Add support for "switch" statements.

Switches with both uniform and varying "switch" expressions are
supported.  Switch statements with varying expressions and very
large numbers of labels may not perform well; some issues to be
filed shortly will track opportunities for improving these.
This commit is contained in:
Matt Pharr
2012-01-11 09:15:05 -08:00
parent 9670ab0887
commit b67446d998
26 changed files with 1402 additions and 92 deletions

301
stmt.cpp
View File

@@ -1886,6 +1886,307 @@ ForeachStmt::Print(int indent) const {
}
///////////////////////////////////////////////////////////////////////////
// CaseStmt
/** Given the statements following a 'case' or 'default' label, this
function determines whether the mask should be checked to see if it is
"all off" immediately after the label, before executing the code for
the statements.
*/
static bool
lCheckMask(Stmt *stmts) {
if (stmts == NULL)
return false;
int cost = EstimateCost(stmts);
bool safeToRunWithAllLanesOff = true;
WalkAST(stmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
// The mask should be checked if the code following the
// 'case'/'default' is relatively complex, or if it would be unsafe to
// run that code with the execution mask all off.
return (cost > PREDICATE_SAFE_IF_STATEMENT_COST ||
safeToRunWithAllLanesOff == false);
}
CaseStmt::CaseStmt(int v, Stmt *s, SourcePos pos)
: Stmt(pos), value(v) {
stmts = s;
}
void
CaseStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->EmitCaseLabel(value, lCheckMask(stmts), pos);
if (stmts)
stmts->EmitCode(ctx);
}
void
CaseStmt::Print(int indent) const {
printf("%*cCase [%d] label", indent, ' ', value);
pos.Print();
printf("\n");
stmts->Print(indent+4);
}
Stmt *
CaseStmt::TypeCheck() {
return this;
}
int
CaseStmt::EstimateCost() const {
return 0;
}
///////////////////////////////////////////////////////////////////////////
// DefaultStmt
DefaultStmt::DefaultStmt(Stmt *s, SourcePos pos)
: Stmt(pos) {
stmts = s;
}
void
DefaultStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->EmitDefaultLabel(lCheckMask(stmts), pos);
if (stmts)
stmts->EmitCode(ctx);
}
void
DefaultStmt::Print(int indent) const {
printf("%*cDefault Stmt", indent, ' ');
pos.Print();
printf("\n");
stmts->Print(indent+4);
}
Stmt *
DefaultStmt::TypeCheck() {
return this;
}
int
DefaultStmt::EstimateCost() const {
return 0;
}
///////////////////////////////////////////////////////////////////////////
// SwitchStmt
SwitchStmt::SwitchStmt(Expr *e, Stmt *s, SourcePos pos)
: Stmt(pos) {
expr = e;
stmts = s;
}
/* An instance of this structure is carried along as we traverse the AST
nodes for the statements after a "switch" statement. We use this
structure to record all of the 'case' and 'default' statements after the
"switch". */
struct SwitchVisitInfo {
SwitchVisitInfo(FunctionEmitContext *c) {
ctx = c;
defaultBlock = NULL;
lastBlock = NULL;
}
FunctionEmitContext *ctx;
/* Basic block for the code following the "default" label (if any). */
llvm::BasicBlock *defaultBlock;
/* Map from integer values after "case" labels to the basic blocks that
follow the corresponding "case" label. */
std::vector<std::pair<int, llvm::BasicBlock *> > caseBlocks;
/* For each basic block for a "case" label or a "default" label,
nextBlock[block] stores the basic block pointer for the next
subsequent "case" or "default" label in the program. */
std::map<llvm::BasicBlock *, llvm::BasicBlock *> nextBlock;
/* The last basic block created for a "case" or "default" label; when
we create the basic block for the next one, we'll use this to update
the nextBlock map<> above. */
llvm::BasicBlock *lastBlock;
};
static bool
lSwitchASTPreVisit(ASTNode *node, void *d) {
if (dynamic_cast<SwitchStmt *>(node) != NULL)
// don't continue recursively into a nested switch--we only want
// our own case and default statements!
return false;
CaseStmt *cs = dynamic_cast<CaseStmt *>(node);
DefaultStmt *ds = dynamic_cast<DefaultStmt *>(node);
SwitchVisitInfo *svi = (SwitchVisitInfo *)d;
llvm::BasicBlock *bb = NULL;
if (cs != NULL) {
// Complain if we've seen a case statement with the same value
// already
for (int i = 0; i < (int)svi->caseBlocks.size(); ++i) {
if (svi->caseBlocks[i].first == cs->value) {
Error(cs->pos, "Duplicate case value \"%d\".", cs->value);
return true;
}
}
// Otherwise create a new basic block for the code following this
// 'case' statement and record the mappign between the case label
// value and the basic block
char buf[32];
sprintf(buf, "case_%d", cs->value);
bb = svi->ctx->CreateBasicBlock(buf);
svi->caseBlocks.push_back(std::make_pair(cs->value, bb));
}
else if (ds != NULL) {
// And complain if we've seen another 'default' label..
if (svi->defaultBlock != NULL) {
Error(ds->pos, "Multiple \"default\" lables in switch statement.");
return true;
}
else {
// Otherwise create a basic block for the code following the
// "default".
bb = svi->ctx->CreateBasicBlock("default");
svi->defaultBlock = bb;
}
}
// If we saw a "case" or "default" label, then update the map to record
// that the block we just created follows the block created for the
// previous label in the "switch".
if (bb != NULL) {
svi->nextBlock[svi->lastBlock] = bb;
svi->lastBlock = bb;
}
return true;
}
void
SwitchStmt::EmitCode(FunctionEmitContext *ctx) const {
if (ctx->GetCurrentBasicBlock() == NULL)
return;
const Type *type;
if (expr == NULL || ((type = expr->GetType()) == NULL)) {
Assert(m->errorCount > 0);
return;
}
// Basic block we'll end up after the switch statement
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("switch_done");
// Walk the AST of the statements after the 'switch' to collect a bunch
// of information about the structure of the 'case' and 'default'
// statements.
SwitchVisitInfo svi(ctx);
WalkAST(stmts, lSwitchASTPreVisit, NULL, &svi);
// Record that the basic block following the last one created for a
// case/default is the block after the end of the switch statement.
svi.nextBlock[svi.lastBlock] = bbDone;
llvm::Value *exprValue = expr->GetValue(ctx);
if (exprValue == NULL) {
Assert(m->errorCount > 0);
return;
}
ctx->StartSwitch(type->IsUniformType(), bbDone);
ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone,
svi.caseBlocks, svi.nextBlock);
if (stmts != NULL)
stmts->EmitCode(ctx);
if (ctx->GetCurrentBasicBlock() != NULL)
ctx->BranchInst(bbDone);
ctx->SetCurrentBasicBlock(bbDone);
ctx->EndSwitch();
}
void
SwitchStmt::Print(int indent) const {
printf("%*cSwitch Stmt", indent, ' ');
pos.Print();
printf("\n");
printf("%*cexpr = ", indent, ' ');
expr->Print();
printf("\n");
stmts->Print(indent+4);
}
Stmt *
SwitchStmt::TypeCheck() {
const Type *exprType = expr->GetType();
if (exprType == NULL)
return NULL;
const Type *toType = NULL;
exprType = exprType->GetAsConstType();
bool is64bit = (exprType->GetAsUniformType() ==
AtomicType::UniformConstUInt64 ||
exprType->GetAsUniformType() ==
AtomicType::UniformConstInt64);
// FIXME: if there's a break or continue under varying control flow
// within a switch with a "uniform" condition, we promote the condition
// to varying so that everything works out and we are set to handle the
// resulting divergent control flow. This is somewhat sub-optimal; see
// Issue #XXX for details.
bool isUniform = (exprType->IsUniformType() &&
lHasVaryingBreakOrContinue(stmts) == false);
if (isUniform) {
if (is64bit) toType = AtomicType::UniformInt64;
else toType = AtomicType::UniformInt32;
}
else {
if (is64bit) toType = AtomicType::VaryingInt64;
else toType = AtomicType::VaryingInt32;
}
expr = TypeConvertExpr(expr, toType, "switch expression");
if (expr == NULL)
return NULL;
return this;
}
int
SwitchStmt::EstimateCost() const {
const Type *type = expr->GetType();
if (type && type->IsVaryingType())
return COST_VARYING_SWITCH;
else
return COST_UNIFORM_SWITCH;
}
///////////////////////////////////////////////////////////////////////////
// ReturnStmt