Add foreach_active iteration statement.

Issue #298.
This commit is contained in:
Matt Pharr
2012-06-22 10:35:43 -07:00
parent ed13dd066b
commit b4a078e2f6
15 changed files with 644 additions and 279 deletions

259
stmt.cpp
View File

@@ -1915,6 +1915,188 @@ ForeachStmt::Print(int indent) const {
}
///////////////////////////////////////////////////////////////////////////
// ForeachActiveStmt
ForeachActiveStmt::ForeachActiveStmt(Symbol *s, Stmt *st, SourcePos pos)
: Stmt(pos) {
sym = s;
stmts = st;
}
void
ForeachActiveStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
// Allocate storage for the symbol that we'll use for the uniform
// variable that holds the current program instance in each loop
// iteration.
if (sym->type == NULL) {
Assert(m->errorCount > 0);
return;
}
Assert(Type::Equal(sym->type,
AtomicType::UniformInt64->GetAsConstType()));
sym->storagePtr = ctx->AllocaInst(LLVMTypes::Int64Type, sym->name.c_str());
ctx->SetDebugPos(pos);
ctx->EmitVariableDebugInfo(sym);
// The various basic blocks that we'll need in the below
llvm::BasicBlock *bbFindNext =
ctx->CreateBasicBlock("foreach_active_find_next");
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_active_body");
llvm::BasicBlock *bbCheckForMore =
ctx->CreateBasicBlock("foreach_active_check_for_more");
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("foreach_active_done");
// Save the old mask so that we can restore it at the end
llvm::Value *oldInternalMask = ctx->GetInternalMask();
// Now, *maskBitsPtr will maintain a bitmask for the lanes that remain
// to be processed by a pass through the loop body. It starts out with
// the current execution mask (which should never be all off going in
// to this)...
llvm::Value *oldFullMask = ctx->GetFullMask();
llvm::Value *maskBitsPtr =
ctx->AllocaInst(LLVMTypes::Int64Type, "mask_bits");
llvm::Value *movmsk = ctx->LaneMask(oldFullMask);
ctx->StoreInst(movmsk, maskBitsPtr);
// Officially start the loop.
ctx->StartScope();
ctx->StartForeach(FunctionEmitContext::FOREACH_ACTIVE);
ctx->SetContinueTarget(bbCheckForMore);
// Onward to find the first set of program instance to run the loop for
ctx->BranchInst(bbFindNext);
ctx->SetCurrentBasicBlock(bbFindNext); {
// Load the bitmask of the lanes left to be processed
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
// Find the index of the first set bit in the mask
llvm::Function *ctlzFunc =
m->module->getFunction("__count_trailing_zeros_i64");
Assert(ctlzFunc != NULL);
llvm::Value *firstSet = ctx->CallInst(ctlzFunc, NULL, remainingBits,
"first_set");
// Store that value into the storage allocated for the iteration
// variable.
ctx->StoreInst(firstSet, sym->storagePtr);
// Now set the execution mask to be only on for the current program
// instance. (TODO: is there a more efficient way to do this? e.g.
// for AVX1, we might want to do this as float rather than int
// math...)
// Get the "program index" vector value
llvm::Value *programIndex =
llvm::UndefValue::get(LLVMTypes::Int32VectorType);
for (int i = 0; i < g->target.vectorWidth; ++i)
programIndex = ctx->InsertInst(programIndex, LLVMInt32(i), i,
"prog_index");
// And smear the current lane out to a vector
llvm::Value *firstSet32 =
ctx->TruncInst(firstSet, LLVMTypes::Int32Type, "first_set32");
llvm::Value *firstSet32Smear = ctx->SmearUniform(firstSet32);
// Now set the execution mask based on doing a vector compare of
// these two
llvm::Value *iterMask =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
firstSet32Smear, programIndex);
iterMask = ctx->I1VecToBoolVec(iterMask);
ctx->SetInternalMask(iterMask);
// Also update the bitvector of lanes left to turn off the bit for
// the lane we're about to run.
llvm::Value *setMask =
ctx->BinaryOperator(llvm::Instruction::Shl, LLVMInt64(1),
firstSet, "set_mask");
llvm::Value *notSetMask = ctx->NotOperator(setMask);
llvm::Value *newRemaining =
ctx->BinaryOperator(llvm::Instruction::And, remainingBits,
notSetMask, "new_remaining");
ctx->StoreInst(newRemaining, maskBitsPtr);
// and onward to run the loop body...
ctx->BranchInst(bbBody);
}
ctx->SetCurrentBasicBlock(bbBody); {
// Run the code in the body of the loop. This is easy now.
if (stmts)
stmts->EmitCode(ctx);
Assert(ctx->GetCurrentBasicBlock() != NULL);
ctx->BranchInst(bbCheckForMore);
}
ctx->SetCurrentBasicBlock(bbCheckForMore); {
// At the end of the loop body (either due to running the
// statements normally, or a continue statement in the middle of
// the loop that jumps to the end, see if there are any lanes left
// to be processed.
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
llvm::Value *nonZero =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE,
remainingBits, LLVMInt64(0), "remaining_ne_zero");
ctx->BranchInst(bbFindNext, bbDone, nonZero);
}
ctx->SetCurrentBasicBlock(bbDone);
ctx->SetInternalMask(oldInternalMask);
ctx->EndForeach();
ctx->EndScope();
}
void
ForeachActiveStmt::Print(int indent) const {
printf("%*cForeach_active Stmt", indent, ' ');
pos.Print();
printf("\n");
printf("%*cIter symbol: ", indent+4, ' ');
if (sym != NULL) {
printf("%s", sym->name.c_str());
if (sym->type != NULL)
printf(" %s", sym->type->GetString().c_str());
}
else
printf("NULL");
printf("\n");
printf("%*cStmts:\n", indent+4, ' ');
if (stmts != NULL)
stmts->Print(indent+8);
else
printf("NULL");
printf("\n");
}
Stmt *
ForeachActiveStmt::TypeCheck() {
if (sym == NULL)
return NULL;
return this;
}
int
ForeachActiveStmt::EstimateCost() const {
return COST_VARYING_LOOP;
}
///////////////////////////////////////////////////////////////////////////
// ForeachUniqueStmt
@@ -3043,80 +3225,3 @@ int
DeleteStmt::EstimateCost() const {
return COST_DELETE;
}
///////////////////////////////////////////////////////////////////////////
/** This generates AST nodes for an __foreach_active statement. This
construct can be synthesized ouf of the existing ForStmt (and other AST
nodes), so here we just build up the AST that we need rather than
having a new Stmt implementation for __foreach_active.
@param iterSym Symbol for the iteration variable (e.g. "i" in
__foreach_active (i) { .. .}
@param stmts Statements to execute each time through the loop, for
each active program instance.
@param pos Position of the __foreach_active statement in the source
file.
*/
Stmt *
CreateForeachActiveStmt(Symbol *iterSym, Stmt *stmts, SourcePos pos) {
if (iterSym == NULL) {
AssertPos(pos, m->errorCount > 0);
return NULL;
}
// loop initializer: set iter = 0
std::vector<VariableDeclaration> var;
ConstExpr *zeroExpr = new ConstExpr(AtomicType::UniformInt32, 0,
iterSym->pos);
var.push_back(VariableDeclaration(iterSym, zeroExpr));
Stmt *initStmt = new DeclStmt(var, iterSym->pos);
// loop test: (iter < programCount)
ConstExpr *progCountExpr =
new ConstExpr(AtomicType::UniformInt32, g->target.vectorWidth,
pos);
SymbolExpr *symExpr = new SymbolExpr(iterSym, iterSym->pos);
Expr *testExpr = new BinaryExpr(BinaryExpr::Lt, symExpr, progCountExpr,
pos);
// loop step: ++iterSym
UnaryExpr *incExpr = new UnaryExpr(UnaryExpr::PreInc, symExpr, pos);
Stmt *stepStmt = new ExprStmt(incExpr, pos);
// loop body
// First, call __movmsk(__mask)) to get the mask as a set of bits.
// This should be hoisted out of the loop
Symbol *maskSym = m->symbolTable->LookupVariable("__mask");
AssertPos(pos, maskSym != NULL);
Expr *maskVecExpr = new SymbolExpr(maskSym, pos);
std::vector<Symbol *> mmFuns;
m->symbolTable->LookupFunction("__movmsk", &mmFuns);
AssertPos(pos, mmFuns.size() == (g->target.maskBitCount == 32 ? 2 : 1));
FunctionSymbolExpr *movmskFunc = new FunctionSymbolExpr("__movmsk", mmFuns,
pos);
ExprList *movmskArgs = new ExprList(maskVecExpr, pos);
FunctionCallExpr *movmskExpr = new FunctionCallExpr(movmskFunc, movmskArgs,
pos);
// Compute the per lane mask to test the mask bits against: (1 << iter)
ConstExpr *oneExpr = new ConstExpr(AtomicType::UniformInt64, int64_t(1),
iterSym->pos);
Expr *shiftLaneExpr = new BinaryExpr(BinaryExpr::Shl, oneExpr, symExpr,
pos);
// Compute the AND: movmsk & (1 << iter)
Expr *maskAndLaneExpr = new BinaryExpr(BinaryExpr::BitAnd, movmskExpr,
shiftLaneExpr, pos);
// Test to see if it's non-zero: (mask & (1 << iter)) != 0
Expr *ifTestExpr = new BinaryExpr(BinaryExpr::NotEqual, maskAndLaneExpr,
zeroExpr, pos);
// Now, enclose the provided statements in an if test such that they
// only run if the mask is non-zero for the lane we're currently
// handling in the loop.
IfStmt *laneCheckIf = new IfStmt(ifTestExpr, stmts, NULL, false, pos);
// And return a for loop that wires it all together.
return new ForStmt(initStmt, testExpr, stepStmt, laneCheckIf, false, pos);
}