Provide mask at block entry for switch statements.

This fixes a crash if 'cbreak' was used in a 'switch'.  Renamed
FunctionEmitContext::SetLoopMask() to SetBlockEntryMask(), and
similarly the loopMask member variable.
This commit is contained in:
Matt Pharr
2012-07-06 11:08:05 -07:00
parent ac421f68e2
commit 73afab464f
3 changed files with 43 additions and 40 deletions

68
ctx.cpp
View File

@@ -66,20 +66,20 @@ struct CFInfo {
llvm::BasicBlock *continueTarget, llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr, llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr, llvm::Value *savedContinueLanesPtr,
llvm::Value *savedMask, llvm::Value *savedLoopMask); llvm::Value *savedMask, llvm::Value *savedBlockEntryMask);
static CFInfo *GetForeach(FunctionEmitContext::ForeachType ft, static CFInfo *GetForeach(FunctionEmitContext::ForeachType ft,
llvm::BasicBlock *breakTarget, llvm::BasicBlock *breakTarget,
llvm::BasicBlock *continueTarget, llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr, llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr, llvm::Value *savedContinueLanesPtr,
llvm::Value *savedMask, llvm::Value *savedLoopMask); llvm::Value *savedMask, llvm::Value *savedBlockEntryMask);
static CFInfo *GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget, static CFInfo *GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget,
llvm::BasicBlock *continueTarget, llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr, llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr, llvm::Value *savedContinueLanesPtr,
llvm::Value *savedMask, llvm::Value *savedLoopMask, llvm::Value *savedMask, llvm::Value *savedBlockEntryMask,
llvm::Value *switchExpr, llvm::Value *switchExpr,
llvm::BasicBlock *bbDefault, llvm::BasicBlock *bbDefault,
const std::vector<std::pair<int, llvm::BasicBlock *> > *bbCases, const std::vector<std::pair<int, llvm::BasicBlock *> > *bbCases,
@@ -101,7 +101,7 @@ struct CFInfo {
bool isUniform; bool isUniform;
llvm::BasicBlock *savedBreakTarget, *savedContinueTarget; llvm::BasicBlock *savedBreakTarget, *savedContinueTarget;
llvm::Value *savedBreakLanesPtr, *savedContinueLanesPtr; llvm::Value *savedBreakLanesPtr, *savedContinueLanesPtr;
llvm::Value *savedMask, *savedLoopMask; llvm::Value *savedMask, *savedBlockEntryMask;
llvm::Value *savedSwitchExpr; llvm::Value *savedSwitchExpr;
llvm::BasicBlock *savedDefaultBlock; llvm::BasicBlock *savedDefaultBlock;
const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCaseBlocks; const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCaseBlocks;
@@ -115,7 +115,7 @@ private:
isUniform = uniformIf; isUniform = uniformIf;
savedBreakTarget = savedContinueTarget = NULL; savedBreakTarget = savedContinueTarget = NULL;
savedBreakLanesPtr = savedContinueLanesPtr = NULL; savedBreakLanesPtr = savedContinueLanesPtr = NULL;
savedMask = savedLoopMask = sm; savedMask = savedBlockEntryMask = sm;
savedSwitchExpr = NULL; savedSwitchExpr = NULL;
savedDefaultBlock = NULL; savedDefaultBlock = NULL;
savedCaseBlocks = NULL; savedCaseBlocks = NULL;
@@ -135,7 +135,7 @@ private:
savedBreakLanesPtr = sb; savedBreakLanesPtr = sb;
savedContinueLanesPtr = sc; savedContinueLanesPtr = sc;
savedMask = sm; savedMask = sm;
savedLoopMask = lm; savedBlockEntryMask = lm;
savedSwitchExpr = sse; savedSwitchExpr = sse;
savedDefaultBlock = bbd; savedDefaultBlock = bbd;
savedCaseBlocks = bbc; savedCaseBlocks = bbc;
@@ -153,7 +153,7 @@ private:
savedBreakLanesPtr = sb; savedBreakLanesPtr = sb;
savedContinueLanesPtr = sc; savedContinueLanesPtr = sc;
savedMask = sm; savedMask = sm;
savedLoopMask = lm; savedBlockEntryMask = lm;
savedSwitchExpr = NULL; savedSwitchExpr = NULL;
savedDefaultBlock = NULL; savedDefaultBlock = NULL;
savedCaseBlocks = NULL; savedCaseBlocks = NULL;
@@ -173,10 +173,10 @@ CFInfo::GetLoop(bool isUniform, llvm::BasicBlock *breakTarget,
llvm::BasicBlock *continueTarget, llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr, llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr, llvm::Value *savedContinueLanesPtr,
llvm::Value *savedMask, llvm::Value *savedLoopMask) { llvm::Value *savedMask, llvm::Value *savedBlockEntryMask) {
return new CFInfo(Loop, isUniform, breakTarget, continueTarget, return new CFInfo(Loop, isUniform, breakTarget, continueTarget,
savedBreakLanesPtr, savedContinueLanesPtr, savedBreakLanesPtr, savedContinueLanesPtr,
savedMask, savedLoopMask); savedMask, savedBlockEntryMask);
} }
@@ -214,14 +214,14 @@ CFInfo::GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget,
llvm::BasicBlock *continueTarget, llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr, llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr, llvm::Value *savedMask, llvm::Value *savedContinueLanesPtr, llvm::Value *savedMask,
llvm::Value *savedLoopMask, llvm::Value *savedSwitchExpr, llvm::Value *savedBlockEntryMask, llvm::Value *savedSwitchExpr,
llvm::BasicBlock *savedDefaultBlock, llvm::BasicBlock *savedDefaultBlock,
const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCases, const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCases,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNext, const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNext,
bool savedSwitchConditionUniform) { bool savedSwitchConditionUniform) {
return new CFInfo(Switch, isUniform, breakTarget, continueTarget, return new CFInfo(Switch, isUniform, breakTarget, continueTarget,
savedBreakLanesPtr, savedContinueLanesPtr, savedBreakLanesPtr, savedContinueLanesPtr,
savedMask, savedLoopMask, savedSwitchExpr, savedDefaultBlock, savedMask, savedBlockEntryMask, savedSwitchExpr, savedDefaultBlock,
savedCases, savedNext, savedSwitchConditionUniform); savedCases, savedNext, savedSwitchConditionUniform);
} }
@@ -249,7 +249,7 @@ FunctionEmitContext::FunctionEmitContext(Function *func, Symbol *funSym,
fullMaskPointer = AllocaInst(LLVMTypes::MaskType, "full_mask_memory"); fullMaskPointer = AllocaInst(LLVMTypes::MaskType, "full_mask_memory");
StoreInst(LLVMMaskAllOn, fullMaskPointer); StoreInst(LLVMMaskAllOn, fullMaskPointer);
loopMask = NULL; blockEntryMask = NULL;
breakLanesPtr = continueLanesPtr = NULL; breakLanesPtr = continueLanesPtr = NULL;
breakTarget = continueTarget = NULL; breakTarget = continueTarget = NULL;
@@ -422,8 +422,8 @@ FunctionEmitContext::SetFunctionMask(llvm::Value *value) {
void void
FunctionEmitContext::SetLoopMask(llvm::Value *value) { FunctionEmitContext::SetBlockEntryMask(llvm::Value *value) {
loopMask = value; blockEntryMask = value;
} }
@@ -567,7 +567,7 @@ FunctionEmitContext::StartLoop(llvm::BasicBlock *bt, llvm::BasicBlock *ct,
llvm::Value *oldMask = GetInternalMask(); llvm::Value *oldMask = GetInternalMask();
controlFlowInfo.push_back(CFInfo::GetLoop(uniformCF, breakTarget, controlFlowInfo.push_back(CFInfo::GetLoop(uniformCF, breakTarget,
continueTarget, breakLanesPtr, continueTarget, breakLanesPtr,
continueLanesPtr, oldMask, loopMask)); continueLanesPtr, oldMask, blockEntryMask));
if (uniformCF) if (uniformCF)
// If the loop has a uniform condition, we don't need to track // If the loop has a uniform condition, we don't need to track
// which lanes 'break' or 'continue'; all of the running ones go // which lanes 'break' or 'continue'; all of the running ones go
@@ -584,7 +584,7 @@ FunctionEmitContext::StartLoop(llvm::BasicBlock *bt, llvm::BasicBlock *ct,
breakTarget = bt; breakTarget = bt;
continueTarget = ct; continueTarget = ct;
loopMask = NULL; // this better be set by the loop! blockEntryMask = NULL; // this better be set by the loop!
} }
@@ -625,7 +625,7 @@ FunctionEmitContext::StartForeach(ForeachType ft) {
llvm::Value *oldMask = GetInternalMask(); llvm::Value *oldMask = GetInternalMask();
controlFlowInfo.push_back(CFInfo::GetForeach(ft, breakTarget, continueTarget, controlFlowInfo.push_back(CFInfo::GetForeach(ft, breakTarget, continueTarget,
breakLanesPtr, continueLanesPtr, breakLanesPtr, continueLanesPtr,
oldMask, loopMask)); oldMask, blockEntryMask));
breakLanesPtr = NULL; breakLanesPtr = NULL;
breakTarget = NULL; breakTarget = NULL;
@@ -633,7 +633,7 @@ FunctionEmitContext::StartForeach(ForeachType ft) {
StoreInst(LLVMMaskAllOff, continueLanesPtr); StoreInst(LLVMMaskAllOff, continueLanesPtr);
continueTarget = NULL; // should be set by SetContinueTarget() continueTarget = NULL; // should be set by SetContinueTarget()
loopMask = NULL; blockEntryMask = NULL;
} }
@@ -834,7 +834,7 @@ FunctionEmitContext::ifsInCFAllUniform(int type) const {
void void
FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) { FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) {
llvm::Value *allDone = NULL; llvm::Value *allDone = NULL;
AssertPos(currentPos, continueLanesPtr != NULL);
if (breakLanesPtr == NULL) { if (breakLanesPtr == NULL) {
// In a foreach loop, break and return are illegal, and // In a foreach loop, break and return are illegal, and
// breakLanesPtr is NULL. In this case, the mask is guaranteed to // breakLanesPtr is NULL. In this case, the mask is guaranteed to
@@ -850,18 +850,20 @@ FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) {
// so, everyone is done and we can jump to the given target // so, everyone is done and we can jump to the given target
llvm::Value *returned = LoadInst(returnedLanesPtr, llvm::Value *returned = LoadInst(returnedLanesPtr,
"returned_lanes"); "returned_lanes");
llvm::Value *continued = LoadInst(continueLanesPtr,
"continue_lanes");
llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes"); llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes");
llvm::Value *returnedOrContinued = BinaryOperator(llvm::Instruction::Or, llvm::Value *finishedLanes = BinaryOperator(llvm::Instruction::Or,
returned, continued, returned, breaked,
"returned|continued"); "returned|breaked");
llvm::Value *returnedOrContinuedOrBreaked = if (continueLanesPtr != NULL) {
BinaryOperator(llvm::Instruction::Or, returnedOrContinued, // It's NULL for "switch" statements...
breaked, "returned|continued"); llvm::Value *continued = LoadInst(continueLanesPtr,
"continue_lanes");
finishedLanes = BinaryOperator(llvm::Instruction::Or, finishedLanes,
continued, "returned|breaked|continued");
}
// Do we match the mask at loop entry? // Do we match the mask at loop or switch statement entry?
allDone = MasksAllEqual(returnedOrContinuedOrBreaked, loopMask); allDone = MasksAllEqual(finishedLanes, blockEntryMask);
} }
llvm::BasicBlock *bAll = CreateBasicBlock("all_continued_or_breaked"); llvm::BasicBlock *bAll = CreateBasicBlock("all_continued_or_breaked");
@@ -905,7 +907,7 @@ FunctionEmitContext::StartSwitch(bool cfIsUniform, llvm::BasicBlock *bbBreak) {
controlFlowInfo.push_back(CFInfo::GetSwitch(cfIsUniform, breakTarget, controlFlowInfo.push_back(CFInfo::GetSwitch(cfIsUniform, breakTarget,
continueTarget, breakLanesPtr, continueTarget, breakLanesPtr,
continueLanesPtr, oldMask, continueLanesPtr, oldMask,
loopMask, switchExpr, defaultBlock, blockEntryMask, switchExpr, defaultBlock,
caseBlocks, nextBlocks, caseBlocks, nextBlocks,
switchConditionWasUniform)); switchConditionWasUniform));
@@ -915,7 +917,7 @@ FunctionEmitContext::StartSwitch(bool cfIsUniform, llvm::BasicBlock *bbBreak) {
continueLanesPtr = NULL; continueLanesPtr = NULL;
continueTarget = NULL; continueTarget = NULL;
loopMask = NULL; blockEntryMask = NULL;
// These will be set by the SwitchInst() method // These will be set by the SwitchInst() method
switchExpr = NULL; switchExpr = NULL;
@@ -3526,7 +3528,7 @@ FunctionEmitContext::popCFState() {
continueTarget = ci->savedContinueTarget; continueTarget = ci->savedContinueTarget;
breakLanesPtr = ci->savedBreakLanesPtr; breakLanesPtr = ci->savedBreakLanesPtr;
continueLanesPtr = ci->savedContinueLanesPtr; continueLanesPtr = ci->savedContinueLanesPtr;
loopMask = ci->savedLoopMask; blockEntryMask = ci->savedBlockEntryMask;
switchExpr = ci->savedSwitchExpr; switchExpr = ci->savedSwitchExpr;
defaultBlock = ci->savedDefaultBlock; defaultBlock = ci->savedDefaultBlock;
caseBlocks = ci->savedCaseBlocks; caseBlocks = ci->savedCaseBlocks;
@@ -3538,7 +3540,7 @@ FunctionEmitContext::popCFState() {
continueTarget = ci->savedContinueTarget; continueTarget = ci->savedContinueTarget;
breakLanesPtr = ci->savedBreakLanesPtr; breakLanesPtr = ci->savedBreakLanesPtr;
continueLanesPtr = ci->savedContinueLanesPtr; continueLanesPtr = ci->savedContinueLanesPtr;
loopMask = ci->savedLoopMask; blockEntryMask = ci->savedBlockEntryMask;
} }
else { else {
AssertPos(currentPos, ci->IsIf()); AssertPos(currentPos, ci->IsIf());

10
ctx.h
View File

@@ -158,8 +158,8 @@ public:
bool uniformControlFlow); bool uniformControlFlow);
/** Informs FunctionEmitContext of the value of the mask at the start /** Informs FunctionEmitContext of the value of the mask at the start
of a loop body. */ of a loop body or switch statement. */
void SetLoopMask(llvm::Value *mask); void SetBlockEntryMask(llvm::Value *mask);
/** Informs FunctionEmitContext that code generation for a loop is /** Informs FunctionEmitContext that code generation for a loop is
finished. */ finished. */
@@ -566,9 +566,9 @@ private:
for error messages and debugging symbols. */ for error messages and debugging symbols. */
SourcePos funcStartPos; SourcePos funcStartPos;
/** If currently in a loop body, the value of the mask at the start of /** If currently in a loop body or switch statement, the value of the
the loop. */ mask at the start of it. */
llvm::Value *loopMask; llvm::Value *blockEntryMask;
/** If currently in a loop body or switch statement, this is a pointer /** If currently in a loop body or switch statement, this is a pointer
to memory to store a mask value that represents which of the lanes to memory to store a mask value that represents which of the lanes

View File

@@ -830,7 +830,7 @@ void DoStmt::EmitCode(FunctionEmitContext *ctx) const {
// And now emit code for the loop body // And now emit code for the loop body
ctx->SetCurrentBasicBlock(bloop); ctx->SetCurrentBasicBlock(bloop);
ctx->SetLoopMask(ctx->GetInternalMask()); ctx->SetBlockEntryMask(ctx->GetInternalMask());
ctx->SetDebugPos(pos); ctx->SetDebugPos(pos);
// FIXME: in the StmtList::EmitCode() method takes starts/stops a new // FIXME: in the StmtList::EmitCode() method takes starts/stops a new
// scope around the statements in the list. So if the body is just a // scope around the statements in the list. So if the body is just a
@@ -1047,7 +1047,7 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const {
// On to emitting the code for the loop body. // On to emitting the code for the loop body.
ctx->SetCurrentBasicBlock(bloop); ctx->SetCurrentBasicBlock(bloop);
ctx->SetLoopMask(ctx->GetInternalMask()); ctx->SetBlockEntryMask(ctx->GetInternalMask());
ctx->AddInstrumentationPoint("for loop body"); ctx->AddInstrumentationPoint("for loop body");
if (!dynamic_cast<StmtList *>(stmts)) if (!dynamic_cast<StmtList *>(stmts))
ctx->StartScope(); ctx->StartScope();
@@ -2557,6 +2557,7 @@ SwitchStmt::EmitCode(FunctionEmitContext *ctx) const {
bool isUniformCF = (type->IsUniformType() && bool isUniformCF = (type->IsUniformType() &&
lHasVaryingBreakOrContinue(stmts) == false); lHasVaryingBreakOrContinue(stmts) == false);
ctx->StartSwitch(isUniformCF, bbDone); ctx->StartSwitch(isUniformCF, bbDone);
ctx->SetBlockEntryMask(ctx->GetInternalMask());
ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone, ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone,
svi.caseBlocks, svi.nextBlock); svi.caseBlocks, svi.nextBlock);