Add support for "switch" statements.

Switches with both uniform and varying "switch" expressions are
supported.  Switch statements with varying expressions and very
large numbers of labels may not perform well; some issues to be
filed shortly will track opportunities for improving these.
This commit is contained in:
Matt Pharr
2012-01-11 09:15:05 -08:00
parent 9670ab0887
commit b67446d998
26 changed files with 1402 additions and 92 deletions

11
ast.cpp
View File

@@ -90,6 +90,9 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
DoStmt *dos; DoStmt *dos;
ForStmt *fs; ForStmt *fs;
ForeachStmt *fes; ForeachStmt *fes;
CaseStmt *cs;
DefaultStmt *defs;
SwitchStmt *ss;
ReturnStmt *rs; ReturnStmt *rs;
LabeledStmt *ls; LabeledStmt *ls;
StmtList *sl; StmtList *sl;
@@ -131,6 +134,14 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
postFunc, data); postFunc, data);
fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data); fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data);
} }
else if ((cs = dynamic_cast<CaseStmt *>(node)) != NULL)
cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data);
else if ((defs = dynamic_cast<DefaultStmt *>(node)) != NULL)
defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data);
else if ((ss = dynamic_cast<SwitchStmt *>(node)) != NULL) {
ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data);
ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data);
}
else if (dynamic_cast<BreakStmt *>(node) != NULL || else if (dynamic_cast<BreakStmt *>(node) != NULL ||
dynamic_cast<ContinueStmt *>(node) != NULL || dynamic_cast<ContinueStmt *>(node) != NULL ||
dynamic_cast<GotoStmt *>(node) != NULL) { dynamic_cast<GotoStmt *>(node) != NULL) {

539
ctx.cpp
View File

@@ -74,18 +74,33 @@ struct CFInfo {
llvm::Value *savedContinueLanesPtr, llvm::Value *savedContinueLanesPtr,
llvm::Value *savedMask, llvm::Value *savedLoopMask); llvm::Value *savedMask, llvm::Value *savedLoopMask);
static CFInfo *GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget,
llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr,
llvm::Value *savedMask, llvm::Value *savedLoopMask,
llvm::Value *switchExpr,
llvm::BasicBlock *bbDefault,
const std::vector<std::pair<int, llvm::BasicBlock *> > *bbCases,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *bbNext);
bool IsIf() { return type == If; } bool IsIf() { return type == If; }
bool IsLoop() { return type == Loop; } bool IsLoop() { return type == Loop; }
bool IsForeach() { return type == Foreach; } bool IsForeach() { return type == Foreach; }
bool IsSwitch() { return type == Switch; }
bool IsVarying() { return !isUniform; } bool IsVarying() { return !isUniform; }
bool IsUniform() { return isUniform; } bool IsUniform() { return isUniform; }
enum CFType { If, Loop, Foreach }; enum CFType { If, Loop, Foreach, Switch };
CFType type; CFType type;
bool isUniform; bool isUniform;
llvm::BasicBlock *savedBreakTarget, *savedContinueTarget; llvm::BasicBlock *savedBreakTarget, *savedContinueTarget;
llvm::Value *savedBreakLanesPtr, *savedContinueLanesPtr; llvm::Value *savedBreakLanesPtr, *savedContinueLanesPtr;
llvm::Value *savedMask, *savedLoopMask; llvm::Value *savedMask, *savedLoopMask;
llvm::Value *savedSwitchExpr;
llvm::BasicBlock *savedDefaultBlock;
const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCaseBlocks;
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNextBlocks;
private: private:
CFInfo(CFType t, bool uniformIf, llvm::Value *sm) { CFInfo(CFType t, bool uniformIf, llvm::Value *sm) {
@@ -95,11 +110,17 @@ private:
savedBreakTarget = savedContinueTarget = NULL; savedBreakTarget = savedContinueTarget = NULL;
savedBreakLanesPtr = savedContinueLanesPtr = NULL; savedBreakLanesPtr = savedContinueLanesPtr = NULL;
savedMask = savedLoopMask = sm; savedMask = savedLoopMask = sm;
savedSwitchExpr = NULL;
savedDefaultBlock = NULL;
savedCaseBlocks = NULL;
savedNextBlocks = NULL;
} }
CFInfo(CFType t, bool iu, llvm::BasicBlock *bt, llvm::BasicBlock *ct, CFInfo(CFType t, bool iu, llvm::BasicBlock *bt, llvm::BasicBlock *ct,
llvm::Value *sb, llvm::Value *sc, llvm::Value *sm, llvm::Value *sb, llvm::Value *sc, llvm::Value *sm,
llvm::Value *lm) { llvm::Value *lm, llvm::Value *sse = NULL, llvm::BasicBlock *bbd = NULL,
Assert(t == Loop); const std::vector<std::pair<int, llvm::BasicBlock *> > *bbc = NULL,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *bbn = NULL) {
Assert(t == Loop || t == Switch);
type = t; type = t;
isUniform = iu; isUniform = iu;
savedBreakTarget = bt; savedBreakTarget = bt;
@@ -108,6 +129,10 @@ private:
savedContinueLanesPtr = sc; savedContinueLanesPtr = sc;
savedMask = sm; savedMask = sm;
savedLoopMask = lm; savedLoopMask = lm;
savedSwitchExpr = sse;
savedDefaultBlock = bbd;
savedCaseBlocks = bbc;
savedNextBlocks = bbn;
} }
CFInfo(CFType t, llvm::BasicBlock *bt, llvm::BasicBlock *ct, CFInfo(CFType t, llvm::BasicBlock *bt, llvm::BasicBlock *ct,
llvm::Value *sb, llvm::Value *sc, llvm::Value *sm, llvm::Value *sb, llvm::Value *sc, llvm::Value *sm,
@@ -121,6 +146,10 @@ private:
savedContinueLanesPtr = sc; savedContinueLanesPtr = sc;
savedMask = sm; savedMask = sm;
savedLoopMask = lm; savedLoopMask = lm;
savedSwitchExpr = NULL;
savedDefaultBlock = NULL;
savedCaseBlocks = NULL;
savedNextBlocks = NULL;
} }
}; };
@@ -154,6 +183,22 @@ CFInfo::GetForeach(llvm::BasicBlock *breakTarget,
savedMask, savedForeachMask); savedMask, savedForeachMask);
} }
CFInfo *
CFInfo::GetSwitch(bool isUniform, llvm::BasicBlock *breakTarget,
llvm::BasicBlock *continueTarget,
llvm::Value *savedBreakLanesPtr,
llvm::Value *savedContinueLanesPtr, llvm::Value *savedMask,
llvm::Value *savedLoopMask, llvm::Value *savedSwitchExpr,
llvm::BasicBlock *savedDefaultBlock,
const std::vector<std::pair<int, llvm::BasicBlock *> > *savedCases,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *savedNext) {
return new CFInfo(Switch, isUniform, breakTarget, continueTarget,
savedBreakLanesPtr, savedContinueLanesPtr,
savedMask, savedLoopMask, savedSwitchExpr, savedDefaultBlock,
savedCases, savedNext);
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
FunctionEmitContext::FunctionEmitContext(Function *func, Symbol *funSym, FunctionEmitContext::FunctionEmitContext(Function *func, Symbol *funSym,
@@ -182,6 +227,11 @@ FunctionEmitContext::FunctionEmitContext(Function *func, Symbol *funSym,
breakLanesPtr = continueLanesPtr = NULL; breakLanesPtr = continueLanesPtr = NULL;
breakTarget = continueTarget = NULL; breakTarget = continueTarget = NULL;
switchExpr = NULL;
caseBlocks = NULL;
defaultBlock = NULL;
nextBlocks = NULL;
returnedLanesPtr = AllocaInst(LLVMTypes::MaskType, "returned_lanes_memory"); returnedLanesPtr = AllocaInst(LLVMTypes::MaskType, "returned_lanes_memory");
StoreInst(LLVMMaskAllOff, returnedLanesPtr); StoreInst(LLVMMaskAllOff, returnedLanesPtr);
@@ -422,51 +472,61 @@ FunctionEmitContext::StartVaryingIf(llvm::Value *oldMask) {
void void
FunctionEmitContext::EndIf() { FunctionEmitContext::EndIf() {
CFInfo *ci = popCFState();
// Make sure we match up with a Start{Uniform,Varying}If(). // Make sure we match up with a Start{Uniform,Varying}If().
Assert(controlFlowInfo.size() > 0 && controlFlowInfo.back()->IsIf()); Assert(ci->IsIf());
CFInfo *ci = controlFlowInfo.back();
controlFlowInfo.pop_back();
// 'uniform' ifs don't change the mask so we only need to restore the // 'uniform' ifs don't change the mask so we only need to restore the
// mask going into the if for 'varying' if statements // mask going into the if for 'varying' if statements
if (!ci->IsUniform() && bblock != NULL) { if (ci->IsUniform() || bblock == NULL)
// We can't just restore the mask as it was going into the 'if' return;
// statement. First we have to take into account any program
// instances that have executed 'return' statements; the restored
// mask must be off for those lanes.
restoreMaskGivenReturns(ci->savedMask);
// If the 'if' statement is inside a loop with a 'varying' // We can't just restore the mask as it was going into the 'if'
// consdition, we also need to account for any break or continue // statement. First we have to take into account any program
// statements that executed inside the 'if' statmeent; we also must // instances that have executed 'return' statements; the restored
// leave the lane masks for the program instances that ran those // mask must be off for those lanes.
// off after we restore the mask after the 'if'. The code below restoreMaskGivenReturns(ci->savedMask);
// ends up being optimized out in the case that there were no break
// or continue statements (and breakLanesPtr and continueLanesPtr
// have their initial 'all off' values), so we don't need to check
// for that here.
if (continueLanesPtr != NULL) {
// We want to compute:
// newMask = (oldMask & ~(breakLanes | continueLanes))
llvm::Value *oldMask = GetInternalMask();
llvm::Value *continueLanes = LoadInst(continueLanesPtr,
"continue_lanes");
llvm::Value *bcLanes = continueLanes;
if (breakLanesPtr != NULL) { // If the 'if' statement is inside a loop with a 'varying'
// breakLanesPtr will be NULL if we're inside a 'foreach' loop // condition, we also need to account for any break or continue
llvm::Value *breakLanes = LoadInst(breakLanesPtr, "break_lanes"); // statements that executed inside the 'if' statmeent; we also must
bcLanes = BinaryOperator(llvm::Instruction::Or, breakLanes, // leave the lane masks for the program instances that ran those
continueLanes, "break|continue_lanes"); // off after we restore the mask after the 'if'. The code below
} // ends up being optimized out in the case that there were no break
// or continue statements (and breakLanesPtr and continueLanesPtr
// have their initial 'all off' values), so we don't need to check
// for that here.
//
// There are three general cases to deal with here:
// - Loops: both break and continue are allowed, and thus the corresponding
// lane mask pointers are non-NULL
// - Foreach: only continueLanesPtr may be non-NULL
// - Switch: only breakLanesPtr may be non-NULL
if (continueLanesPtr != NULL || breakLanesPtr != NULL) {
// We want to compute:
// newMask = (oldMask & ~(breakLanes | continueLanes)),
// treading breakLanes or continueLanes as "all off" if the
// corresponding pointer is NULL.
llvm::Value *bcLanes = NULL;
llvm::Value *notBreakOrContinue = if (continueLanesPtr != NULL)
NotOperator(bcLanes, "!(break|continue)_lanes"); bcLanes = LoadInst(continueLanesPtr, "continue_lanes");
llvm::Value *newMask = else
BinaryOperator(llvm::Instruction::And, oldMask, bcLanes = LLVMMaskAllOff;
notBreakOrContinue, "new_mask");
SetInternalMask(newMask); if (breakLanesPtr != NULL) {
llvm::Value *breakLanes = LoadInst(breakLanesPtr, "break_lanes");
bcLanes = BinaryOperator(llvm::Instruction::Or, bcLanes,
breakLanes, "|break_lanes");
} }
llvm::Value *notBreakOrContinue =
NotOperator(bcLanes, "!(break|continue)_lanes");
llvm::Value *oldMask = GetInternalMask();
llvm::Value *newMask =
BinaryOperator(llvm::Instruction::And, oldMask,
notBreakOrContinue, "new_mask");
SetInternalMask(newMask);
} }
} }
@@ -502,17 +562,8 @@ FunctionEmitContext::StartLoop(llvm::BasicBlock *bt, llvm::BasicBlock *ct,
void void
FunctionEmitContext::EndLoop() { FunctionEmitContext::EndLoop() {
Assert(controlFlowInfo.size() && controlFlowInfo.back()->IsLoop()); CFInfo *ci = popCFState();
CFInfo *ci = controlFlowInfo.back(); Assert(ci->IsLoop());
controlFlowInfo.pop_back();
// Restore the break/continue state information to what it was before
// we went into this loop.
breakTarget = ci->savedBreakTarget;
continueTarget = ci->savedContinueTarget;
breakLanesPtr = ci->savedBreakLanesPtr;
continueLanesPtr = ci->savedContinueLanesPtr;
loopMask = ci->savedLoopMask;
if (!ci->IsUniform()) if (!ci->IsUniform())
// If the loop had a 'uniform' test, then it didn't make any // If the loop had a 'uniform' test, then it didn't make any
@@ -545,17 +596,8 @@ FunctionEmitContext::StartForeach(llvm::BasicBlock *ct) {
void void
FunctionEmitContext::EndForeach() { FunctionEmitContext::EndForeach() {
Assert(controlFlowInfo.size() && controlFlowInfo.back()->IsForeach()); CFInfo *ci = popCFState();
CFInfo *ci = controlFlowInfo.back(); Assert(ci->IsForeach());
controlFlowInfo.pop_back();
// Restore the break/continue state information to what it was before
// we went into this loop.
breakTarget = ci->savedBreakTarget;
continueTarget = ci->savedContinueTarget;
breakLanesPtr = ci->savedBreakLanesPtr;
continueLanesPtr = ci->savedContinueLanesPtr;
loopMask = ci->savedLoopMask;
} }
@@ -576,28 +618,85 @@ FunctionEmitContext::restoreMaskGivenReturns(llvm::Value *oldMask) {
} }
/** Returns "true" if the first enclosing "switch" statement (if any) has a
uniform condition. It is legal to call this outside of the scope of an
enclosing switch. */
bool
FunctionEmitContext::inUniformSwitch() const {
// Go backwards through controlFlowInfo, since we add new nested scopes
// to the back.
int i = controlFlowInfo.size() - 1;
while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch)
--i;
if (i == -1)
return false;
return controlFlowInfo[i]->IsUniform();
}
/** Along the lines of inUniformSwitch(), this returns "true" if the first
enclosing switch has a varying condition. Note that both
inUniformSwitch() and inVaryingSwitch() may return false, which
indicates that we're not currently inside a switch's scope. */
bool
FunctionEmitContext::inVaryingSwitch() const {
// Go backwards through controlFlowInfo, since we add new nested scopes
// to the back.
int i = controlFlowInfo.size() - 1;
while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch)
--i;
if (i == -1)
return false;
return controlFlowInfo[i]->IsVarying();
}
void void
FunctionEmitContext::Break(bool doCoherenceCheck) { FunctionEmitContext::Break(bool doCoherenceCheck) {
Assert(controlFlowInfo.size() > 0);
if (breakTarget == NULL) { if (breakTarget == NULL) {
Error(currentPos, "\"break\" statement is illegal outside of " Error(currentPos, "\"break\" statement is illegal outside of "
"for/while/do loops."); "for/while/do loops and \"switch\" statements.");
return;
}
if (bblock == NULL)
return;
if (inUniformSwitch()) {
// FIXME: Currently, if there are any "break" statements under
// varying "if" statements inside a switch with a uniform
// condition, then the SwitchStmt code promotes the condition to
// varying; hence this assert. However, we can do better than
// that--see issue XXX. When that issue is fixed, this assert will
// be wrong, and should be a second test in the if() statement
// above.
Assert(ifsInCFAllUniform(CFInfo::Switch));
// We know that all program instances are executing the break, so
// just jump to the block immediately after the switch.
Assert(breakTarget != NULL);
BranchInst(breakTarget);
bblock = NULL;
return; return;
} }
// If all of the enclosing 'if' tests in the loop have uniform control // If all of the enclosing 'if' tests in the loop have uniform control
// flow or if we can tell that the mask is all on, then we can just // flow or if we can tell that the mask is all on, then we can just
// jump to the break location. // jump to the break location.
if (ifsInLoopAllUniform() || GetInternalMask() == LLVMMaskAllOn) { if (!inVaryingSwitch() && (ifsInCFAllUniform(CFInfo::Loop) ||
GetInternalMask() == LLVMMaskAllOn)) {
BranchInst(breakTarget); BranchInst(breakTarget);
if (ifsInLoopAllUniform() && doCoherenceCheck) if (ifsInCFAllUniform(CFInfo::Loop) && doCoherenceCheck)
Warning(currentPos, "Coherent break statement not necessary in fully uniform " Warning(currentPos, "Coherent break statement not necessary in "
"control flow."); "fully uniform control flow.");
// Set bblock to NULL since the jump has terminated the basic block // Set bblock to NULL since the jump has terminated the basic block
bblock = NULL; bblock = NULL;
} }
else { else {
// Otherwise we need to update the mask of the lanes that have // Varying switch, or a loop with varying 'if's above the break.
// executed a 'break' statement: // In these cases, we need to update the mask of the lanes that
// have executed a 'break' statement:
// breakLanes = breakLanes | mask // breakLanes = breakLanes | mask
Assert(breakLanesPtr != NULL); Assert(breakLanesPtr != NULL);
llvm::Value *mask = GetInternalMask(); llvm::Value *mask = GetInternalMask();
@@ -613,7 +712,7 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
// an 'if' statement and restore the mask then. // an 'if' statement and restore the mask then.
SetInternalMask(LLVMMaskAllOff); SetInternalMask(LLVMMaskAllOff);
if (doCoherenceCheck) if (doCoherenceCheck && !inVaryingSwitch())
// If the user has indicated that this is a 'coherent' break // If the user has indicated that this is a 'coherent' break
// statement, then check to see if the mask is all off. If so, // statement, then check to see if the mask is all off. If so,
// we have to conservatively jump to the continueTarget, not // we have to conservatively jump to the continueTarget, not
@@ -635,12 +734,12 @@ FunctionEmitContext::Continue(bool doCoherenceCheck) {
return; return;
} }
if (ifsInLoopAllUniform() || GetInternalMask() == LLVMMaskAllOn) { if (ifsInCFAllUniform(CFInfo::Loop) || GetInternalMask() == LLVMMaskAllOn) {
// Similarly to 'break' statements, we can immediately jump to the // Similarly to 'break' statements, we can immediately jump to the
// continue target if we're only in 'uniform' control flow within // continue target if we're only in 'uniform' control flow within
// loop or if we can tell that the mask is all on. // loop or if we can tell that the mask is all on.
AddInstrumentationPoint("continue: uniform CF, jumped"); AddInstrumentationPoint("continue: uniform CF, jumped");
if (ifsInLoopAllUniform() && doCoherenceCheck) if (ifsInCFAllUniform(CFInfo::Loop) && doCoherenceCheck)
Warning(currentPos, "Coherent continue statement not necessary in " Warning(currentPos, "Coherent continue statement not necessary in "
"fully uniform control flow."); "fully uniform control flow.");
BranchInst(continueTarget); BranchInst(continueTarget);
@@ -653,8 +752,9 @@ FunctionEmitContext::Continue(bool doCoherenceCheck) {
llvm::Value *mask = GetInternalMask(); llvm::Value *mask = GetInternalMask();
llvm::Value *continueMask = llvm::Value *continueMask =
LoadInst(continueLanesPtr, "continue_mask"); LoadInst(continueLanesPtr, "continue_mask");
llvm::Value *newMask = BinaryOperator(llvm::Instruction::Or, llvm::Value *newMask =
mask, continueMask, "mask|continueMask"); BinaryOperator(llvm::Instruction::Or, mask, continueMask,
"mask|continueMask");
StoreInst(newMask, continueLanesPtr); StoreInst(newMask, continueLanesPtr);
// And set the current mask to be all off in case there are any // And set the current mask to be all off in case there are any
@@ -671,22 +771,23 @@ FunctionEmitContext::Continue(bool doCoherenceCheck) {
/** This function checks to see if all of the 'if' statements (if any) /** This function checks to see if all of the 'if' statements (if any)
between the current scope and the first enclosing loop have 'uniform' between the current scope and the first enclosing loop/switch of given
tests. control flow type have 'uniform' tests.
*/ */
bool bool
FunctionEmitContext::ifsInLoopAllUniform() const { FunctionEmitContext::ifsInCFAllUniform(int type) const {
Assert(controlFlowInfo.size() > 0); Assert(controlFlowInfo.size() > 0);
// Go backwards through controlFlowInfo, since we add new nested scopes // Go backwards through controlFlowInfo, since we add new nested scopes
// to the back. Stop once we come to the first enclosing loop. // to the back. Stop once we come to the first enclosing control flow
// structure of the desired type.
int i = controlFlowInfo.size() - 1; int i = controlFlowInfo.size() - 1;
while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Loop) { while (i >= 0 && controlFlowInfo[i]->type != type) {
if (controlFlowInfo[i]->isUniform == false) if (controlFlowInfo[i]->isUniform == false)
// Found a scope due to an 'if' statement with a varying test // Found a scope due to an 'if' statement with a varying test
return false; return false;
--i; --i;
} }
Assert(i >= 0); // else we didn't find a loop! Assert(i >= 0); // else we didn't find the expected control flow type!
return true; return true;
} }
@@ -759,6 +860,243 @@ FunctionEmitContext::RestoreContinuedLanes() {
} }
void
FunctionEmitContext::StartSwitch(bool isUniform, llvm::BasicBlock *bbBreak) {
llvm::Value *oldMask = GetInternalMask();
controlFlowInfo.push_back(CFInfo::GetSwitch(isUniform, breakTarget,
continueTarget, breakLanesPtr,
continueLanesPtr, oldMask,
loopMask, switchExpr, defaultBlock,
caseBlocks, nextBlocks));
breakLanesPtr = AllocaInst(LLVMTypes::MaskType, "break_lanes_memory");
StoreInst(LLVMMaskAllOff, breakLanesPtr);
breakTarget = bbBreak;
continueLanesPtr = NULL;
continueTarget = NULL;
loopMask = NULL;
// These will be set by the SwitchInst() method
switchExpr = NULL;
defaultBlock = NULL;
caseBlocks = NULL;
nextBlocks = NULL;
}
void
FunctionEmitContext::EndSwitch() {
Assert(bblock != NULL);
CFInfo *ci = popCFState();
if (ci->IsVarying() && bblock != NULL)
restoreMaskGivenReturns(ci->savedMask);
}
/** Emit code to check for an "all off" mask before the code for a
case or default label in a "switch" statement.
*/
void
FunctionEmitContext::addSwitchMaskCheck(llvm::Value *mask) {
llvm::Value *allOff = None(mask);
llvm::BasicBlock *bbSome = CreateBasicBlock("case_default_on");
// Find the basic block for the case or default label immediately after
// the current one in the switch statement--that's where we want to
// jump if the mask is all off at this label.
Assert(nextBlocks->find(bblock) != nextBlocks->end());
llvm::BasicBlock *bbNext = nextBlocks->find(bblock)->second;
// Jump to the next one of the mask is all off; otherwise jump to the
// newly created block that will hold the actual code for this label.
BranchInst(bbNext, bbSome, allOff);
SetCurrentBasicBlock(bbSome);
}
/** Returns the execution mask at entry to the first enclosing "switch"
statement. */
llvm::Value *
FunctionEmitContext::getMaskAtSwitchEntry() {
Assert(controlFlowInfo.size() > 0);
int i = controlFlowInfo.size() - 1;
while (i >= 0 && controlFlowInfo[i]->type != CFInfo::Switch)
--i;
Assert(i != -1);
return controlFlowInfo[i]->savedMask;
}
void
FunctionEmitContext::EmitDefaultLabel(bool checkMask, SourcePos pos) {
if (!inUniformSwitch() && !inVaryingSwitch()) {
Error(pos, "\"default\" label illegal outside of \"switch\" "
"statement.");
return;
}
// If there's a default label in the switch, a basic block for it
// should have been provided in the previous call to SwitchInst().
Assert(defaultBlock != NULL);
if (bblock != NULL)
// The previous case in the switch fell through, or we're in a
// varying switch; terminate the current block with a jump to the
// block for the code for the default label.
BranchInst(defaultBlock);
SetCurrentBasicBlock(defaultBlock);
if (inUniformSwitch())
// Nothing more to do for the uniform case; return back to the
// caller, which will then emit the code for the default case.
return;
// For a varying switch, we need to update the execution mask.
//
// First, compute the mask that corresponds to which program instances
// should execute the "default" code; this corresponds to the set of
// program instances that don't match any of the case statements.
// Therefore, we generate code that compares the value of the switch
// expression to the value associated with each of the "case"
// statements such that the surviving lanes didn't match any of them.
llvm::Value *matchesDefault = getMaskAtSwitchEntry();
for (int i = 0; i < (int)caseBlocks->size(); ++i) {
int value = (*caseBlocks)[i].first;
llvm::Value *valueVec = (switchExpr->getType() == LLVMTypes::Int32VectorType) ?
LLVMInt32Vector(value) : LLVMInt64Vector(value);
// TODO: for AVX2 at least, the following generates better code
// than doing ICMP_NE and skipping the NotOperator() below; file a
// LLVM bug?
llvm::Value *matchesCaseValue =
CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, switchExpr,
valueVec, "cmp_case_value");
matchesCaseValue = I1VecToBoolVec(matchesCaseValue);
llvm::Value *notMatchesCaseValue = NotOperator(matchesCaseValue);
matchesDefault = BinaryOperator(llvm::Instruction::And, matchesDefault,
notMatchesCaseValue, "default&~case_match");
}
// The mask may have some lanes on, which corresponds to the previous
// label falling through; compute the updated mask by ANDing with the
// current mask.
llvm::Value *oldMask = GetInternalMask();
llvm::Value *newMask = BinaryOperator(llvm::Instruction::Or, oldMask,
matchesDefault, "old_mask|matches_default");
SetInternalMask(newMask);
if (checkMask)
addSwitchMaskCheck(newMask);
}
void
FunctionEmitContext::EmitCaseLabel(int value, bool checkMask, SourcePos pos) {
if (!inUniformSwitch() && !inVaryingSwitch()) {
Error(pos, "\"case\" label illegal outside of \"switch\" statement.");
return;
}
// Find the basic block for this case statement.
llvm::BasicBlock *bbCase = NULL;
Assert(caseBlocks != NULL);
for (int i = 0; i < (int)caseBlocks->size(); ++i)
if ((*caseBlocks)[i].first == value) {
bbCase = (*caseBlocks)[i].second;
break;
}
Assert(bbCase != NULL);
if (bblock != NULL)
// fall through from the previous case
BranchInst(bbCase);
SetCurrentBasicBlock(bbCase);
if (inUniformSwitch())
return;
// update the mask: first, get a mask that indicates which program
// instances have a value for the switch expression that matches this
// case statement.
llvm::Value *valueVec = (switchExpr->getType() == LLVMTypes::Int32VectorType) ?
LLVMInt32Vector(value) : LLVMInt64Vector(value);
llvm::Value *matchesCaseValue =
CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, switchExpr,
valueVec, "cmp_case_value");
matchesCaseValue = I1VecToBoolVec(matchesCaseValue);
// If a lane was off going into the switch, we don't care if has a
// value in the switch expression that happens to match this case.
llvm::Value *entryMask = getMaskAtSwitchEntry();
matchesCaseValue = BinaryOperator(llvm::Instruction::And, entryMask,
matchesCaseValue, "entry_mask&case_match");
// Take the surviving lanes and turn on the mask for them.
llvm::Value *oldMask = GetInternalMask();
llvm::Value *newMask = BinaryOperator(llvm::Instruction::Or, oldMask,
matchesCaseValue, "mask|case_match");
SetInternalMask(newMask);
if (checkMask)
addSwitchMaskCheck(newMask);
}
void
FunctionEmitContext::SwitchInst(llvm::Value *expr, llvm::BasicBlock *bbDefault,
const std::vector<std::pair<int, llvm::BasicBlock *> > &bbCases,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> &bbNext) {
// The calling code should have called StartSwitch() before calling
// SwitchInst().
Assert(controlFlowInfo.size() &&
controlFlowInfo.back()->IsSwitch());
switchExpr = expr;
defaultBlock = bbDefault;
caseBlocks = new std::vector<std::pair<int, llvm::BasicBlock *> >(bbCases);
nextBlocks = new std::map<llvm::BasicBlock *, llvm::BasicBlock *>(bbNext);
if (inUniformSwitch()) {
// For a uniform switch, just wire things up to the LLVM switch
// instruction.
Assert(llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(expr->getType()) ==
false);
llvm::SwitchInst *s = llvm::SwitchInst::Create(expr, bbDefault,
bbCases.size(), bblock);
for (int i = 0; i < (int)bbCases.size(); ++i) {
if (expr->getType() == LLVMTypes::Int32Type)
s->addCase(LLVMInt32(bbCases[i].first), bbCases[i].second);
else {
Assert(expr->getType() == LLVMTypes::Int64Type);
s->addCase(LLVMInt64(bbCases[i].first), bbCases[i].second);
}
}
AddDebugPos(s);
// switch is a terminator
bblock = NULL;
}
else {
// For a varying switch, we first turn off all lanes of the mask
SetInternalMask(LLVMMaskAllOff);
if (nextBlocks->size() > 0) {
// If there are any labels inside the switch, jump to the first
// one; any code before the first label won't be executed by
// anyone.
std::map<llvm::BasicBlock *, llvm::BasicBlock *>::const_iterator iter;
iter = nextBlocks->find(NULL);
Assert(iter != nextBlocks->end());
llvm::BasicBlock *bbFirst = iter->second;
BranchInst(bbFirst);
bblock = NULL;
}
}
}
int int
FunctionEmitContext::VaryingCFDepth() const { FunctionEmitContext::VaryingCFDepth() const {
int sum = 0; int sum = 0;
@@ -905,6 +1243,14 @@ FunctionEmitContext::All(llvm::Value *mask) {
} }
llvm::Value *
FunctionEmitContext::None(llvm::Value *mask) {
llvm::Value *mmval = LaneMask(mask);
return CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ, mmval,
LLVMInt32(0), "none_mm_cmp");
}
llvm::Value * llvm::Value *
FunctionEmitContext::LaneMask(llvm::Value *v) { FunctionEmitContext::LaneMask(llvm::Value *v) {
// Call the target-dependent movmsk function to turn the vector mask // Call the target-dependent movmsk function to turn the vector mask
@@ -2632,3 +2978,36 @@ FunctionEmitContext::addVaryingOffsetsIfNeeded(llvm::Value *ptr,
return BinaryOperator(llvm::Instruction::Add, ptr, offset); return BinaryOperator(llvm::Instruction::Add, ptr, offset);
} }
CFInfo *
FunctionEmitContext::popCFState() {
Assert(controlFlowInfo.size() > 0);
CFInfo *ci = controlFlowInfo.back();
controlFlowInfo.pop_back();
if (ci->IsSwitch()) {
breakTarget = ci->savedBreakTarget;
continueTarget = ci->savedContinueTarget;
breakLanesPtr = ci->savedBreakLanesPtr;
continueLanesPtr = ci->savedContinueLanesPtr;
loopMask = ci->savedLoopMask;
switchExpr = ci->savedSwitchExpr;
defaultBlock = ci->savedDefaultBlock;
caseBlocks = ci->savedCaseBlocks;
nextBlocks = ci->savedNextBlocks;
}
else if (ci->IsLoop() || ci->IsForeach()) {
breakTarget = ci->savedBreakTarget;
continueTarget = ci->savedContinueTarget;
breakLanesPtr = ci->savedBreakLanesPtr;
continueLanesPtr = ci->savedContinueLanesPtr;
loopMask = ci->savedLoopMask;
}
else {
Assert(ci->IsIf());
// nothing to do
}
return ci;
}

93
ctx.h
View File

@@ -187,6 +187,45 @@ public:
previous iteration. */ previous iteration. */
void RestoreContinuedLanes(); void RestoreContinuedLanes();
/** Indicates that code generation for a "switch" statement is about to
start. isUniform indicates whether the "switch" value is uniform,
and bbAfterSwitch gives the basic block immediately following the
"switch" statement. (For example, if the switch condition is
uniform, we jump here upon executing a "break" statement.) */
void StartSwitch(bool isUniform, llvm::BasicBlock *bbAfterSwitch);
/** Indicates the end of code generation for a "switch" statement. */
void EndSwitch();
/** Emits code for a "switch" statement in the program.
@param expr Gives the value of the expression after the "switch"
@param defaultBlock Basic block to execute for the "default" case. This
should be NULL if there is no "default" label inside
the switch.
@param caseBlocks vector that stores the mapping from label values
after "case" statements to basic blocks corresponding
to the "case" labels.
@param nextBlocks For each basic block for a "case" or "default"
label, this gives the basic block for the
immediately-following "case" or "default" label (or
the basic block after the "switch" statement for the
last label.)
*/
void SwitchInst(llvm::Value *expr, llvm::BasicBlock *defaultBlock,
const std::vector<std::pair<int, llvm::BasicBlock *> > &caseBlocks,
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> &nextBlocks);
/** Generates code for a "default" label after a "switch" statement.
The checkMask parameter indicates whether additional code should be
generated to check to see if the execution mask is all off after
the default label (in which case a jump to the following label will
be issued. */
void EmitDefaultLabel(bool checkMask, SourcePos pos);
/** Generates code for a "case" label after a "switch" statement. See
the documentation for EmitDefaultLabel() for discussion of the
checkMask parameter. */
void EmitCaseLabel(int value, bool checkMask, SourcePos pos);
/** Returns the current number of nested levels of 'varying' control /** Returns the current number of nested levels of 'varying' control
flow */ flow */
int VaryingCFDepth() const; int VaryingCFDepth() const;
@@ -221,6 +260,10 @@ public:
i1 value that indicates if all of the mask lanes are on. */ i1 value that indicates if all of the mask lanes are on. */
llvm::Value *All(llvm::Value *mask); llvm::Value *All(llvm::Value *mask);
/** Given a boolean mask value of type LLVMTypes::MaskType, return an
i1 value that indicates if all of the mask lanes are off. */
llvm::Value *None(llvm::Value *mask);
/** Given a boolean mask value of type LLVMTypes::MaskType, return an /** Given a boolean mask value of type LLVMTypes::MaskType, return an
i32 value wherein the i'th bit is on if and only if the i'th lane i32 value wherein the i'th bit is on if and only if the i'th lane
of the mask is on. */ of the mask is on. */
@@ -492,10 +535,10 @@ private:
the loop. */ the loop. */
llvm::Value *loopMask; llvm::Value *loopMask;
/** If currently in a loop body, this is a pointer to memory to store a /** If currently in a loop body or switch statement, this is a pointer
mask value that represents which of the lanes have executed a to memory to store a mask value that represents which of the lanes
'break' statement. If we're not in a loop body, this should be have executed a 'break' statement. If we're not in a loop body or
NULL. */ switch, this should be NULL. */
llvm::Value *breakLanesPtr; llvm::Value *breakLanesPtr;
/** Similar to breakLanesPtr, if we're inside a loop, this is a pointer /** Similar to breakLanesPtr, if we're inside a loop, this is a pointer
@@ -503,16 +546,42 @@ private:
'continue' statement. */ 'continue' statement. */
llvm::Value *continueLanesPtr; llvm::Value *continueLanesPtr;
/** If we're inside a loop, this gives the basic block immediately /** If we're inside a loop or switch statement, this gives the basic
after the current loop, which we will jump to if all of the lanes block immediately after the current loop or switch, which we will
have executed a break statement or are otherwise done with the jump to if all of the lanes have executed a break statement or are
loop. */ otherwise done with it. */
llvm::BasicBlock *breakTarget; llvm::BasicBlock *breakTarget;
/** If we're inside a loop, this gives the block to jump to if all of /** If we're inside a loop, this gives the block to jump to if all of
the running lanes have executed a 'continue' statement. */ the running lanes have executed a 'continue' statement. */
llvm::BasicBlock *continueTarget; llvm::BasicBlock *continueTarget;
/** @name Switch statement state
These variables store various state that's active when we're
generating code for a switch statement. They should all be NULL
outside of a switch.
@{
*/
/** The value of the expression used to determine which case in the
statements after the switch to execute. */
llvm::Value *switchExpr;
/** Map from case label numbers to the basic block that will hold code
for that case. */
const std::vector<std::pair<int, llvm::BasicBlock *> > *caseBlocks;
/** The basic block of code to run for the "default" label in the
switch statement. */
llvm::BasicBlock *defaultBlock;
/** For each basic block for the code for cases (and the default label,
if present), this map gives the basic block for the immediately
following case/default label. */
const std::map<llvm::BasicBlock *, llvm::BasicBlock *> *nextBlocks;
/** @} */
/** A pointer to memory that records which of the program instances /** A pointer to memory that records which of the program instances
have executed a 'return' statement (and are thus really truly done have executed a 'return' statement (and are thus really truly done
running any more instructions in this functions. */ running any more instructions in this functions. */
@@ -556,7 +625,7 @@ private:
llvm::Value *pointerVectorToVoidPointers(llvm::Value *value); llvm::Value *pointerVectorToVoidPointers(llvm::Value *value);
static void addGSMetadata(llvm::Value *inst, SourcePos pos); static void addGSMetadata(llvm::Value *inst, SourcePos pos);
bool ifsInLoopAllUniform() const; bool ifsInCFAllUniform(int cfType) const;
void jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target); void jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target);
llvm::Value *emitGatherCallback(llvm::Value *lvalue, llvm::Value *retPtr); llvm::Value *emitGatherCallback(llvm::Value *lvalue, llvm::Value *retPtr);
@@ -564,6 +633,12 @@ private:
const Type *ptrType); const Type *ptrType);
void restoreMaskGivenReturns(llvm::Value *oldMask); void restoreMaskGivenReturns(llvm::Value *oldMask);
void addSwitchMaskCheck(llvm::Value *mask);
bool inUniformSwitch() const;
bool inVaryingSwitch() const;
llvm::Value *getMaskAtSwitchEntry();
CFInfo *popCFState();
void scatter(llvm::Value *value, llvm::Value *ptr, const Type *ptrType, void scatter(llvm::Value *value, llvm::Value *ptr, const Type *ptrType,
llvm::Value *mask); llvm::Value *mask);

View File

@@ -99,6 +99,7 @@ Contents:
+ `Control Flow`_ + `Control Flow`_
* `Conditional Statements: "if"`_ * `Conditional Statements: "if"`_
* `Conditional Statements: "switch"`_
* `Basic Iteration Statements: "for", "while", and "do"`_ * `Basic Iteration Statements: "for", "while", and "do"`_
* `Unstructured Control Flow: "goto"`_ * `Unstructured Control Flow: "goto"`_
* `"Coherent" Control Flow Statements: "cif" and Friends`_ * `"Coherent" Control Flow Statements: "cif" and Friends`_
@@ -1994,6 +1995,31 @@ executes if the condition is false.
else else
x *= 2.; x *= 2.;
Conditional Statements: "switch"
--------------------------------
The ``switch`` conditional statement is also available, again with the same
behavior as in C; the expression used in the ``switch`` must be of integer
type (but it can be uniform or varying). As in C, if there is no ``break``
statement at the end of the code for a given case, execution "falls
through" to the following case. These features are demonstrated in the
code below.
::
int x = ...;
switch (x) {
case 0:
case 1:
foo(x);
/* fall through */
case 5:
x = 0;
break;
default:
x *= x;
}
Basic Iteration Statements: "for", "while", and "do" Basic Iteration Statements: "for", "while", and "do"
---------------------------------------------------- ----------------------------------------------------

2
ispc.h
View File

@@ -437,6 +437,8 @@ enum {
COST_VARYING_IF = 3, COST_VARYING_IF = 3,
COST_UNIFORM_LOOP = 4, COST_UNIFORM_LOOP = 4,
COST_VARYING_LOOP = 6, COST_VARYING_LOOP = 6,
COST_UNIFORM_SWITCH = 4,
COST_VARYING_SWITCH = 12,
COST_ASSERT = 8, COST_ASSERT = 8,
CHECK_MASK_AT_FUNCTION_START_COST = 16, CHECK_MASK_AT_FUNCTION_START_COST = 16,

View File

@@ -1265,9 +1265,17 @@ labeled_statement
$$ = new LabeledStmt($1, $3, @1); $$ = new LabeledStmt($1, $3, @1);
} }
| TOKEN_CASE constant_expression ':' statement | TOKEN_CASE constant_expression ':' statement
{ UNIMPLEMENTED; } {
int value;
if ($2 != NULL &&
lGetConstantInt($2, &value, @2, "Case statement value")) {
$$ = new CaseStmt(value, $4, Union(@1, @2));
}
else
$$ = NULL;
}
| TOKEN_DEFAULT ':' statement | TOKEN_DEFAULT ':' statement
{ UNIMPLEMENTED; } { $$ = new DefaultStmt($3, @1); }
; ;
start_scope start_scope
@@ -1313,7 +1321,7 @@ selection_statement
| TOKEN_CIF '(' expression ')' statement TOKEN_ELSE statement | TOKEN_CIF '(' expression ')' statement TOKEN_ELSE statement
{ $$ = new IfStmt($3, $5, $7, true, @1); } { $$ = new IfStmt($3, $5, $7, true, @1); }
| TOKEN_SWITCH '(' expression ')' statement | TOKEN_SWITCH '(' expression ')' statement
{ UNIMPLEMENTED; } { $$ = new SwitchStmt($3, $5, @1); }
; ;
for_test for_test

301
stmt.cpp
View File

@@ -1886,6 +1886,307 @@ ForeachStmt::Print(int indent) const {
} }
///////////////////////////////////////////////////////////////////////////
// CaseStmt
/** Given the statements following a 'case' or 'default' label, this
function determines whether the mask should be checked to see if it is
"all off" immediately after the label, before executing the code for
the statements.
*/
static bool
lCheckMask(Stmt *stmts) {
if (stmts == NULL)
return false;
int cost = EstimateCost(stmts);
bool safeToRunWithAllLanesOff = true;
WalkAST(stmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
// The mask should be checked if the code following the
// 'case'/'default' is relatively complex, or if it would be unsafe to
// run that code with the execution mask all off.
return (cost > PREDICATE_SAFE_IF_STATEMENT_COST ||
safeToRunWithAllLanesOff == false);
}
CaseStmt::CaseStmt(int v, Stmt *s, SourcePos pos)
: Stmt(pos), value(v) {
stmts = s;
}
void
CaseStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->EmitCaseLabel(value, lCheckMask(stmts), pos);
if (stmts)
stmts->EmitCode(ctx);
}
void
CaseStmt::Print(int indent) const {
printf("%*cCase [%d] label", indent, ' ', value);
pos.Print();
printf("\n");
stmts->Print(indent+4);
}
Stmt *
CaseStmt::TypeCheck() {
return this;
}
int
CaseStmt::EstimateCost() const {
return 0;
}
///////////////////////////////////////////////////////////////////////////
// DefaultStmt
DefaultStmt::DefaultStmt(Stmt *s, SourcePos pos)
: Stmt(pos) {
stmts = s;
}
void
DefaultStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->EmitDefaultLabel(lCheckMask(stmts), pos);
if (stmts)
stmts->EmitCode(ctx);
}
void
DefaultStmt::Print(int indent) const {
printf("%*cDefault Stmt", indent, ' ');
pos.Print();
printf("\n");
stmts->Print(indent+4);
}
Stmt *
DefaultStmt::TypeCheck() {
return this;
}
int
DefaultStmt::EstimateCost() const {
return 0;
}
///////////////////////////////////////////////////////////////////////////
// SwitchStmt
SwitchStmt::SwitchStmt(Expr *e, Stmt *s, SourcePos pos)
: Stmt(pos) {
expr = e;
stmts = s;
}
/* An instance of this structure is carried along as we traverse the AST
nodes for the statements after a "switch" statement. We use this
structure to record all of the 'case' and 'default' statements after the
"switch". */
struct SwitchVisitInfo {
SwitchVisitInfo(FunctionEmitContext *c) {
ctx = c;
defaultBlock = NULL;
lastBlock = NULL;
}
FunctionEmitContext *ctx;
/* Basic block for the code following the "default" label (if any). */
llvm::BasicBlock *defaultBlock;
/* Map from integer values after "case" labels to the basic blocks that
follow the corresponding "case" label. */
std::vector<std::pair<int, llvm::BasicBlock *> > caseBlocks;
/* For each basic block for a "case" label or a "default" label,
nextBlock[block] stores the basic block pointer for the next
subsequent "case" or "default" label in the program. */
std::map<llvm::BasicBlock *, llvm::BasicBlock *> nextBlock;
/* The last basic block created for a "case" or "default" label; when
we create the basic block for the next one, we'll use this to update
the nextBlock map<> above. */
llvm::BasicBlock *lastBlock;
};
static bool
lSwitchASTPreVisit(ASTNode *node, void *d) {
if (dynamic_cast<SwitchStmt *>(node) != NULL)
// don't continue recursively into a nested switch--we only want
// our own case and default statements!
return false;
CaseStmt *cs = dynamic_cast<CaseStmt *>(node);
DefaultStmt *ds = dynamic_cast<DefaultStmt *>(node);
SwitchVisitInfo *svi = (SwitchVisitInfo *)d;
llvm::BasicBlock *bb = NULL;
if (cs != NULL) {
// Complain if we've seen a case statement with the same value
// already
for (int i = 0; i < (int)svi->caseBlocks.size(); ++i) {
if (svi->caseBlocks[i].first == cs->value) {
Error(cs->pos, "Duplicate case value \"%d\".", cs->value);
return true;
}
}
// Otherwise create a new basic block for the code following this
// 'case' statement and record the mappign between the case label
// value and the basic block
char buf[32];
sprintf(buf, "case_%d", cs->value);
bb = svi->ctx->CreateBasicBlock(buf);
svi->caseBlocks.push_back(std::make_pair(cs->value, bb));
}
else if (ds != NULL) {
// And complain if we've seen another 'default' label..
if (svi->defaultBlock != NULL) {
Error(ds->pos, "Multiple \"default\" lables in switch statement.");
return true;
}
else {
// Otherwise create a basic block for the code following the
// "default".
bb = svi->ctx->CreateBasicBlock("default");
svi->defaultBlock = bb;
}
}
// If we saw a "case" or "default" label, then update the map to record
// that the block we just created follows the block created for the
// previous label in the "switch".
if (bb != NULL) {
svi->nextBlock[svi->lastBlock] = bb;
svi->lastBlock = bb;
}
return true;
}
void
SwitchStmt::EmitCode(FunctionEmitContext *ctx) const {
if (ctx->GetCurrentBasicBlock() == NULL)
return;
const Type *type;
if (expr == NULL || ((type = expr->GetType()) == NULL)) {
Assert(m->errorCount > 0);
return;
}
// Basic block we'll end up after the switch statement
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("switch_done");
// Walk the AST of the statements after the 'switch' to collect a bunch
// of information about the structure of the 'case' and 'default'
// statements.
SwitchVisitInfo svi(ctx);
WalkAST(stmts, lSwitchASTPreVisit, NULL, &svi);
// Record that the basic block following the last one created for a
// case/default is the block after the end of the switch statement.
svi.nextBlock[svi.lastBlock] = bbDone;
llvm::Value *exprValue = expr->GetValue(ctx);
if (exprValue == NULL) {
Assert(m->errorCount > 0);
return;
}
ctx->StartSwitch(type->IsUniformType(), bbDone);
ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone,
svi.caseBlocks, svi.nextBlock);
if (stmts != NULL)
stmts->EmitCode(ctx);
if (ctx->GetCurrentBasicBlock() != NULL)
ctx->BranchInst(bbDone);
ctx->SetCurrentBasicBlock(bbDone);
ctx->EndSwitch();
}
void
SwitchStmt::Print(int indent) const {
printf("%*cSwitch Stmt", indent, ' ');
pos.Print();
printf("\n");
printf("%*cexpr = ", indent, ' ');
expr->Print();
printf("\n");
stmts->Print(indent+4);
}
Stmt *
SwitchStmt::TypeCheck() {
const Type *exprType = expr->GetType();
if (exprType == NULL)
return NULL;
const Type *toType = NULL;
exprType = exprType->GetAsConstType();
bool is64bit = (exprType->GetAsUniformType() ==
AtomicType::UniformConstUInt64 ||
exprType->GetAsUniformType() ==
AtomicType::UniformConstInt64);
// FIXME: if there's a break or continue under varying control flow
// within a switch with a "uniform" condition, we promote the condition
// to varying so that everything works out and we are set to handle the
// resulting divergent control flow. This is somewhat sub-optimal; see
// Issue #XXX for details.
bool isUniform = (exprType->IsUniformType() &&
lHasVaryingBreakOrContinue(stmts) == false);
if (isUniform) {
if (is64bit) toType = AtomicType::UniformInt64;
else toType = AtomicType::UniformInt32;
}
else {
if (is64bit) toType = AtomicType::VaryingInt64;
else toType = AtomicType::VaryingInt32;
}
expr = TypeConvertExpr(expr, toType, "switch expression");
if (expr == NULL)
return NULL;
return this;
}
int
SwitchStmt::EstimateCost() const {
const Type *type = expr->GetType();
if (type && type->IsVaryingType())
return COST_VARYING_SWITCH;
else
return COST_UNIFORM_SWITCH;
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// ReturnStmt // ReturnStmt

59
stmt.h
View File

@@ -282,6 +282,60 @@ public:
}; };
/** Statement corresponding to a "case" label in the program. In addition
to the value associated with the "case", this statement also stores the
statements following it. */
class CaseStmt : public Stmt {
public:
CaseStmt(int value, Stmt *stmt, SourcePos pos);
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *TypeCheck();
int EstimateCost() const;
/** Integer value after the "case" statement */
const int value;
Stmt *stmts;
};
/** Statement for a "default" label (as would be found inside a "switch"
statement). */
class DefaultStmt : public Stmt {
public:
DefaultStmt(Stmt *stmt, SourcePos pos);
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *TypeCheck();
int EstimateCost() const;
Stmt *stmts;
};
/** A "switch" statement in the program. */
class SwitchStmt : public Stmt {
public:
SwitchStmt(Expr *expr, Stmt *stmts, SourcePos pos);
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *TypeCheck();
int EstimateCost() const;
/** Expression that is used to determine which label to jump to. */
Expr *expr;
/** Statement block after the "switch" expression. */
Stmt *stmts;
};
/** A "goto" in an ispc program. */
class GotoStmt : public Stmt { class GotoStmt : public Stmt {
public: public:
GotoStmt(const char *label, SourcePos gotoPos, SourcePos idPos); GotoStmt(const char *label, SourcePos gotoPos, SourcePos idPos);
@@ -293,11 +347,14 @@ public:
Stmt *TypeCheck(); Stmt *TypeCheck();
int EstimateCost() const; int EstimateCost() const;
/** Name of the label to jump to when the goto is executed. */
std::string label; std::string label;
SourcePos identifierPos; SourcePos identifierPos;
}; };
/** Statement corresponding to a label (as would be used as a goto target)
in the program. */
class LabeledStmt : public Stmt { class LabeledStmt : public Stmt {
public: public:
LabeledStmt(const char *label, Stmt *stmt, SourcePos p); LabeledStmt(const char *label, Stmt *stmt, SourcePos p);
@@ -309,7 +366,9 @@ public:
Stmt *TypeCheck(); Stmt *TypeCheck();
int EstimateCost() const; int EstimateCost() const;
/** Name of the label. */
std::string name; std::string name;
/** Statements following the label. */
Stmt *stmt; Stmt *stmt;
}; };

18
tests/switch-1.ispc Normal file
View File

@@ -0,0 +1,18 @@
export uniform int width() { return programCount; }
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
switch (b) {
default:
RET[programIndex] = -1;
break;
case 5:
RET[programIndex] = 0;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
}

44
tests/switch-10.ispc Normal file
View File

@@ -0,0 +1,44 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
case 6:
case 4:
case 5:
if (a & 1)
break;
return 2;
case 1: {
switch (a+b) {
case 6:
return 42;
default:
break;
}
return -1234;
}
case 32:
*((int *)NULL) = 0;
default:
return 0;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[0] = 42;
RET[2] = 1;
RET[6] = RET[4] = 3;
RET[5] = RET[3] = 2;
}

50
tests/switch-11.ispc Normal file
View File

@@ -0,0 +1,50 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
case 6:
case 4:
case 5:
if (a & 1)
break;
return 2;
case 1: {
switch (a+b) {
case 60:
return -1234;
default:
break;
case 6:
if (b == 5)
break;
return -42;
case 12:
return -1;
}
return 42;
}
case 32:
*((int *)NULL) = 0;
default:
return 0;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[0] = 42;
RET[2] = 1;
RET[6] = RET[4] = 3;
RET[5] = RET[3] = 2;
}

54
tests/switch-12.ispc Normal file
View File

@@ -0,0 +1,54 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
case 6:
case 4:
case 5:
if (a & 1)
break;
return 2;
case 1: {
switch (a+b) {
case 60:
return -1234;
default:
break;
case 6:
int count = 0;
for (count = 0; count < 10; ++count) {
a += b;
if (a == 11)
break;
}
return a;
case 12:
return -1;
}
return 42;
}
case 32:
*((int *)NULL) = 0;
default:
return 0;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[0] = 11;
RET[2] = 1;
RET[6] = RET[4] = 3;
RET[5] = RET[3] = 2;
}

17
tests/switch-2.ispc Normal file
View File

@@ -0,0 +1,17 @@
export uniform int width() { return programCount; }
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
switch (b) {
default:
RET[programIndex] = -1;
case 5:
RET[programIndex] = 0;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
}

18
tests/switch-3.ispc Normal file
View File

@@ -0,0 +1,18 @@
export uniform int width() { return programCount; }
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
switch (b) {
case 5:
RET[programIndex] = 0;
break;
default:
RET[programIndex] = -1;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
}

24
tests/switch-4.ispc Normal file
View File

@@ -0,0 +1,24 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
int r = 0;
switch (a) {
case 3:
r = 1;
break;
default:
r = 0;
}
return r;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex == 2) ? 1 : 0;
}

22
tests/switch-5.ispc Normal file
View File

@@ -0,0 +1,22 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
int r = 0;
switch (a) {
case 3:
return 1;
default:
return 0;
}
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex == 2) ? 1 : 0;
}

27
tests/switch-6.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
if (b == 5)
break;
default:
return 0;
}
return -1;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[2] = 1;
RET[6] = -1;
}

32
tests/switch-7.ispc Normal file
View File

@@ -0,0 +1,32 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
case 6:
case 4:
case 5:
if (a & 1)
break;
return 2;
default:
return 0;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[2] = 1;
RET[6] = RET[4] = 3;
RET[5] = RET[3] = 2;
}

36
tests/switch-8.ispc Normal file
View File

@@ -0,0 +1,36 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
case 6:
case 4:
case 5:
if (a & 1)
break;
return 2;
case 32:
*((int *)NULL) = 0;
//CO default:
case 1:
case 2:
return 0;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[2] = 1;
RET[6] = RET[4] = 3;
RET[5] = RET[3] = 2;
}

34
tests/switch-9.ispc Normal file
View File

@@ -0,0 +1,34 @@
export uniform int width() { return programCount; }
int switchit(int a, uniform int b) {
switch (a) {
case 3:
return 1;
case 7:
case 6:
case 4:
case 5:
if (a & 1)
break;
return 2;
case 32:
*((int *)NULL) = 0;
default:
return 0;
}
return 3;
}
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
int a = aFOO[programIndex];
int x = switchit(a, b);
RET[programIndex] = x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
RET[2] = 1;
RET[6] = RET[4] = 3;
RET[5] = RET[3] = 2;
}

View File

@@ -0,0 +1,9 @@
// Case statement value must be a compile-time integer constant
void foo(float f) {
switch (f) {
case 1.5:
++f;
}
}

View File

@@ -0,0 +1,12 @@
// Duplicate case value "1"
void foo(float f) {
switch (f) {
case 1:
++f;
case 2:
case 1:
f = 0;
}
}

View File

@@ -0,0 +1,13 @@
// "case" label illegal outside of "switch" statement
void foo(float f) {
switch (f) {
case 1:
++f;
case 2:
f = 0;
}
case 3:
--f;
}

View File

@@ -0,0 +1,13 @@
// "default" label illegal outside of "switch" statement
void foo(float f) {
default:
++f;
switch (f) {
case 1:
++f;
case 2:
f = 0;
}
}

View File

@@ -0,0 +1,14 @@
// "default" label illegal outside of "switch" statement
void foo(float f) {
default:
++f;
switch (f) {
case 1:
++f;
continue;
case 2:
f = 0;
}
}

View File

@@ -0,0 +1,12 @@
// "continue" statement illegal outside of for/while/do/foreach loops
void foo(float f) {
switch (f) {
case 1:
++f;
continue;
case 2:
f = 0;
}
}