diff --git a/ctx.cpp b/ctx.cpp index e45d5465..de8f6177 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -66,20 +66,20 @@ struct CFInfo { llvm::BasicBlock *continueTarget, llvm::Value *savedBreakLanesPtr, llvm::Value *savedContinueLanesPtr, - llvm::Value *savedMask, llvm::Value *savedLoopMask); + llvm::Value *savedMask, llvm::Value *savedBlockEntryMask); static CFInfo *GetForeach(FunctionEmitContext::ForeachType ft, llvm::BasicBlock *breakTarget, llvm::BasicBlock *continueTarget, llvm::Value *savedBreakLanesPtr, llvm::Value *savedContinueLanesPtr, - llvm::Value *savedMask, llvm::Value *savedLoopMask); + llvm::Value *savedMask, llvm::Value *savedBlockEntryMask); static CFInfo *GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget, llvm::BasicBlock *continueTarget, llvm::Value *savedBreakLanesPtr, llvm::Value *savedContinueLanesPtr, - llvm::Value *savedMask, llvm::Value *savedLoopMask, + llvm::Value *savedMask, llvm::Value *savedBlockEntryMask, llvm::Value *switchExpr, llvm::BasicBlock *bbDefault, const std::vector > *bbCases, @@ -101,7 +101,7 @@ struct CFInfo { bool isUniform; llvm::BasicBlock *savedBreakTarget, *savedContinueTarget; llvm::Value *savedBreakLanesPtr, *savedContinueLanesPtr; - llvm::Value *savedMask, *savedLoopMask; + llvm::Value *savedMask, *savedBlockEntryMask; llvm::Value *savedSwitchExpr; llvm::BasicBlock *savedDefaultBlock; const std::vector > *savedCaseBlocks; @@ -115,7 +115,7 @@ private: isUniform = uniformIf; savedBreakTarget = savedContinueTarget = NULL; savedBreakLanesPtr = savedContinueLanesPtr = NULL; - savedMask = savedLoopMask = sm; + savedMask = savedBlockEntryMask = sm; savedSwitchExpr = NULL; savedDefaultBlock = NULL; savedCaseBlocks = NULL; @@ -135,7 +135,7 @@ private: savedBreakLanesPtr = sb; savedContinueLanesPtr = sc; savedMask = sm; - savedLoopMask = lm; + savedBlockEntryMask = lm; savedSwitchExpr = sse; savedDefaultBlock = bbd; savedCaseBlocks = bbc; @@ -153,7 +153,7 @@ private: savedBreakLanesPtr = sb; savedContinueLanesPtr = sc; savedMask = sm; - savedLoopMask = lm; + savedBlockEntryMask = lm; savedSwitchExpr = NULL; savedDefaultBlock = NULL; savedCaseBlocks = NULL; @@ -173,10 +173,10 @@ CFInfo::GetLoop(bool isUniform, llvm::BasicBlock *breakTarget, llvm::BasicBlock *continueTarget, llvm::Value *savedBreakLanesPtr, llvm::Value *savedContinueLanesPtr, - llvm::Value *savedMask, llvm::Value *savedLoopMask) { + llvm::Value *savedMask, llvm::Value *savedBlockEntryMask) { return new CFInfo(Loop, isUniform, breakTarget, continueTarget, savedBreakLanesPtr, savedContinueLanesPtr, - savedMask, savedLoopMask); + savedMask, savedBlockEntryMask); } @@ -214,14 +214,14 @@ CFInfo::GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget, llvm::BasicBlock *continueTarget, llvm::Value *savedBreakLanesPtr, llvm::Value *savedContinueLanesPtr, llvm::Value *savedMask, - llvm::Value *savedLoopMask, llvm::Value *savedSwitchExpr, + llvm::Value *savedBlockEntryMask, llvm::Value *savedSwitchExpr, llvm::BasicBlock *savedDefaultBlock, const std::vector > *savedCases, const std::map *savedNext, bool savedSwitchConditionUniform) { return new CFInfo(Switch, isUniform, breakTarget, continueTarget, savedBreakLanesPtr, savedContinueLanesPtr, - savedMask, savedLoopMask, savedSwitchExpr, savedDefaultBlock, + savedMask, savedBlockEntryMask, savedSwitchExpr, savedDefaultBlock, savedCases, savedNext, savedSwitchConditionUniform); } @@ -249,7 +249,7 @@ FunctionEmitContext::FunctionEmitContext(Function *func, Symbol *funSym, fullMaskPointer = AllocaInst(LLVMTypes::MaskType, "full_mask_memory"); StoreInst(LLVMMaskAllOn, fullMaskPointer); - loopMask = NULL; + blockEntryMask = NULL; breakLanesPtr = continueLanesPtr = NULL; breakTarget = continueTarget = NULL; @@ -422,8 +422,8 @@ FunctionEmitContext::SetFunctionMask(llvm::Value *value) { void -FunctionEmitContext::SetLoopMask(llvm::Value *value) { - loopMask = value; +FunctionEmitContext::SetBlockEntryMask(llvm::Value *value) { + blockEntryMask = value; } @@ -567,7 +567,7 @@ FunctionEmitContext::StartLoop(llvm::BasicBlock *bt, llvm::BasicBlock *ct, llvm::Value *oldMask = GetInternalMask(); controlFlowInfo.push_back(CFInfo::GetLoop(uniformCF, breakTarget, continueTarget, breakLanesPtr, - continueLanesPtr, oldMask, loopMask)); + continueLanesPtr, oldMask, blockEntryMask)); if (uniformCF) // If the loop has a uniform condition, we don't need to track // which lanes 'break' or 'continue'; all of the running ones go @@ -584,7 +584,7 @@ FunctionEmitContext::StartLoop(llvm::BasicBlock *bt, llvm::BasicBlock *ct, breakTarget = bt; continueTarget = ct; - loopMask = NULL; // this better be set by the loop! + blockEntryMask = NULL; // this better be set by the loop! } @@ -625,7 +625,7 @@ FunctionEmitContext::StartForeach(ForeachType ft) { llvm::Value *oldMask = GetInternalMask(); controlFlowInfo.push_back(CFInfo::GetForeach(ft, breakTarget, continueTarget, breakLanesPtr, continueLanesPtr, - oldMask, loopMask)); + oldMask, blockEntryMask)); breakLanesPtr = NULL; breakTarget = NULL; @@ -633,7 +633,7 @@ FunctionEmitContext::StartForeach(ForeachType ft) { StoreInst(LLVMMaskAllOff, continueLanesPtr); continueTarget = NULL; // should be set by SetContinueTarget() - loopMask = NULL; + blockEntryMask = NULL; } @@ -834,7 +834,7 @@ FunctionEmitContext::ifsInCFAllUniform(int type) const { void FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) { llvm::Value *allDone = NULL; - AssertPos(currentPos, continueLanesPtr != NULL); + if (breakLanesPtr == NULL) { // In a foreach loop, break and return are illegal, and // breakLanesPtr is NULL. In this case, the mask is guaranteed to @@ -850,18 +850,20 @@ FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) { // so, everyone is done and we can jump to the given target llvm::Value *returned = LoadInst(returnedLanesPtr, "returned_lanes"); - llvm::Value *continued = LoadInst(continueLanesPtr, - "continue_lanes"); llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes"); - llvm::Value *returnedOrContinued = BinaryOperator(llvm::Instruction::Or, - returned, continued, - "returned|continued"); - llvm::Value *returnedOrContinuedOrBreaked = - BinaryOperator(llvm::Instruction::Or, returnedOrContinued, - breaked, "returned|continued"); + llvm::Value *finishedLanes = BinaryOperator(llvm::Instruction::Or, + returned, breaked, + "returned|breaked"); + if (continueLanesPtr != NULL) { + // It's NULL for "switch" statements... + llvm::Value *continued = LoadInst(continueLanesPtr, + "continue_lanes"); + finishedLanes = BinaryOperator(llvm::Instruction::Or, finishedLanes, + continued, "returned|breaked|continued"); + } - // Do we match the mask at loop entry? - allDone = MasksAllEqual(returnedOrContinuedOrBreaked, loopMask); + // Do we match the mask at loop or switch statement entry? + allDone = MasksAllEqual(finishedLanes, blockEntryMask); } llvm::BasicBlock *bAll = CreateBasicBlock("all_continued_or_breaked"); @@ -905,7 +907,7 @@ FunctionEmitContext::StartSwitch(bool cfIsUniform, llvm::BasicBlock *bbBreak) { controlFlowInfo.push_back(CFInfo::GetSwitch(cfIsUniform, breakTarget, continueTarget, breakLanesPtr, continueLanesPtr, oldMask, - loopMask, switchExpr, defaultBlock, + blockEntryMask, switchExpr, defaultBlock, caseBlocks, nextBlocks, switchConditionWasUniform)); @@ -915,7 +917,7 @@ FunctionEmitContext::StartSwitch(bool cfIsUniform, llvm::BasicBlock *bbBreak) { continueLanesPtr = NULL; continueTarget = NULL; - loopMask = NULL; + blockEntryMask = NULL; // These will be set by the SwitchInst() method switchExpr = NULL; @@ -3526,7 +3528,7 @@ FunctionEmitContext::popCFState() { continueTarget = ci->savedContinueTarget; breakLanesPtr = ci->savedBreakLanesPtr; continueLanesPtr = ci->savedContinueLanesPtr; - loopMask = ci->savedLoopMask; + blockEntryMask = ci->savedBlockEntryMask; switchExpr = ci->savedSwitchExpr; defaultBlock = ci->savedDefaultBlock; caseBlocks = ci->savedCaseBlocks; @@ -3538,7 +3540,7 @@ FunctionEmitContext::popCFState() { continueTarget = ci->savedContinueTarget; breakLanesPtr = ci->savedBreakLanesPtr; continueLanesPtr = ci->savedContinueLanesPtr; - loopMask = ci->savedLoopMask; + blockEntryMask = ci->savedBlockEntryMask; } else { AssertPos(currentPos, ci->IsIf()); diff --git a/ctx.h b/ctx.h index 793242c7..9f1f2788 100644 --- a/ctx.h +++ b/ctx.h @@ -158,8 +158,8 @@ public: bool uniformControlFlow); /** Informs FunctionEmitContext of the value of the mask at the start - of a loop body. */ - void SetLoopMask(llvm::Value *mask); + of a loop body or switch statement. */ + void SetBlockEntryMask(llvm::Value *mask); /** Informs FunctionEmitContext that code generation for a loop is finished. */ @@ -566,9 +566,9 @@ private: for error messages and debugging symbols. */ SourcePos funcStartPos; - /** If currently in a loop body, the value of the mask at the start of - the loop. */ - llvm::Value *loopMask; + /** If currently in a loop body or switch statement, the value of the + mask at the start of it. */ + llvm::Value *blockEntryMask; /** If currently in a loop body or switch statement, this is a pointer to memory to store a mask value that represents which of the lanes diff --git a/stmt.cpp b/stmt.cpp index 30166f9c..344b6993 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -830,7 +830,7 @@ void DoStmt::EmitCode(FunctionEmitContext *ctx) const { // And now emit code for the loop body ctx->SetCurrentBasicBlock(bloop); - ctx->SetLoopMask(ctx->GetInternalMask()); + ctx->SetBlockEntryMask(ctx->GetInternalMask()); ctx->SetDebugPos(pos); // FIXME: in the StmtList::EmitCode() method takes starts/stops a new // scope around the statements in the list. So if the body is just a @@ -1047,7 +1047,7 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const { // On to emitting the code for the loop body. ctx->SetCurrentBasicBlock(bloop); - ctx->SetLoopMask(ctx->GetInternalMask()); + ctx->SetBlockEntryMask(ctx->GetInternalMask()); ctx->AddInstrumentationPoint("for loop body"); if (!dynamic_cast(stmts)) ctx->StartScope(); @@ -2557,6 +2557,7 @@ SwitchStmt::EmitCode(FunctionEmitContext *ctx) const { bool isUniformCF = (type->IsUniformType() && lHasVaryingBreakOrContinue(stmts) == false); ctx->StartSwitch(isUniformCF, bbDone); + ctx->SetBlockEntryMask(ctx->GetInternalMask()); ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone, svi.caseBlocks, svi.nextBlock);