Improve code for uniform switches with a 'break' under varying control flow.

Previously, when we had a switch statement with a uniform switch condition
but a 'break' statement that was under varying control flow inside the
switch, we'd promote the switch condition to be varying so that the
break would work correctly.

Now, we leave the condition as uniform and are thus able to use the
more-efficient LLVM switch instruction in this case.

Issue #156.
This commit is contained in:
Matt Pharr
2012-01-19 08:41:19 -07:00
parent 6451c3d99d
commit 748b292e77
5 changed files with 120 additions and 78 deletions

122
ctx.cpp
View File

@@ -82,7 +82,8 @@ struct CFInfo {
llvm::Value *switchExpr, llvm::Value *switchExpr,
llvm::BasicBlock *bbDefault, llvm::BasicBlock *bbDefault,
const std::vector<std::pair<int, llvm::BasicBlock *> > *bbCases, const std::vector<std::pair<int, llvm::BasicBlock *> > *bbCases,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *bbNext); const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *bbNext,
bool scUniform);
bool IsIf() { return type == If; } bool IsIf() { return type == If; }
bool IsLoop() { return type == Loop; } bool IsLoop() { return type == Loop; }
@@ -101,6 +102,7 @@ struct CFInfo {
llvm::BasicBlock *savedDefaultBlock; llvm::BasicBlock *savedDefaultBlock;
const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCaseBlocks; const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCaseBlocks;
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNextBlocks; const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNextBlocks;
bool savedSwitchConditionWasUniform;
private: private:
CFInfo(CFType t, bool uniformIf, llvm::Value *sm) { CFInfo(CFType t, bool uniformIf, llvm::Value *sm) {
@@ -119,7 +121,8 @@ private:
llvm::Value *sb, llvm::Value *sc, llvm::Value *sm, llvm::Value *sb, llvm::Value *sc, llvm::Value *sm,
llvm::Value *lm, llvm::Value *sse = NULL, llvm::BasicBlock *bbd = NULL, llvm::Value *lm, llvm::Value *sse = NULL, llvm::BasicBlock *bbd = NULL,
const std::vector<std::pair<int, llvm::BasicBlock *> > *bbc = NULL, const std::vector<std::pair<int, llvm::BasicBlock *> > *bbc = NULL,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *bbn = NULL) { const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *bbn = NULL,
bool scu = false) {
Assert(t == Loop || t == Switch); Assert(t == Loop || t == Switch);
type = t; type = t;
isUniform = iu; isUniform = iu;
@@ -133,6 +136,7 @@ private:
savedDefaultBlock = bbd; savedDefaultBlock = bbd;
savedCaseBlocks = bbc; savedCaseBlocks = bbc;
savedNextBlocks = bbn; savedNextBlocks = bbn;
savedSwitchConditionWasUniform = scu;
} }
CFInfo(CFType t, llvm::BasicBlock *bt, llvm::BasicBlock *ct, CFInfo(CFType t, llvm::BasicBlock *bt, llvm::BasicBlock *ct,
llvm::Value *sb, llvm::Value *sc, llvm::Value *sm, llvm::Value *sb, llvm::Value *sc, llvm::Value *sm,
@@ -192,11 +196,12 @@ CFInfo::GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget,
llvm::Value *savedLoopMask, llvm::Value *savedSwitchExpr, llvm::Value *savedLoopMask, llvm::Value *savedSwitchExpr,
llvm::BasicBlock *savedDefaultBlock, llvm::BasicBlock *savedDefaultBlock,
const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCases, const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCases,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNext) { const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNext,
bool savedSwitchConditionUniform) {
return new CFInfo(Switch, isUniform, breakTarget, continueTarget, return new CFInfo(Switch, isUniform, breakTarget, continueTarget,
savedBreakLanesPtr, savedContinueLanesPtr, savedBreakLanesPtr, savedContinueLanesPtr,
savedMask, savedLoopMask, savedSwitchExpr, savedDefaultBlock, savedMask, savedLoopMask, savedSwitchExpr, savedDefaultBlock,
savedCases, savedNext); savedCases, savedNext, savedSwitchConditionUniform);
} }
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
@@ -618,36 +623,20 @@ FunctionEmitContext::restoreMaskGivenReturns(llvm::Value *oldMask) {
} }
/** Returns "true" if the first enclosing "switch" statement (if any) has a /** Returns "true" if the first enclosing non-if control flow expression is
uniform condition. It is legal to call this outside of the scope of an a "switch" statement.
enclosing switch. */ */
bool bool
FunctionEmitContext::inUniformSwitch() const { FunctionEmitContext::inSwitchStatement() const {
// Go backwards through controlFlowInfo, since we add new nested scopes // Go backwards through controlFlowInfo, since we add new nested scopes
// to the back. // to the back.
int i = controlFlowInfo.size() - 1; int i = controlFlowInfo.size() - 1;
while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch) while (i >= 0 && controlFlowInfo[i]->IsIf())
--i; --i;
// Got to the first non-if (or end of CF info)
if (i == -1) if (i == -1)
return false; return false;
return controlFlowInfo[i]->IsUniform(); return controlFlowInfo[i]->IsSwitch();
}
/** Along the lines of inUniformSwitch(), this returns "true" if the first
enclosing switch has a varying condition. Note that both
inUniformSwitch() and inVaryingSwitch() may return false, which
indicates that we're not currently inside a switch's scope. */
bool
FunctionEmitContext::inVaryingSwitch() const {
// Go backwards through controlFlowInfo, since we add new nested scopes
// to the back.
int i = controlFlowInfo.size() - 1;
while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch)
--i;
if (i == -1)
return false;
return controlFlowInfo[i]->IsVarying();
} }
@@ -663,16 +652,9 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
if (bblock == NULL) if (bblock == NULL)
return; return;
if (inUniformSwitch()) { if (inSwitchStatement() == true &&
// FIXME: Currently, if there are any "break" statements under switchConditionWasUniform == true &&
// varying "if" statements inside a switch with a uniform ifsInCFAllUniform(CFInfo::Switch)) {
// condition, then the SwitchStmt code promotes the condition to
// varying; hence this assert. However, we can do better than
// that--see issue XXX. When that issue is fixed, this assert will
// be wrong, and should be a second test in the if() statement
// above.
Assert(ifsInCFAllUniform(CFInfo::Switch));
// We know that all program instances are executing the break, so // We know that all program instances are executing the break, so
// just jump to the block immediately after the switch. // just jump to the block immediately after the switch.
Assert(breakTarget != NULL); Assert(breakTarget != NULL);
@@ -684,8 +666,9 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
// If all of the enclosing 'if' tests in the loop have uniform control // If all of the enclosing 'if' tests in the loop have uniform control
// flow or if we can tell that the mask is all on, then we can just // flow or if we can tell that the mask is all on, then we can just
// jump to the break location. // jump to the break location.
if (!inVaryingSwitch() && (ifsInCFAllUniform(CFInfo::Loop) || if (inSwitchStatement() == false &&
GetInternalMask() == LLVMMaskAllOn)) { (ifsInCFAllUniform(CFInfo::Loop) ||
GetInternalMask() == LLVMMaskAllOn)) {
BranchInst(breakTarget); BranchInst(breakTarget);
if (ifsInCFAllUniform(CFInfo::Loop) && doCoherenceCheck) if (ifsInCFAllUniform(CFInfo::Loop) && doCoherenceCheck)
Warning(currentPos, "Coherent break statement not necessary in " Warning(currentPos, "Coherent break statement not necessary in "
@@ -694,9 +677,10 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
bblock = NULL; bblock = NULL;
} }
else { else {
// Varying switch, or a loop with varying 'if's above the break. // Varying switch, uniform switch where the 'break' is under
// In these cases, we need to update the mask of the lanes that // varying control flow, or a loop with varying 'if's above the
// have executed a 'break' statement: // break. In these cases, we need to update the mask of the lanes
// that have executed a 'break' statement:
// breakLanes = breakLanes | mask // breakLanes = breakLanes | mask
Assert(breakLanesPtr != NULL); Assert(breakLanesPtr != NULL);
llvm::Value *mask = GetInternalMask(); llvm::Value *mask = GetInternalMask();
@@ -712,16 +696,20 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
// an 'if' statement and restore the mask then. // an 'if' statement and restore the mask then.
SetInternalMask(LLVMMaskAllOff); SetInternalMask(LLVMMaskAllOff);
if (doCoherenceCheck && !inVaryingSwitch()) if (doCoherenceCheck) {
// If the user has indicated that this is a 'coherent' break if (continueTarget != NULL)
// statement, then check to see if the mask is all off. If so, // If the user has indicated that this is a 'coherent'
// we have to conservatively jump to the continueTarget, not // break statement, then check to see if the mask is all
// the breakTarget, since part of the reason the mask is all // off. If so, we have to conservatively jump to the
// off may be due to 'continue' statements that executed in the // continueTarget, not the breakTarget, since part of the
// current loop iteration. // reason the mask is all off may be due to 'continue'
// FIXME: if the loop only has break statements and no // statements that executed in the current loop iteration.
// continues, we can jump to breakTarget in that case. jumpIfAllLoopLanesAreDone(continueTarget);
jumpIfAllLoopLanesAreDone(continueTarget); else if (breakTarget != NULL)
// Similarly handle these for switch statements, where we
// only have a break target.
jumpIfAllLoopLanesAreDone(breakTarget);
}
} }
} }
@@ -861,13 +849,14 @@ FunctionEmitContext::RestoreContinuedLanes() {
void void
FunctionEmitContext::StartSwitch(bool isUniform, llvm::BasicBlock *bbBreak) { FunctionEmitContext::StartSwitch(bool cfIsUniform, llvm::BasicBlock *bbBreak) {
llvm::Value *oldMask = GetInternalMask(); llvm::Value *oldMask = GetInternalMask();
controlFlowInfo.push_back(CFInfo::GetSwitch(isUniform, breakTarget, controlFlowInfo.push_back(CFInfo::GetSwitch(cfIsUniform, breakTarget,
continueTarget, breakLanesPtr, continueTarget, breakLanesPtr,
continueLanesPtr, oldMask, continueLanesPtr, oldMask,
loopMask, switchExpr, defaultBlock, loopMask, switchExpr, defaultBlock,
caseBlocks, nextBlocks)); caseBlocks, nextBlocks,
switchConditionWasUniform));
breakLanesPtr = AllocaInst(LLVMTypes::MaskType, "break_lanes_memory"); breakLanesPtr = AllocaInst(LLVMTypes::MaskType, "break_lanes_memory");
StoreInst(LLVMMaskAllOff, breakLanesPtr); StoreInst(LLVMMaskAllOff, breakLanesPtr);
@@ -931,7 +920,7 @@ FunctionEmitContext::getMaskAtSwitchEntry() {
void void
FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) { FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) {
if (!inUniformSwitch() && !inVaryingSwitch()) { if (inSwitchStatement() == false) {
Error(pos, "\"default\" label illegal outside of \"switch\" " Error(pos, "\"default\" label illegal outside of \"switch\" "
"statement."); "statement.");
return; return;
@@ -948,9 +937,9 @@ FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) {
BranchInst(defaultBlock); BranchInst(defaultBlock);
SetCurrentBasicBlock(defaultBlock); SetCurrentBasicBlock(defaultBlock);
if (inUniformSwitch()) if (switchConditionWasUniform)
// Nothing more to do for the uniform case; return back to the // Nothing more to do for this case; return back to the caller,
// caller, which will then emit the code for the default case. // which will then emit the code for the default case.
return; return;
// For a varying switch, we need to update the execution mask. // For a varying switch, we need to update the execution mask.
@@ -994,7 +983,7 @@ FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) {
void void
FunctionEmitContext::EmitCaseLabel(int value, bool checkMask, SourcePos pos) { FunctionEmitContext::EmitCaseLabel(int value, bool checkMask, SourcePos pos) {
if (!inUniformSwitch() && !inVaryingSwitch()) { if (inSwitchStatement() == false) {
Error(pos, "\"case\" label illegal outside of \"switch\" statement."); Error(pos, "\"case\" label illegal outside of \"switch\" statement.");
return; return;
} }
@@ -1014,7 +1003,7 @@ FunctionEmitContext::EmitCaseLabel(int value, bool checkMask, SourcePos pos) {
BranchInst(bbCase); BranchInst(bbCase);
SetCurrentBasicBlock(bbCase); SetCurrentBasicBlock(bbCase);
if (inUniformSwitch()) if (switchConditionWasUniform)
return; return;
// update the mask: first, get a mask that indicates which program // update the mask: first, get a mask that indicates which program
@@ -1057,12 +1046,12 @@ FunctionEmitContext::SwitchInst(llvm::Value *expr, llvm::BasicBlock *bbDefault,
defaultBlock = bbDefault; defaultBlock = bbDefault;
caseBlocks = new std::vector<std::pair<int, llvm::BasicBlock *> >(bbCases); caseBlocks = new std::vector<std::pair<int, llvm::BasicBlock *> >(bbCases);
nextBlocks = new std::map<llvm::BasicBlock *, llvm::BasicBlock *>(bbNext); nextBlocks = new std::map<llvm::BasicBlock *, llvm::BasicBlock *>(bbNext);
switchConditionWasUniform =
(llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(expr->getType()) == false);
if (inUniformSwitch()) { if (switchConditionWasUniform == true) {
// For a uniform switch, just wire things up to the LLVM switch // For a uniform switch condition, just wire things up to the LLVM
// instruction. // switch instruction.
Assert(llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(expr->getType()) ==
false);
llvm::SwitchInst *s = llvm::SwitchInst::Create(expr, bbDefault, llvm::SwitchInst *s = llvm::SwitchInst::Create(expr, bbDefault,
bbCases.size(), bblock); bbCases.size(), bblock);
for (int i = 0; i < (int)bbCases.size(); ++i) { for (int i = 0; i < (int)bbCases.size(); ++i) {
@@ -2996,6 +2985,7 @@ FunctionEmitContext::popCFState() {
defaultBlock = ci->savedDefaultBlock; defaultBlock = ci->savedDefaultBlock;
caseBlocks = ci->savedCaseBlocks; caseBlocks = ci->savedCaseBlocks;
nextBlocks = ci->savedNextBlocks; nextBlocks = ci->savedNextBlocks;
switchConditionWasUniform = ci->savedSwitchConditionWasUniform;
} }
else if (ci->IsLoop() || ci->IsForeach()) { else if (ci->IsLoop() || ci->IsForeach()) {
breakTarget = ci->savedBreakTarget; breakTarget = ci->savedBreakTarget;

10
ctx.h
View File

@@ -580,6 +580,13 @@ private:
if present), this map gives the basic block for the immediately if present), this map gives the basic block for the immediately
following case/default label. */ following case/default label. */
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *nextBlocks; const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *nextBlocks;
/** Records whether the switch condition was uniform; this is a
distinct notion from whether the switch represents uniform or
varying control flow; we may have varying control flow from a
uniform switch condition if there is a 'break' inside the switch
that's under varying control flow. */
bool switchConditionWasUniform;
/** @} */ /** @} */
/** A pointer to memory that records which of the program instances /** A pointer to memory that records which of the program instances
@@ -634,8 +641,7 @@ private:
void restoreMaskGivenReturns(llvm::Value *oldMask); void restoreMaskGivenReturns(llvm::Value *oldMask);
void addSwitchMaskCheck(llvm::Value *mask); void addSwitchMaskCheck(llvm::Value *mask);
bool inUniformSwitch() const; bool inSwitchStatement() const;
bool inVaryingSwitch() const;
llvm::Value *getMaskAtSwitchEntry(); llvm::Value *getMaskAtSwitchEntry();
CFInfo *popCFState(); CFInfo *popCFState();

View File

@@ -2335,7 +2335,9 @@ SwitchStmt::EmitCode(FunctionEmitContext *ctx) const {
return; return;
} }
ctx->StartSwitch(type->IsUniformType(), bbDone); bool isUniformCF = (type->IsUniformType() &&
lHasVaryingBreakOrContinue(stmts) == false);
ctx->StartSwitch(isUniformCF, bbDone);
ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone, ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone,
svi.caseBlocks, svi.nextBlock); svi.caseBlocks, svi.nextBlock);
@@ -2375,15 +2377,7 @@ SwitchStmt::TypeCheck() {
exprType->GetAsUniformType() == exprType->GetAsUniformType() ==
AtomicType::UniformConstInt64); AtomicType::UniformConstInt64);
// FIXME: if there's a break or continue under varying control flow if (exprType->IsUniformType()) {
// within a switch with a "uniform" condition, we promote the condition
// to varying so that everything works out and we are set to handle the
// resulting divergent control flow. This is somewhat sub-optimal; see
// https://github.com/ispc/ispc/issues/156 for details.
bool isUniform = (exprType->IsUniformType() &&
lHasVaryingBreakOrContinue(stmts) == false);
if (isUniform) {
if (is64bit) toType = AtomicType::UniformInt64; if (is64bit) toType = AtomicType::UniformInt64;
else toType = AtomicType::UniformInt32; else toType = AtomicType::UniformInt32;
} }

28
tests/switch-13.ispc Normal file
View File

@@ -0,0 +1,28 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
int r = -1;
switch (b) {
case 5:
if (a & 1) {
r=3;
break;
}
r= 2;
break;
default:
r= 3;
}
return r;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex & 1) ? 2 : 3;
}

24
tests/switch-14.ispc Normal file
View File

@@ -0,0 +1,24 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (b) {
case 5:
if (a & 1)
break;
return 2;
default:
return 42;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex & 1) ? 2 : 3;
}