Add foreach_unique iteration construct.
Idea via Ingo Wald / IVL compiler.
This commit is contained in:
236
stmt.cpp
236
stmt.cpp
@@ -1915,6 +1915,242 @@ ForeachStmt::Print(int indent) const {
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// ForeachUniqueStmt
|
||||
|
||||
ForeachUniqueStmt::ForeachUniqueStmt(const char *iterName, Expr *e,
|
||||
Stmt *s, SourcePos pos)
|
||||
: Stmt(pos) {
|
||||
sym = m->symbolTable->LookupVariable(iterName);
|
||||
expr = e;
|
||||
stmts = s;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
ForeachUniqueStmt::EmitCode(FunctionEmitContext *ctx) const {
|
||||
if (!ctx->GetCurrentBasicBlock())
|
||||
return;
|
||||
|
||||
// First, allocate local storage for the symbol that we'll use for the
|
||||
// uniform variable that holds the current unique value through each
|
||||
// loop.
|
||||
if (sym->type == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return;
|
||||
}
|
||||
llvm::Type *symType = sym->type->LLVMType(g->ctx);
|
||||
if (symType == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return;
|
||||
}
|
||||
sym->storagePtr = ctx->AllocaInst(symType, sym->name.c_str());
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
ctx->EmitVariableDebugInfo(sym);
|
||||
|
||||
// The various basic blocks that we'll need in the below
|
||||
llvm::BasicBlock *bbFindNext = ctx->CreateBasicBlock("foreach_find_next");
|
||||
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_body");
|
||||
llvm::BasicBlock *bbCheckForMore = ctx->CreateBasicBlock("foreach_check_for_more");
|
||||
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("foreach_done");
|
||||
|
||||
// Prepare the FunctionEmitContext
|
||||
ctx->StartScope();
|
||||
|
||||
// Save the old internal mask so that we can restore it at the end
|
||||
llvm::Value *oldMask = ctx->GetInternalMask();
|
||||
|
||||
// Now, *maskBitsPtr will maintain a bitmask for the lanes that remain
|
||||
// to be processed by a pass through the foreach_unique loop body. It
|
||||
// starts out with the full execution mask (which should never be all
|
||||
// off going in to this)...
|
||||
llvm::Value *oldFullMask = ctx->GetFullMask();
|
||||
llvm::Value *maskBitsPtr = ctx->AllocaInst(LLVMTypes::Int64Type, "mask_bits");
|
||||
llvm::Value *movmsk = ctx->LaneMask(oldFullMask);
|
||||
ctx->StoreInst(movmsk, maskBitsPtr);
|
||||
|
||||
// Officially start the loop; as far as the FunctionEmitContext is
|
||||
// concerned, this can be handled the same way as a regular foreach
|
||||
// loop (continue allowed but not break and return, etc.)
|
||||
ctx->StartForeach();
|
||||
ctx->SetContinueTarget(bbCheckForMore);
|
||||
|
||||
// Evaluate the varying expression we're iterating over just once.
|
||||
llvm::Value *exprValue = expr->GetValue(ctx);
|
||||
|
||||
// And we'll store its value into locally-allocated storage, for ease
|
||||
// of indexing over it with non-compile-time-constant indices.
|
||||
const Type *exprType;
|
||||
llvm::VectorType *llvmExprType;
|
||||
if (exprValue == NULL ||
|
||||
(exprType = expr->GetType()) == NULL ||
|
||||
(llvmExprType =
|
||||
llvm::dyn_cast<llvm::VectorType>(exprValue->getType())) == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return;
|
||||
}
|
||||
ctx->SetDebugPos(pos);
|
||||
const Type *exprPtrType = PointerType::GetUniform(exprType);
|
||||
llvm::Value *exprMem = ctx->AllocaInst(llvmExprType, "expr_mem");
|
||||
ctx->StoreInst(exprValue, exprMem);
|
||||
|
||||
// Onward to find the first set of lanes to run the loop for
|
||||
ctx->BranchInst(bbFindNext);
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbFindNext); {
|
||||
// Load the bitmask of the lanes left to be processed
|
||||
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
|
||||
|
||||
// Find the index of the first set bit in the mask
|
||||
llvm::Function *ctlzFunc =
|
||||
m->module->getFunction("__count_trailing_zeros_i64");
|
||||
Assert(ctlzFunc != NULL);
|
||||
llvm::Value *firstSet = ctx->CallInst(ctlzFunc, NULL, remainingBits,
|
||||
"first_set");
|
||||
|
||||
// And load the corresponding element value from the temporary
|
||||
// memory storing the value of the varying expr.
|
||||
llvm::Value *uniqueValuePtr =
|
||||
ctx->GetElementPtrInst(exprMem, LLVMInt64(0), firstSet, exprPtrType,
|
||||
"unique_index_ptr");
|
||||
llvm::Value *uniqueValue = ctx->LoadInst(uniqueValuePtr, "unique_value");
|
||||
|
||||
// If it's a varying pointer type, need to convert from the int
|
||||
// type we store in the vector to the actual pointer type
|
||||
if (llvm::dyn_cast<llvm::PointerType>(symType) != NULL)
|
||||
uniqueValue = ctx->IntToPtrInst(uniqueValue, symType);
|
||||
|
||||
// Store that value in sym's storage so that the iteration variable
|
||||
// has the right value inside the loop body
|
||||
ctx->StoreInst(uniqueValue, sym->storagePtr);
|
||||
|
||||
// Set the execution mask so that it's on for any lane that a) was
|
||||
// running at the start of the foreach loop, and b) where that
|
||||
// lane's value of the varying expression is the same as the value
|
||||
// we've selected to process this time through--i.e.:
|
||||
// oldMask & (smear(element) == exprValue)
|
||||
llvm::Value *uniqueSmear = ctx->SmearUniform(uniqueValue, "unique_semar");
|
||||
llvm::Value *matchingLanes = NULL;
|
||||
if (uniqueValue->getType()->isFloatingPointTy())
|
||||
matchingLanes =
|
||||
ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_OEQ,
|
||||
uniqueSmear, exprValue, "matching_lanes");
|
||||
else
|
||||
matchingLanes =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
|
||||
uniqueSmear, exprValue, "matching_lanes");
|
||||
matchingLanes = ctx->I1VecToBoolVec(matchingLanes);
|
||||
|
||||
llvm::Value *loopMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, oldMask, matchingLanes,
|
||||
"foreach_unique_loop_mask");
|
||||
ctx->SetInternalMask(loopMask);
|
||||
|
||||
// Also update the bitvector of lanes left to process in subsequent
|
||||
// loop iterations:
|
||||
// remainingBits &= ~movmsk(current mask)
|
||||
llvm::Value *loopMaskMM = ctx->LaneMask(loopMask);
|
||||
llvm::Value *notLoopMaskMM = ctx->NotOperator(loopMaskMM);
|
||||
llvm::Value *newRemaining =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, remainingBits,
|
||||
notLoopMaskMM, "new_remaining");
|
||||
ctx->StoreInst(newRemaining, maskBitsPtr);
|
||||
|
||||
// and onward...
|
||||
ctx->BranchInst(bbBody);
|
||||
}
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbBody); {
|
||||
// Run the code in the body of the loop. This is easy now.
|
||||
if (stmts)
|
||||
stmts->EmitCode(ctx);
|
||||
|
||||
Assert(ctx->GetCurrentBasicBlock() != NULL);
|
||||
ctx->BranchInst(bbCheckForMore);
|
||||
}
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbCheckForMore); {
|
||||
// At the end of the loop body (either due to running the
|
||||
// statements normally, or a continue statement in the middle of
|
||||
// the loop that jumps to the end, see if there are any lanes left
|
||||
// to be processed.
|
||||
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
|
||||
llvm::Value *nonZero =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE,
|
||||
remainingBits, LLVMInt64(0), "remaining_ne_zero");
|
||||
ctx->BranchInst(bbFindNext, bbDone, nonZero);
|
||||
}
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbDone);
|
||||
ctx->SetInternalMask(oldMask);
|
||||
ctx->EndForeach();
|
||||
ctx->EndScope();
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
ForeachUniqueStmt::Print(int indent) const {
|
||||
printf("%*cForeach_unique Stmt", indent, ' ');
|
||||
pos.Print();
|
||||
printf("\n");
|
||||
|
||||
printf("%*cIter symbol: ", indent+4, ' ');
|
||||
if (sym != NULL) {
|
||||
printf("%s", sym->name.c_str());
|
||||
if (sym->type != NULL)
|
||||
printf(" %s", sym->type->GetString().c_str());
|
||||
}
|
||||
else
|
||||
printf("NULL");
|
||||
printf("\n");
|
||||
|
||||
printf("%*cIter expr: ", indent+4, ' ');
|
||||
if (expr != NULL)
|
||||
expr->Print();
|
||||
else
|
||||
printf("NULL");
|
||||
printf("\n");
|
||||
|
||||
printf("%*cStmts:\n", indent+4, ' ');
|
||||
if (stmts != NULL)
|
||||
stmts->Print(indent+8);
|
||||
else
|
||||
printf("NULL");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
||||
Stmt *
|
||||
ForeachUniqueStmt::TypeCheck() {
|
||||
const Type *type;
|
||||
if (sym == NULL || expr == NULL || (type = expr->GetType()) == NULL)
|
||||
return NULL;
|
||||
|
||||
if (type->IsVaryingType() == false) {
|
||||
Error(expr->pos, "Iteration domain type in \"foreach_tiled\" loop "
|
||||
"must be \"varying\" type, not \"%s\".",
|
||||
type->GetString().c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Type::IsBasicType(type) == false) {
|
||||
Error(expr->pos, "Iteration domain type in \"foreach_tiled\" loop "
|
||||
"must be an atomic, pointer, or enum type, not \"%s\".",
|
||||
type->GetString().c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
ForeachUniqueStmt::EstimateCost() const {
|
||||
return COST_VARYING_LOOP;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// CaseStmt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user