259
stmt.cpp
259
stmt.cpp
@@ -1915,6 +1915,188 @@ ForeachStmt::Print(int indent) const {
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// ForeachActiveStmt
|
||||
|
||||
ForeachActiveStmt::ForeachActiveStmt(Symbol *s, Stmt *st, SourcePos pos)
|
||||
: Stmt(pos) {
|
||||
sym = s;
|
||||
stmts = st;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
ForeachActiveStmt::EmitCode(FunctionEmitContext *ctx) const {
|
||||
if (!ctx->GetCurrentBasicBlock())
|
||||
return;
|
||||
|
||||
// Allocate storage for the symbol that we'll use for the uniform
|
||||
// variable that holds the current program instance in each loop
|
||||
// iteration.
|
||||
if (sym->type == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return;
|
||||
}
|
||||
Assert(Type::Equal(sym->type,
|
||||
AtomicType::UniformInt64->GetAsConstType()));
|
||||
sym->storagePtr = ctx->AllocaInst(LLVMTypes::Int64Type, sym->name.c_str());
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
ctx->EmitVariableDebugInfo(sym);
|
||||
|
||||
// The various basic blocks that we'll need in the below
|
||||
llvm::BasicBlock *bbFindNext =
|
||||
ctx->CreateBasicBlock("foreach_active_find_next");
|
||||
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_active_body");
|
||||
llvm::BasicBlock *bbCheckForMore =
|
||||
ctx->CreateBasicBlock("foreach_active_check_for_more");
|
||||
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("foreach_active_done");
|
||||
|
||||
// Save the old mask so that we can restore it at the end
|
||||
llvm::Value *oldInternalMask = ctx->GetInternalMask();
|
||||
|
||||
// Now, *maskBitsPtr will maintain a bitmask for the lanes that remain
|
||||
// to be processed by a pass through the loop body. It starts out with
|
||||
// the current execution mask (which should never be all off going in
|
||||
// to this)...
|
||||
llvm::Value *oldFullMask = ctx->GetFullMask();
|
||||
llvm::Value *maskBitsPtr =
|
||||
ctx->AllocaInst(LLVMTypes::Int64Type, "mask_bits");
|
||||
llvm::Value *movmsk = ctx->LaneMask(oldFullMask);
|
||||
ctx->StoreInst(movmsk, maskBitsPtr);
|
||||
|
||||
// Officially start the loop.
|
||||
ctx->StartScope();
|
||||
ctx->StartForeach(FunctionEmitContext::FOREACH_ACTIVE);
|
||||
ctx->SetContinueTarget(bbCheckForMore);
|
||||
|
||||
// Onward to find the first set of program instance to run the loop for
|
||||
ctx->BranchInst(bbFindNext);
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbFindNext); {
|
||||
// Load the bitmask of the lanes left to be processed
|
||||
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
|
||||
|
||||
// Find the index of the first set bit in the mask
|
||||
llvm::Function *ctlzFunc =
|
||||
m->module->getFunction("__count_trailing_zeros_i64");
|
||||
Assert(ctlzFunc != NULL);
|
||||
llvm::Value *firstSet = ctx->CallInst(ctlzFunc, NULL, remainingBits,
|
||||
"first_set");
|
||||
|
||||
// Store that value into the storage allocated for the iteration
|
||||
// variable.
|
||||
ctx->StoreInst(firstSet, sym->storagePtr);
|
||||
|
||||
// Now set the execution mask to be only on for the current program
|
||||
// instance. (TODO: is there a more efficient way to do this? e.g.
|
||||
// for AVX1, we might want to do this as float rather than int
|
||||
// math...)
|
||||
|
||||
// Get the "program index" vector value
|
||||
llvm::Value *programIndex =
|
||||
llvm::UndefValue::get(LLVMTypes::Int32VectorType);
|
||||
for (int i = 0; i < g->target.vectorWidth; ++i)
|
||||
programIndex = ctx->InsertInst(programIndex, LLVMInt32(i), i,
|
||||
"prog_index");
|
||||
|
||||
// And smear the current lane out to a vector
|
||||
llvm::Value *firstSet32 =
|
||||
ctx->TruncInst(firstSet, LLVMTypes::Int32Type, "first_set32");
|
||||
llvm::Value *firstSet32Smear = ctx->SmearUniform(firstSet32);
|
||||
|
||||
// Now set the execution mask based on doing a vector compare of
|
||||
// these two
|
||||
llvm::Value *iterMask =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
|
||||
firstSet32Smear, programIndex);
|
||||
iterMask = ctx->I1VecToBoolVec(iterMask);
|
||||
|
||||
ctx->SetInternalMask(iterMask);
|
||||
|
||||
// Also update the bitvector of lanes left to turn off the bit for
|
||||
// the lane we're about to run.
|
||||
llvm::Value *setMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::Shl, LLVMInt64(1),
|
||||
firstSet, "set_mask");
|
||||
llvm::Value *notSetMask = ctx->NotOperator(setMask);
|
||||
llvm::Value *newRemaining =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, remainingBits,
|
||||
notSetMask, "new_remaining");
|
||||
ctx->StoreInst(newRemaining, maskBitsPtr);
|
||||
|
||||
// and onward to run the loop body...
|
||||
ctx->BranchInst(bbBody);
|
||||
}
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbBody); {
|
||||
// Run the code in the body of the loop. This is easy now.
|
||||
if (stmts)
|
||||
stmts->EmitCode(ctx);
|
||||
|
||||
Assert(ctx->GetCurrentBasicBlock() != NULL);
|
||||
ctx->BranchInst(bbCheckForMore);
|
||||
}
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbCheckForMore); {
|
||||
// At the end of the loop body (either due to running the
|
||||
// statements normally, or a continue statement in the middle of
|
||||
// the loop that jumps to the end, see if there are any lanes left
|
||||
// to be processed.
|
||||
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
|
||||
llvm::Value *nonZero =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE,
|
||||
remainingBits, LLVMInt64(0), "remaining_ne_zero");
|
||||
ctx->BranchInst(bbFindNext, bbDone, nonZero);
|
||||
}
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbDone);
|
||||
ctx->SetInternalMask(oldInternalMask);
|
||||
ctx->EndForeach();
|
||||
ctx->EndScope();
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
ForeachActiveStmt::Print(int indent) const {
|
||||
printf("%*cForeach_active Stmt", indent, ' ');
|
||||
pos.Print();
|
||||
printf("\n");
|
||||
|
||||
printf("%*cIter symbol: ", indent+4, ' ');
|
||||
if (sym != NULL) {
|
||||
printf("%s", sym->name.c_str());
|
||||
if (sym->type != NULL)
|
||||
printf(" %s", sym->type->GetString().c_str());
|
||||
}
|
||||
else
|
||||
printf("NULL");
|
||||
printf("\n");
|
||||
|
||||
printf("%*cStmts:\n", indent+4, ' ');
|
||||
if (stmts != NULL)
|
||||
stmts->Print(indent+8);
|
||||
else
|
||||
printf("NULL");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
||||
Stmt *
|
||||
ForeachActiveStmt::TypeCheck() {
|
||||
if (sym == NULL)
|
||||
return NULL;
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
ForeachActiveStmt::EstimateCost() const {
|
||||
return COST_VARYING_LOOP;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// ForeachUniqueStmt
|
||||
|
||||
@@ -3043,80 +3225,3 @@ int
|
||||
DeleteStmt::EstimateCost() const {
|
||||
return COST_DELETE;
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/** This generates AST nodes for an __foreach_active statement. This
|
||||
construct can be synthesized ouf of the existing ForStmt (and other AST
|
||||
nodes), so here we just build up the AST that we need rather than
|
||||
having a new Stmt implementation for __foreach_active.
|
||||
|
||||
@param iterSym Symbol for the iteration variable (e.g. "i" in
|
||||
__foreach_active (i) { .. .}
|
||||
@param stmts Statements to execute each time through the loop, for
|
||||
each active program instance.
|
||||
@param pos Position of the __foreach_active statement in the source
|
||||
file.
|
||||
*/
|
||||
Stmt *
|
||||
CreateForeachActiveStmt(Symbol *iterSym, Stmt *stmts, SourcePos pos) {
|
||||
if (iterSym == NULL) {
|
||||
AssertPos(pos, m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// loop initializer: set iter = 0
|
||||
std::vector<VariableDeclaration> var;
|
||||
ConstExpr *zeroExpr = new ConstExpr(AtomicType::UniformInt32, 0,
|
||||
iterSym->pos);
|
||||
var.push_back(VariableDeclaration(iterSym, zeroExpr));
|
||||
Stmt *initStmt = new DeclStmt(var, iterSym->pos);
|
||||
|
||||
// loop test: (iter < programCount)
|
||||
ConstExpr *progCountExpr =
|
||||
new ConstExpr(AtomicType::UniformInt32, g->target.vectorWidth,
|
||||
pos);
|
||||
SymbolExpr *symExpr = new SymbolExpr(iterSym, iterSym->pos);
|
||||
Expr *testExpr = new BinaryExpr(BinaryExpr::Lt, symExpr, progCountExpr,
|
||||
pos);
|
||||
|
||||
// loop step: ++iterSym
|
||||
UnaryExpr *incExpr = new UnaryExpr(UnaryExpr::PreInc, symExpr, pos);
|
||||
Stmt *stepStmt = new ExprStmt(incExpr, pos);
|
||||
|
||||
// loop body
|
||||
// First, call __movmsk(__mask)) to get the mask as a set of bits.
|
||||
// This should be hoisted out of the loop
|
||||
Symbol *maskSym = m->symbolTable->LookupVariable("__mask");
|
||||
AssertPos(pos, maskSym != NULL);
|
||||
Expr *maskVecExpr = new SymbolExpr(maskSym, pos);
|
||||
std::vector<Symbol *> mmFuns;
|
||||
m->symbolTable->LookupFunction("__movmsk", &mmFuns);
|
||||
AssertPos(pos, mmFuns.size() == (g->target.maskBitCount == 32 ? 2 : 1));
|
||||
FunctionSymbolExpr *movmskFunc = new FunctionSymbolExpr("__movmsk", mmFuns,
|
||||
pos);
|
||||
ExprList *movmskArgs = new ExprList(maskVecExpr, pos);
|
||||
FunctionCallExpr *movmskExpr = new FunctionCallExpr(movmskFunc, movmskArgs,
|
||||
pos);
|
||||
|
||||
// Compute the per lane mask to test the mask bits against: (1 << iter)
|
||||
ConstExpr *oneExpr = new ConstExpr(AtomicType::UniformInt64, int64_t(1),
|
||||
iterSym->pos);
|
||||
Expr *shiftLaneExpr = new BinaryExpr(BinaryExpr::Shl, oneExpr, symExpr,
|
||||
pos);
|
||||
|
||||
// Compute the AND: movmsk & (1 << iter)
|
||||
Expr *maskAndLaneExpr = new BinaryExpr(BinaryExpr::BitAnd, movmskExpr,
|
||||
shiftLaneExpr, pos);
|
||||
// Test to see if it's non-zero: (mask & (1 << iter)) != 0
|
||||
Expr *ifTestExpr = new BinaryExpr(BinaryExpr::NotEqual, maskAndLaneExpr,
|
||||
zeroExpr, pos);
|
||||
|
||||
// Now, enclose the provided statements in an if test such that they
|
||||
// only run if the mask is non-zero for the lane we're currently
|
||||
// handling in the loop.
|
||||
IfStmt *laneCheckIf = new IfStmt(ifTestExpr, stmts, NULL, false, pos);
|
||||
|
||||
// And return a for loop that wires it all together.
|
||||
return new ForStmt(initStmt, testExpr, stepStmt, laneCheckIf, false, pos);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user