diff --git a/ctx.cpp b/ctx.cpp index 7c238635..0a7dd6d0 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -82,7 +82,8 @@ struct CFInfo { llvm::Value *switchExpr, llvm::BasicBlock *bbDefault, const std::vector > *bbCases, - const std::map *bbNext); + const std::map *bbNext, + bool scUniform); bool IsIf() { return type == If; } bool IsLoop() { return type == Loop; } @@ -101,6 +102,7 @@ struct CFInfo { llvm::BasicBlock *savedDefaultBlock; const std::vector > *savedCaseBlocks; const std::map *savedNextBlocks; + bool savedSwitchConditionWasUniform; private: CFInfo(CFType t, bool uniformIf, llvm::Value *sm) { @@ -119,7 +121,8 @@ private: llvm::Value *sb, llvm::Value *sc, llvm::Value *sm, llvm::Value *lm, llvm::Value *sse = NULL, llvm::BasicBlock *bbd = NULL, const std::vector > *bbc = NULL, - const std::map *bbn = NULL) { + const std::map *bbn = NULL, + bool scu = false) { Assert(t == Loop || t == Switch); type = t; isUniform = iu; @@ -133,6 +136,7 @@ private: savedDefaultBlock = bbd; savedCaseBlocks = bbc; savedNextBlocks = bbn; + savedSwitchConditionWasUniform = scu; } CFInfo(CFType t, llvm::BasicBlock *bt, llvm::BasicBlock *ct, llvm::Value *sb, llvm::Value *sc, llvm::Value *sm, @@ -192,11 +196,12 @@ CFInfo::GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget, llvm::Value *savedLoopMask, llvm::Value *savedSwitchExpr, llvm::BasicBlock *savedDefaultBlock, const std::vector > *savedCases, - const std::map *savedNext) { + const std::map *savedNext, + bool savedSwitchConditionUniform) { return new CFInfo(Switch, isUniform, breakTarget, continueTarget, savedBreakLanesPtr, savedContinueLanesPtr, savedMask, savedLoopMask, savedSwitchExpr, savedDefaultBlock, - savedCases, savedNext); + savedCases, savedNext, savedSwitchConditionUniform); } /////////////////////////////////////////////////////////////////////////// @@ -618,36 +623,20 @@ FunctionEmitContext::restoreMaskGivenReturns(llvm::Value *oldMask) { } -/** Returns "true" if the first enclosing "switch" statement (if any) has a - uniform condition. It is legal to call this outside of the scope of an - enclosing switch. */ +/** Returns "true" if the first enclosing non-if control flow expression is + a "switch" statement. +*/ bool -FunctionEmitContext::inUniformSwitch() const { +FunctionEmitContext::inSwitchStatement() const { // Go backwards through controlFlowInfo, since we add new nested scopes // to the back. int i = controlFlowInfo.size() - 1; - while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch) + while (i >= 0 && controlFlowInfo[i]->IsIf()) --i; + // Got to the first non-if (or end of CF info) if (i == -1) return false; - return controlFlowInfo[i]->IsUniform(); -} - - -/** Along the lines of inUniformSwitch(), this returns "true" if the first - enclosing switch has a varying condition. Note that both - inUniformSwitch() and inVaryingSwitch() may return false, which - indicates that we're not currently inside a switch's scope. */ -bool -FunctionEmitContext::inVaryingSwitch() const { - // Go backwards through controlFlowInfo, since we add new nested scopes - // to the back. - int i = controlFlowInfo.size() - 1; - while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch) - --i; - if (i == -1) - return false; - return controlFlowInfo[i]->IsVarying(); + return controlFlowInfo[i]->IsSwitch(); } @@ -663,16 +652,9 @@ FunctionEmitContext::Break(bool doCoherenceCheck) { if (bblock == NULL) return; - if (inUniformSwitch()) { - // FIXME: Currently, if there are any "break" statements under - // varying "if" statements inside a switch with a uniform - // condition, then the SwitchStmt code promotes the condition to - // varying; hence this assert. However, we can do better than - // that--see issue XXX. When that issue is fixed, this assert will - // be wrong, and should be a second test in the if() statement - // above. - Assert(ifsInCFAllUniform(CFInfo::Switch)); - + if (inSwitchStatement() == true && + switchConditionWasUniform == true && + ifsInCFAllUniform(CFInfo::Switch)) { // We know that all program instances are executing the break, so // just jump to the block immediately after the switch. Assert(breakTarget != NULL); @@ -684,8 +666,9 @@ FunctionEmitContext::Break(bool doCoherenceCheck) { // If all of the enclosing 'if' tests in the loop have uniform control // flow or if we can tell that the mask is all on, then we can just // jump to the break location. - if (!inVaryingSwitch() && (ifsInCFAllUniform(CFInfo::Loop) || - GetInternalMask() == LLVMMaskAllOn)) { + if (inSwitchStatement() == false && + (ifsInCFAllUniform(CFInfo::Loop) || + GetInternalMask() == LLVMMaskAllOn)) { BranchInst(breakTarget); if (ifsInCFAllUniform(CFInfo::Loop) && doCoherenceCheck) Warning(currentPos, "Coherent break statement not necessary in " @@ -694,9 +677,10 @@ FunctionEmitContext::Break(bool doCoherenceCheck) { bblock = NULL; } else { - // Varying switch, or a loop with varying 'if's above the break. - // In these cases, we need to update the mask of the lanes that - // have executed a 'break' statement: + // Varying switch, uniform switch where the 'break' is under + // varying control flow, or a loop with varying 'if's above the + // break. In these cases, we need to update the mask of the lanes + // that have executed a 'break' statement: // breakLanes = breakLanes | mask Assert(breakLanesPtr != NULL); llvm::Value *mask = GetInternalMask(); @@ -712,16 +696,20 @@ FunctionEmitContext::Break(bool doCoherenceCheck) { // an 'if' statement and restore the mask then. SetInternalMask(LLVMMaskAllOff); - if (doCoherenceCheck && !inVaryingSwitch()) - // If the user has indicated that this is a 'coherent' break - // statement, then check to see if the mask is all off. If so, - // we have to conservatively jump to the continueTarget, not - // the breakTarget, since part of the reason the mask is all - // off may be due to 'continue' statements that executed in the - // current loop iteration. - // FIXME: if the loop only has break statements and no - // continues, we can jump to breakTarget in that case. - jumpIfAllLoopLanesAreDone(continueTarget); + if (doCoherenceCheck) { + if (continueTarget != NULL) + // If the user has indicated that this is a 'coherent' + // break statement, then check to see if the mask is all + // off. If so, we have to conservatively jump to the + // continueTarget, not the breakTarget, since part of the + // reason the mask is all off may be due to 'continue' + // statements that executed in the current loop iteration. + jumpIfAllLoopLanesAreDone(continueTarget); + else if (breakTarget != NULL) + // Similarly handle these for switch statements, where we + // only have a break target. + jumpIfAllLoopLanesAreDone(breakTarget); + } } } @@ -861,13 +849,14 @@ FunctionEmitContext::RestoreContinuedLanes() { void -FunctionEmitContext::StartSwitch(bool isUniform, llvm::BasicBlock *bbBreak) { +FunctionEmitContext::StartSwitch(bool cfIsUniform, llvm::BasicBlock *bbBreak) { llvm::Value *oldMask = GetInternalMask(); - controlFlowInfo.push_back(CFInfo::GetSwitch(isUniform, breakTarget, + controlFlowInfo.push_back(CFInfo::GetSwitch(cfIsUniform, breakTarget, continueTarget, breakLanesPtr, continueLanesPtr, oldMask, loopMask, switchExpr, defaultBlock, - caseBlocks, nextBlocks)); + caseBlocks, nextBlocks, + switchConditionWasUniform)); breakLanesPtr = AllocaInst(LLVMTypes::MaskType, "break_lanes_memory"); StoreInst(LLVMMaskAllOff, breakLanesPtr); @@ -931,7 +920,7 @@ FunctionEmitContext::getMaskAtSwitchEntry() { void FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) { - if (!inUniformSwitch() && !inVaryingSwitch()) { + if (inSwitchStatement() == false) { Error(pos, "\"default\" label illegal outside of \"switch\" " "statement."); return; @@ -948,9 +937,9 @@ FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) { BranchInst(defaultBlock); SetCurrentBasicBlock(defaultBlock); - if (inUniformSwitch()) - // Nothing more to do for the uniform case; return back to the - // caller, which will then emit the code for the default case. + if (switchConditionWasUniform) + // Nothing more to do for this case; return back to the caller, + // which will then emit the code for the default case. return; // For a varying switch, we need to update the execution mask. @@ -994,7 +983,7 @@ FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) { void FunctionEmitContext::EmitCaseLabel(int value, bool checkMask, SourcePos pos) { - if (!inUniformSwitch() && !inVaryingSwitch()) { + if (inSwitchStatement() == false) { Error(pos, "\"case\" label illegal outside of \"switch\" statement."); return; } @@ -1014,7 +1003,7 @@ FunctionEmitContext::EmitCaseLabel(int value, bool checkMask, SourcePos pos) { BranchInst(bbCase); SetCurrentBasicBlock(bbCase); - if (inUniformSwitch()) + if (switchConditionWasUniform) return; // update the mask: first, get a mask that indicates which program @@ -1057,12 +1046,12 @@ FunctionEmitContext::SwitchInst(llvm::Value *expr, llvm::BasicBlock *bbDefault, defaultBlock = bbDefault; caseBlocks = new std::vector >(bbCases); nextBlocks = new std::map(bbNext); + switchConditionWasUniform = + (llvm::isa(expr->getType()) == false); - if (inUniformSwitch()) { - // For a uniform switch, just wire things up to the LLVM switch - // instruction. - Assert(llvm::isa(expr->getType()) == - false); + if (switchConditionWasUniform == true) { + // For a uniform switch condition, just wire things up to the LLVM + // switch instruction. llvm::SwitchInst *s = llvm::SwitchInst::Create(expr, bbDefault, bbCases.size(), bblock); for (int i = 0; i < (int)bbCases.size(); ++i) { @@ -2996,6 +2985,7 @@ FunctionEmitContext::popCFState() { defaultBlock = ci->savedDefaultBlock; caseBlocks = ci->savedCaseBlocks; nextBlocks = ci->savedNextBlocks; + switchConditionWasUniform = ci->savedSwitchConditionWasUniform; } else if (ci->IsLoop() || ci->IsForeach()) { breakTarget = ci->savedBreakTarget; diff --git a/ctx.h b/ctx.h index 1cd59a54..adf26560 100644 --- a/ctx.h +++ b/ctx.h @@ -580,6 +580,13 @@ private: if present), this map gives the basic block for the immediately following case/default label. */ const std::map *nextBlocks; + + /** Records whether the switch condition was uniform; this is a + distinct notion from whether the switch represents uniform or + varying control flow; we may have varying control flow from a + uniform switch condition if there is a 'break' inside the switch + that's under varying control flow. */ + bool switchConditionWasUniform; /** @} */ /** A pointer to memory that records which of the program instances @@ -634,8 +641,7 @@ private: void restoreMaskGivenReturns(llvm::Value *oldMask); void addSwitchMaskCheck(llvm::Value *mask); - bool inUniformSwitch() const; - bool inVaryingSwitch() const; + bool inSwitchStatement() const; llvm::Value *getMaskAtSwitchEntry(); CFInfo *popCFState(); diff --git a/stmt.cpp b/stmt.cpp index 0ea0c355..fda693e5 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -2335,7 +2335,9 @@ SwitchStmt::EmitCode(FunctionEmitContext *ctx) const { return; } - ctx->StartSwitch(type->IsUniformType(), bbDone); + bool isUniformCF = (type->IsUniformType() && + lHasVaryingBreakOrContinue(stmts) == false); + ctx->StartSwitch(isUniformCF, bbDone); ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone, svi.caseBlocks, svi.nextBlock); @@ -2375,15 +2377,7 @@ SwitchStmt::TypeCheck() { exprType->GetAsUniformType() == AtomicType::UniformConstInt64); - // FIXME: if there's a break or continue under varying control flow - // within a switch with a "uniform" condition, we promote the condition - // to varying so that everything works out and we are set to handle the - // resulting divergent control flow. This is somewhat sub-optimal; see - // https://github.com/ispc/ispc/issues/156 for details. - bool isUniform = (exprType->IsUniformType() && - lHasVaryingBreakOrContinue(stmts) == false); - - if (isUniform) { + if (exprType->IsUniformType()) { if (is64bit) toType = AtomicType::UniformInt64; else toType = AtomicType::UniformInt32; } diff --git a/tests/switch-13.ispc b/tests/switch-13.ispc new file mode 100644 index 00000000..c1cea57e --- /dev/null +++ b/tests/switch-13.ispc @@ -0,0 +1,28 @@ + +export uniform int width() { return programCount; } + +int switchit(int a, uniform int b) { + int r = -1; + switch (b) { + case 5: + if (a & 1) { + r=3; + break; + } + r= 2; + break; + default: + r= 3; + } + return r; +} + +export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) { + int a = aFOO[programIndex]; + int x = switchit(a, b); + RET[programIndex] = x; +} + +export void result(uniform float RET[]) { + RET[programIndex] = (programIndex & 1) ? 2 : 3; +} diff --git a/tests/switch-14.ispc b/tests/switch-14.ispc new file mode 100644 index 00000000..fc8cfd72 --- /dev/null +++ b/tests/switch-14.ispc @@ -0,0 +1,24 @@ + +export uniform int width() { return programCount; } + +int switchit(int a, uniform int b) { + switch (b) { + case 5: + if (a & 1) + break; + return 2; + default: + return 42; + } + return 3; +} + +export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) { + int a = aFOO[programIndex]; + int x = switchit(a, b); + RET[programIndex] = x; +} + +export void result(uniform float RET[]) { + RET[programIndex] = (programIndex & 1) ? 2 : 3; +}