Add foreach_unique iteration construct.

Idea via Ingo Wald / IVL compiler.
This commit is contained in:
Matt Pharr
2012-06-20 10:03:44 -07:00
parent fae47e0dfc
commit 3bc66136b2
17 changed files with 488 additions and 6 deletions

236
stmt.cpp
View File

@@ -1915,6 +1915,242 @@ ForeachStmt::Print(int indent) const {
}
///////////////////////////////////////////////////////////////////////////
// ForeachUniqueStmt
ForeachUniqueStmt::ForeachUniqueStmt(const char *iterName, Expr *e,
Stmt *s, SourcePos pos)
: Stmt(pos) {
sym = m->symbolTable->LookupVariable(iterName);
expr = e;
stmts = s;
}
void
ForeachUniqueStmt::EmitCode(FunctionEmitContext *ctx) const {
if (!ctx->GetCurrentBasicBlock())
return;
// First, allocate local storage for the symbol that we'll use for the
// uniform variable that holds the current unique value through each
// loop.
if (sym->type == NULL) {
Assert(m->errorCount > 0);
return;
}
llvm::Type *symType = sym->type->LLVMType(g->ctx);
if (symType == NULL) {
Assert(m->errorCount > 0);
return;
}
sym->storagePtr = ctx->AllocaInst(symType, sym->name.c_str());
ctx->SetDebugPos(pos);
ctx->EmitVariableDebugInfo(sym);
// The various basic blocks that we'll need in the below
llvm::BasicBlock *bbFindNext = ctx->CreateBasicBlock("foreach_find_next");
llvm::BasicBlock *bbBody = ctx->CreateBasicBlock("foreach_body");
llvm::BasicBlock *bbCheckForMore = ctx->CreateBasicBlock("foreach_check_for_more");
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("foreach_done");
// Prepare the FunctionEmitContext
ctx->StartScope();
// Save the old internal mask so that we can restore it at the end
llvm::Value *oldMask = ctx->GetInternalMask();
// Now, *maskBitsPtr will maintain a bitmask for the lanes that remain
// to be processed by a pass through the foreach_unique loop body. It
// starts out with the full execution mask (which should never be all
// off going in to this)...
llvm::Value *oldFullMask = ctx->GetFullMask();
llvm::Value *maskBitsPtr = ctx->AllocaInst(LLVMTypes::Int64Type, "mask_bits");
llvm::Value *movmsk = ctx->LaneMask(oldFullMask);
ctx->StoreInst(movmsk, maskBitsPtr);
// Officially start the loop; as far as the FunctionEmitContext is
// concerned, this can be handled the same way as a regular foreach
// loop (continue allowed but not break and return, etc.)
ctx->StartForeach();
ctx->SetContinueTarget(bbCheckForMore);
// Evaluate the varying expression we're iterating over just once.
llvm::Value *exprValue = expr->GetValue(ctx);
// And we'll store its value into locally-allocated storage, for ease
// of indexing over it with non-compile-time-constant indices.
const Type *exprType;
llvm::VectorType *llvmExprType;
if (exprValue == NULL ||
(exprType = expr->GetType()) == NULL ||
(llvmExprType =
llvm::dyn_cast<llvm::VectorType>(exprValue->getType())) == NULL) {
Assert(m->errorCount > 0);
return;
}
ctx->SetDebugPos(pos);
const Type *exprPtrType = PointerType::GetUniform(exprType);
llvm::Value *exprMem = ctx->AllocaInst(llvmExprType, "expr_mem");
ctx->StoreInst(exprValue, exprMem);
// Onward to find the first set of lanes to run the loop for
ctx->BranchInst(bbFindNext);
ctx->SetCurrentBasicBlock(bbFindNext); {
// Load the bitmask of the lanes left to be processed
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
// Find the index of the first set bit in the mask
llvm::Function *ctlzFunc =
m->module->getFunction("__count_trailing_zeros_i64");
Assert(ctlzFunc != NULL);
llvm::Value *firstSet = ctx->CallInst(ctlzFunc, NULL, remainingBits,
"first_set");
// And load the corresponding element value from the temporary
// memory storing the value of the varying expr.
llvm::Value *uniqueValuePtr =
ctx->GetElementPtrInst(exprMem, LLVMInt64(0), firstSet, exprPtrType,
"unique_index_ptr");
llvm::Value *uniqueValue = ctx->LoadInst(uniqueValuePtr, "unique_value");
// If it's a varying pointer type, need to convert from the int
// type we store in the vector to the actual pointer type
if (llvm::dyn_cast<llvm::PointerType>(symType) != NULL)
uniqueValue = ctx->IntToPtrInst(uniqueValue, symType);
// Store that value in sym's storage so that the iteration variable
// has the right value inside the loop body
ctx->StoreInst(uniqueValue, sym->storagePtr);
// Set the execution mask so that it's on for any lane that a) was
// running at the start of the foreach loop, and b) where that
// lane's value of the varying expression is the same as the value
// we've selected to process this time through--i.e.:
// oldMask & (smear(element) == exprValue)
llvm::Value *uniqueSmear = ctx->SmearUniform(uniqueValue, "unique_semar");
llvm::Value *matchingLanes = NULL;
if (uniqueValue->getType()->isFloatingPointTy())
matchingLanes =
ctx->CmpInst(llvm::Instruction::FCmp, llvm::CmpInst::FCMP_OEQ,
uniqueSmear, exprValue, "matching_lanes");
else
matchingLanes =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
uniqueSmear, exprValue, "matching_lanes");
matchingLanes = ctx->I1VecToBoolVec(matchingLanes);
llvm::Value *loopMask =
ctx->BinaryOperator(llvm::Instruction::And, oldMask, matchingLanes,
"foreach_unique_loop_mask");
ctx->SetInternalMask(loopMask);
// Also update the bitvector of lanes left to process in subsequent
// loop iterations:
// remainingBits &= ~movmsk(current mask)
llvm::Value *loopMaskMM = ctx->LaneMask(loopMask);
llvm::Value *notLoopMaskMM = ctx->NotOperator(loopMaskMM);
llvm::Value *newRemaining =
ctx->BinaryOperator(llvm::Instruction::And, remainingBits,
notLoopMaskMM, "new_remaining");
ctx->StoreInst(newRemaining, maskBitsPtr);
// and onward...
ctx->BranchInst(bbBody);
}
ctx->SetCurrentBasicBlock(bbBody); {
// Run the code in the body of the loop. This is easy now.
if (stmts)
stmts->EmitCode(ctx);
Assert(ctx->GetCurrentBasicBlock() != NULL);
ctx->BranchInst(bbCheckForMore);
}
ctx->SetCurrentBasicBlock(bbCheckForMore); {
// At the end of the loop body (either due to running the
// statements normally, or a continue statement in the middle of
// the loop that jumps to the end, see if there are any lanes left
// to be processed.
llvm::Value *remainingBits = ctx->LoadInst(maskBitsPtr, "remaining_bits");
llvm::Value *nonZero =
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_NE,
remainingBits, LLVMInt64(0), "remaining_ne_zero");
ctx->BranchInst(bbFindNext, bbDone, nonZero);
}
ctx->SetCurrentBasicBlock(bbDone);
ctx->SetInternalMask(oldMask);
ctx->EndForeach();
ctx->EndScope();
}
void
ForeachUniqueStmt::Print(int indent) const {
printf("%*cForeach_unique Stmt", indent, ' ');
pos.Print();
printf("\n");
printf("%*cIter symbol: ", indent+4, ' ');
if (sym != NULL) {
printf("%s", sym->name.c_str());
if (sym->type != NULL)
printf(" %s", sym->type->GetString().c_str());
}
else
printf("NULL");
printf("\n");
printf("%*cIter expr: ", indent+4, ' ');
if (expr != NULL)
expr->Print();
else
printf("NULL");
printf("\n");
printf("%*cStmts:\n", indent+4, ' ');
if (stmts != NULL)
stmts->Print(indent+8);
else
printf("NULL");
printf("\n");
}
Stmt *
ForeachUniqueStmt::TypeCheck() {
const Type *type;
if (sym == NULL || expr == NULL || (type = expr->GetType()) == NULL)
return NULL;
if (type->IsVaryingType() == false) {
Error(expr->pos, "Iteration domain type in \"foreach_tiled\" loop "
"must be \"varying\" type, not \"%s\".",
type->GetString().c_str());
return false;
}
if (Type::IsBasicType(type) == false) {
Error(expr->pos, "Iteration domain type in \"foreach_tiled\" loop "
"must be an atomic, pointer, or enum type, not \"%s\".",
type->GetString().c_str());
return false;
}
return this;
}
int
ForeachUniqueStmt::EstimateCost() const {
return COST_VARYING_LOOP;
}
///////////////////////////////////////////////////////////////////////////
// CaseStmt