Add foreach and foreach_tiled looping constructs

These make it easier to iterate over arbitrary amounts of data
elements; specifically, they automatically handle the "ragged
extra bits" that come up when the number of elements to be
processed isn't evenly divided by programCount.

TODO: documentation
This commit is contained in:
Matt Pharr
2011-11-30 13:17:31 -08:00
parent b48775a549
commit 8bc7367109
32 changed files with 1120 additions and 78 deletions

473
stmt.cpp
View File

@@ -819,6 +819,17 @@ lSafeToRunWithAllLanesOff(Stmt *stmt) {
lSafeToRunWithAllLanesOff(fs->step) &&
lSafeToRunWithAllLanesOff(fs->stmts));
ForeachStmt *fes;
if ((fes = dynamic_cast<ForeachStmt *>(stmt)) != NULL) {
for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
if (!lSafeToRunWithAllLanesOff(fes->startExprs[i]))
return false;
for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
if (!lSafeToRunWithAllLanesOff(fes->endExprs[i]))
return false;
return lSafeToRunWithAllLanesOff(fes->stmts);
}
if (dynamic_cast<BreakStmt *>(stmt) != NULL ||
dynamic_cast<ContinueStmt *>(stmt) != NULL)
return true;
@@ -1592,6 +1603,463 @@ ContinueStmt::Print(int indent) const {
}
///////////////////////////////////////////////////////////////////////////
// ForeachStmt
ForeachStmt::ForeachStmt(const std::vector<Symbol *> &lvs,
const std::vector<Expr *> &se,
const std::vector<Expr *> &ee,
Stmt *s, bool t, SourcePos pos)
: Stmt(pos), dimVariables(lvs), startExprs(se), endExprs(ee), isTiled(t),
stmts(s) {
}
/* Given a uniform counter value in the memory location pointed to by
uniformCounterPtr, compute the corresponding set of varying counter
values for use within the loop body.
*/
static llvm::Value *
lUpdateVaryingCounter(int dim, int nDims, FunctionEmitContext *ctx,
llvm::Value *uniformCounterPtr,
llvm::Value *varyingCounterPtr,
const std::vector<int> &spans) {
// Smear the uniform counter value out to be varying
llvm::Value *counter = ctx->LoadInst(uniformCounterPtr);
llvm::Value *smearCounter =
llvm::UndefValue::get(LLVMTypes::Int32VectorType);
for (int i = 0; i < g->target.vectorWidth; ++i)
smearCounter =
ctx->InsertInst(smearCounter, counter, i, "smear_counter");
// Figure out the offsets; this is a little bit tricky. As an example,
// consider a 2D tiled foreach loop, where we're running 8-wide and
// where the inner dimension has a stride of 4 and the outer dimension
// has a stride of 2. For the inner dimension, we want the offsets
// (0,1,2,3,0,1,2,3), and for the outer dimension we want
// (0,0,0,0,1,1,1,1).
int32_t delta[ISPC_MAX_NVEC];
for (int i = 0; i < g->target.vectorWidth; ++i) {
int d = i;
// First, account for the effect of any dimensions at deeper
// nesting levels than the current one.
int prevDimSpanCount = 1;
for (int j = dim; j < nDims-1; ++j)
prevDimSpanCount *= spans[j+1];
d /= prevDimSpanCount;
// And now with what's left, figure out our own offset
delta[i] = d % spans[dim];
}
// Add the deltas to compute the varying counter values; store the
// result to memory and then return it directly as well.
llvm::Value *varyingCounter =
ctx->BinaryOperator(llvm::Instruction::Add, smearCounter,
LLVMInt32Vector(delta), "iter_val");
ctx->StoreInst(varyingCounter, varyingCounterPtr);
return varyingCounter;
}
/** Returns the integer log2 of the given integer. */
static int
lLog2(int i) {
int ret = 0;
while (i != 0) {
++ret;
i >>= 1;
}
return ret-1;
}
/* Figure out how many elements to process in each dimension for each time
through a foreach loop. The untiled case is easy; all of the outer
dimensions up until the innermost one have a span of 1, and the
innermost one takes the entire vector width. For the tiled case, we
give wider spans to the innermost dimensions while also trying to
generate relatively square domains.
This code works recursively from outer dimensions to inner dimensions.
*/
static void
lGetSpans(int dimsLeft, int nDims, int itemsLeft, bool isTiled, int *a) {
if (dimsLeft == 0) {
// Nothing left to do but give all of the remaining work to the
// innermost domain.
*a = itemsLeft;
return;
}
if (isTiled == false || (dimsLeft >= lLog2(itemsLeft)))
// If we're not tiled, or if there are enough dimensions left that
// giving this one any more than a span of one would mean that a
// later dimension would have to have a span of one, give this one
// a span of one to save the available items for later.
*a = 1;
else if (itemsLeft >= 16 && (dimsLeft == 1))
// Special case to have 4x4 domains for the 2D case when running
// 16-wide.
*a = 4;
else
// Otherwise give this dimension a span of two.
*a = 2;
lGetSpans(dimsLeft-1, nDims, itemsLeft / *a, isTiled, a+1);
}
/* Emit code for a foreach statement. We effectively emit code to run the
set of n-dimensional nested loops corresponding to the dimensionality of
the foreach statement along with the extra logic to deal with mismatches
between the vector width we're compiling to and the number of elements
to process.
*/
void
ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
if (ctx->GetCurrentBasicBlock() == NULL || stmts == NULL)
return;
llvm::BasicBlock *bbCheckExtras = ctx->CreateBasicBlock("foreach_check_extras");
llvm::BasicBlock *bbDoExtras = ctx->CreateBasicBlock("foreach_do_extras");
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_body");
llvm::BasicBlock *bbExit = ctx->CreateBasicBlock("foreach_exit");
llvm::Value *oldMask = ctx->GetInternalMask();
ctx->StartForeach();
ctx->SetDebugPos(pos);
ctx->StartScope();
// This should be caught during typechecking
assert(startExprs.size() == dimVariables.size() &&
endExprs.size() == dimVariables.size());
int nDims = (int)dimVariables.size();
///////////////////////////////////////////////////////////////////////
// Setup: compute the number of items we have to work on in each
// dimension and a number of derived values.
std::vector<llvm::BasicBlock *> bbReset, bbStep, bbTest;
std::vector<llvm::Value *> startVals, endVals, uniformCounterPtrs;
std::vector<llvm::Value *> nItems, nExtras, alignedEnd;
std::vector<llvm::Value *> extrasMaskPtrs;
std::vector<int> span(nDims, 0);
lGetSpans(nDims-1, nDims, g->target.vectorWidth, isTiled, &span[0]);
for (int i = 0; i < nDims; ++i) {
// Basic blocks that we'll fill in later with the looping logic for
// this dimension.
bbReset.push_back(ctx->CreateBasicBlock("foreach_reset"));
bbStep.push_back(ctx->CreateBasicBlock("foreach_step"));
bbTest.push_back(ctx->CreateBasicBlock("foreach_test"));
// Start and end value for this loop dimension
llvm::Value *sv = startExprs[i]->GetValue(ctx);
llvm::Value *ev = endExprs[i]->GetValue(ctx);
if (sv == NULL || ev == NULL)
return;
startVals.push_back(sv);
endVals.push_back(ev);
// nItems = endVal - startVal
nItems.push_back(ctx->BinaryOperator(llvm::Instruction::Sub, ev, sv,
"nitems"));
// nExtras = nItems % (span for this dimension)
// This gives us the number of extra elements we need to deal with
// at the end of the loop for this dimension that don't fit cleanly
// into a vector width.
nExtras.push_back(ctx->BinaryOperator(llvm::Instruction::SRem, nItems[i],
LLVMInt32(span[i]), "nextras"));
// alignedEnd = endVal - nExtras
alignedEnd.push_back(ctx->BinaryOperator(llvm::Instruction::Sub, ev,
nExtras[i], "aligned_end"));
///////////////////////////////////////////////////////////////////////
// Each dimension has a loop counter that is a uniform value that
// goes from startVal to endVal, in steps of the span for this
// dimension. Its value is only used internally here for looping
// logic and isn't directly available in the user's program code.
uniformCounterPtrs.push_back(ctx->AllocaInst(LLVMTypes::Int32Type,
"counter"));
ctx->StoreInst(startVals[i], uniformCounterPtrs[i]);
// There is also a varying variable that holds the set of index
// values for each dimension in the current loop iteration; this is
// the value that is program-visible.
dimVariables[i]->storagePtr = ctx->AllocaInst(LLVMTypes::Int32VectorType,
dimVariables[i]->name.c_str());
dimVariables[i]->parentFunction = ctx->GetFunction();
ctx->EmitVariableDebugInfo(dimVariables[i]);
// Each dimension also maintains a mask that represents which of
// the varying elements in the current iteration should be
// processed. (i.e. this is used to disable the lanes that have
// out-of-bounds offsets.)
extrasMaskPtrs.push_back(ctx->AllocaInst(LLVMTypes::MaskType, "extras mask"));
ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]);
}
// On to the outermost loop's test
ctx->BranchInst(bbTest[0]);
///////////////////////////////////////////////////////////////////////////
// foreach_reset: this code runs when we need to reset the counter for
// a given dimension in preparation for running through its loop again,
// after the enclosing level advances its counter.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbReset[i]);
if (i == 0)
ctx->BranchInst(bbExit);
else {
ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]);
ctx->StoreInst(startVals[i], uniformCounterPtrs[i]);
ctx->BranchInst(bbStep[i-1]);
}
}
///////////////////////////////////////////////////////////////////////////
// foreach_test
std::vector<llvm::Value *> inExtras;
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbTest[i]);
llvm::Value *haveExtras =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SGT,
endVals[i], alignedEnd[i], "have_extras");
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i], "counter");
llvm::Value *atAlignedEnd =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
counter, alignedEnd[i], "at_aligned_end");
llvm::Value *inEx =
ctx->BinaryOperator(llvm::Instruction::And, haveExtras,
atAlignedEnd, "in_extras");
if (i == 0)
inExtras.push_back(inEx);
else
inExtras.push_back(ctx->BinaryOperator(llvm::Instruction::Or, inEx,
inExtras[i-1], "in_extras_all"));
llvm::Value *varyingCounter =
lUpdateVaryingCounter(i, nDims, ctx, uniformCounterPtrs[i],
dimVariables[i]->storagePtr, span);
llvm::Value *smearEnd = llvm::UndefValue::get(LLVMTypes::Int32VectorType);
for (int j = 0; j < g->target.vectorWidth; ++j)
smearEnd = ctx->InsertInst(smearEnd, endVals[i], j, "smear_end");
// Do a vector compare of its value to the end value to generate a
// mask for this last bit of work.
llvm::Value *emask =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT,
varyingCounter, smearEnd);
emask = ctx->I1VecToBoolVec(emask);
if (i == 0)
ctx->StoreInst(emask, extrasMaskPtrs[i]);
else {
// FIXME: at least specialize the innermost loop to not do all
// this mask stuff each time through the test...
llvm::Value *oldMask = ctx->LoadInst(extrasMaskPtrs[i-1]);
llvm::Value *newMask =
ctx->BinaryOperator(llvm::Instruction::And, oldMask, emask,
"extras_mask");
ctx->StoreInst(newMask, extrasMaskPtrs[i]);
}
llvm::Value *notAtEnd =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_SLT,
counter, endVals[i]);
if (i != nDims-1)
ctx->BranchInst(bbTest[i+1], bbReset[i], notAtEnd);
else
ctx->BranchInst(bbCheckExtras, bbReset[i], notAtEnd);
}
///////////////////////////////////////////////////////////////////////////
// foreach_step: increment the uniform counter by the vector width.
// Note that we don't increment the varying counter here as well but
// just generate its value when we need it in the loop body.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbStep[i]);
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i]);
llvm::Value *newCounter =
ctx->BinaryOperator(llvm::Instruction::Add, counter,
LLVMInt32(span[i]), "new_counter");
ctx->StoreInst(newCounter, uniformCounterPtrs[i]);
ctx->BranchInst(bbTest[i]);
}
///////////////////////////////////////////////////////////////////////////
// foreach_check_extras: see if we need to deal with any partial
// vector's worth of work that's left.
ctx->SetCurrentBasicBlock(bbCheckExtras);
ctx->AddInstrumentationPoint("foreach loop check extras");
ctx->BranchInst(bbDoExtras, bbBody, inExtras[nDims-1]);
///////////////////////////////////////////////////////////////////////////
// foreach_body: do a full vector's worth of work. We know that all
// lanes will be running here, so we explicitly set the mask to be 'all
// on'. This ends up being relatively straightforward: just update the
// value of the varying loop counter and have the statements in the
// loop body emit their code.
ctx->SetCurrentBasicBlock(bbBody);
ctx->SetInternalMask(LLVMMaskAllOn);
ctx->AddInstrumentationPoint("foreach loop body");
stmts->EmitCode(ctx);
assert(ctx->GetCurrentBasicBlock() != NULL);
ctx->BranchInst(bbStep[nDims-1]);
///////////////////////////////////////////////////////////////////////////
// foreach_doextras: set the mask and have the statements emit their
// code again. Note that it's generally worthwhile having two copies
// of the statements' code, since the code above is emitted with the
// mask known to be all-on, which in turn leads to more efficient code
// for that case.
ctx->SetCurrentBasicBlock(bbDoExtras);
llvm::Value *mask = ctx->LoadInst(extrasMaskPtrs[nDims-1]);
ctx->SetInternalMask(mask);
stmts->EmitCode(ctx);
ctx->BranchInst(bbStep[nDims-1]);
///////////////////////////////////////////////////////////////////////////
// foreach_exit: All done. Restore the old mask and clean up
ctx->SetCurrentBasicBlock(bbExit);
ctx->SetInternalMask(oldMask);
ctx->EndForeach();
ctx->EndScope();
}
Stmt *
ForeachStmt::Optimize() {
bool anyErrors = false;
for (unsigned int i = 0; i < startExprs.size(); ++i) {
if (startExprs[i] != NULL)
startExprs[i]->Optimize();
anyErrors |= (startExprs[i] == NULL);
}
for (unsigned int i = 0; i < endExprs.size(); ++i) {
if (endExprs[i] != NULL)
endExprs[i]->Optimize();
anyErrors |= (endExprs[i] == NULL);
}
if (stmts != NULL)
stmts = stmts->TypeCheck();
anyErrors |= (stmts == NULL);
return anyErrors ? NULL : this;
}
Stmt *
ForeachStmt::TypeCheck() {
bool anyErrors = false;
for (unsigned int i = 0; i < startExprs.size(); ++i) {
if (startExprs[i] != NULL)
startExprs[i] = TypeConvertExpr(startExprs[i],
AtomicType::UniformInt32,
"foreach starting value");
if (startExprs[i] != NULL)
startExprs[i]->TypeCheck();
anyErrors |= (startExprs[i] == NULL);
}
for (unsigned int i = 0; i < endExprs.size(); ++i) {
if (endExprs[i] != NULL)
endExprs[i] = TypeConvertExpr(endExprs[i], AtomicType::UniformInt32,
"foreach ending value");
if (endExprs[i] != NULL)
endExprs[i]->TypeCheck();
anyErrors |= (endExprs[i] == NULL);
}
if (stmts != NULL)
stmts = stmts->TypeCheck();
anyErrors |= (stmts == NULL);
if (startExprs.size() < dimVariables.size()) {
Error(pos, "Not enough initial values provided for \"foreach\" loop; "
"got %d, expected %d\n", (int)startExprs.size(), (int)dimVariables.size());
anyErrors = true;
}
else if (startExprs.size() > dimVariables.size()) {
Error(pos, "Too many initial values provided for \"foreach\" loop; "
"got %d, expected %d\n", (int)startExprs.size(), (int)dimVariables.size());
anyErrors = true;
}
if (endExprs.size() < dimVariables.size()) {
Error(pos, "Not enough initial values provided for \"foreach\" loop; "
"got %d, expected %d\n", (int)endExprs.size(), (int)dimVariables.size());
anyErrors = true;
}
else if (endExprs.size() > dimVariables.size()) {
Error(pos, "Too many initial values provided for \"foreach\" loop; "
"got %d, expected %d\n", (int)endExprs.size(), (int)dimVariables.size());
anyErrors = true;
}
return anyErrors ? NULL : this;
}
int
ForeachStmt::EstimateCost() const {
return dimVariables.size() * (COST_UNIFORM_LOOP + COST_SIMPLE_ARITH_LOGIC_OP) +
(stmts ? stmts->EstimateCost() : 0);
}
void
ForeachStmt::Print(int indent) const {
printf("%*cForeach Stmt", indent, ' ');
pos.Print();
printf("\n");
for (unsigned int i = 0; i < dimVariables.size(); ++i)
if (dimVariables[i] != NULL)
printf("%*cVar %d: %s\n", indent+4, ' ', i,
dimVariables[i]->name.c_str());
else
printf("%*cVar %d: NULL\n", indent+4, ' ', i);
printf("Start values:\n");
for (unsigned int i = 0; i < startExprs.size(); ++i) {
if (startExprs[i] != NULL)
startExprs[i]->Print();
else
printf("NULL");
if (i != startExprs.size()-1)
printf(", ");
else
printf("\n");
}
printf("End values:\n");
for (unsigned int i = 0; i < endExprs.size(); ++i) {
if (endExprs[i] != NULL)
endExprs[i]->Print();
else
printf("NULL");
if (i != endExprs.size()-1)
printf(", ");
else
printf("\n");
}
if (stmts != NULL) {
printf("%*cStmts:\n", indent+4, ' ');
stmts->Print(indent+8);
}
}
///////////////////////////////////////////////////////////////////////////
// ReturnStmt
@@ -1606,6 +2074,11 @@ ReturnStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
if (ctx->InForeachLoop()) {
Error(pos, "\"return\" statement is illegal inside a \"foreach\" loop.");
return;
}
ctx->SetDebugPos(pos);
ctx->CurrentLanesReturned(val, doCoherenceCheck);
}