Add support for "switch" statements.
Switches with both uniform and varying "switch" expressions are supported. Switch statements with varying expressions and very large numbers of labels may not perform well; some issues to be filed shortly will track opportunities for improving these.
This commit is contained in:
301
stmt.cpp
301
stmt.cpp
@@ -1886,6 +1886,307 @@ ForeachStmt::Print(int indent) const {
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// CaseStmt
|
||||
|
||||
/** Given the statements following a 'case' or 'default' label, this
|
||||
function determines whether the mask should be checked to see if it is
|
||||
"all off" immediately after the label, before executing the code for
|
||||
the statements.
|
||||
*/
|
||||
static bool
|
||||
lCheckMask(Stmt *stmts) {
|
||||
if (stmts == NULL)
|
||||
return false;
|
||||
|
||||
int cost = EstimateCost(stmts);
|
||||
|
||||
bool safeToRunWithAllLanesOff = true;
|
||||
WalkAST(stmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
|
||||
|
||||
// The mask should be checked if the code following the
|
||||
// 'case'/'default' is relatively complex, or if it would be unsafe to
|
||||
// run that code with the execution mask all off.
|
||||
return (cost > PREDICATE_SAFE_IF_STATEMENT_COST ||
|
||||
safeToRunWithAllLanesOff == false);
|
||||
}
|
||||
|
||||
|
||||
CaseStmt::CaseStmt(int v, Stmt *s, SourcePos pos)
|
||||
: Stmt(pos), value(v) {
|
||||
stmts = s;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
CaseStmt::EmitCode(FunctionEmitContext *ctx) const {
|
||||
ctx->EmitCaseLabel(value, lCheckMask(stmts), pos);
|
||||
if (stmts)
|
||||
stmts->EmitCode(ctx);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
CaseStmt::Print(int indent) const {
|
||||
printf("%*cCase [%d] label", indent, ' ', value);
|
||||
pos.Print();
|
||||
printf("\n");
|
||||
stmts->Print(indent+4);
|
||||
}
|
||||
|
||||
|
||||
Stmt *
|
||||
CaseStmt::TypeCheck() {
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
CaseStmt::EstimateCost() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// DefaultStmt
|
||||
|
||||
DefaultStmt::DefaultStmt(Stmt *s, SourcePos pos)
|
||||
: Stmt(pos) {
|
||||
stmts = s;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
DefaultStmt::EmitCode(FunctionEmitContext *ctx) const {
|
||||
ctx->EmitDefaultLabel(lCheckMask(stmts), pos);
|
||||
if (stmts)
|
||||
stmts->EmitCode(ctx);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
DefaultStmt::Print(int indent) const {
|
||||
printf("%*cDefault Stmt", indent, ' ');
|
||||
pos.Print();
|
||||
printf("\n");
|
||||
stmts->Print(indent+4);
|
||||
}
|
||||
|
||||
|
||||
Stmt *
|
||||
DefaultStmt::TypeCheck() {
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
DefaultStmt::EstimateCost() const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// SwitchStmt
|
||||
|
||||
SwitchStmt::SwitchStmt(Expr *e, Stmt *s, SourcePos pos)
|
||||
: Stmt(pos) {
|
||||
expr = e;
|
||||
stmts = s;
|
||||
}
|
||||
|
||||
|
||||
/* An instance of this structure is carried along as we traverse the AST
|
||||
nodes for the statements after a "switch" statement. We use this
|
||||
structure to record all of the 'case' and 'default' statements after the
|
||||
"switch". */
|
||||
struct SwitchVisitInfo {
|
||||
SwitchVisitInfo(FunctionEmitContext *c) {
|
||||
ctx = c;
|
||||
defaultBlock = NULL;
|
||||
lastBlock = NULL;
|
||||
}
|
||||
|
||||
FunctionEmitContext *ctx;
|
||||
|
||||
/* Basic block for the code following the "default" label (if any). */
|
||||
llvm::BasicBlock *defaultBlock;
|
||||
|
||||
/* Map from integer values after "case" labels to the basic blocks that
|
||||
follow the corresponding "case" label. */
|
||||
std::vector<std::pair<int, llvm::BasicBlock *> > caseBlocks;
|
||||
|
||||
/* For each basic block for a "case" label or a "default" label,
|
||||
nextBlock[block] stores the basic block pointer for the next
|
||||
subsequent "case" or "default" label in the program. */
|
||||
std::map<llvm::BasicBlock *, llvm::BasicBlock *> nextBlock;
|
||||
|
||||
/* The last basic block created for a "case" or "default" label; when
|
||||
we create the basic block for the next one, we'll use this to update
|
||||
the nextBlock map<> above. */
|
||||
llvm::BasicBlock *lastBlock;
|
||||
};
|
||||
|
||||
|
||||
static bool
|
||||
lSwitchASTPreVisit(ASTNode *node, void *d) {
|
||||
if (dynamic_cast<SwitchStmt *>(node) != NULL)
|
||||
// don't continue recursively into a nested switch--we only want
|
||||
// our own case and default statements!
|
||||
return false;
|
||||
|
||||
CaseStmt *cs = dynamic_cast<CaseStmt *>(node);
|
||||
DefaultStmt *ds = dynamic_cast<DefaultStmt *>(node);
|
||||
|
||||
SwitchVisitInfo *svi = (SwitchVisitInfo *)d;
|
||||
llvm::BasicBlock *bb = NULL;
|
||||
if (cs != NULL) {
|
||||
// Complain if we've seen a case statement with the same value
|
||||
// already
|
||||
for (int i = 0; i < (int)svi->caseBlocks.size(); ++i) {
|
||||
if (svi->caseBlocks[i].first == cs->value) {
|
||||
Error(cs->pos, "Duplicate case value \"%d\".", cs->value);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise create a new basic block for the code following this
|
||||
// 'case' statement and record the mappign between the case label
|
||||
// value and the basic block
|
||||
char buf[32];
|
||||
sprintf(buf, "case_%d", cs->value);
|
||||
bb = svi->ctx->CreateBasicBlock(buf);
|
||||
svi->caseBlocks.push_back(std::make_pair(cs->value, bb));
|
||||
}
|
||||
else if (ds != NULL) {
|
||||
// And complain if we've seen another 'default' label..
|
||||
if (svi->defaultBlock != NULL) {
|
||||
Error(ds->pos, "Multiple \"default\" lables in switch statement.");
|
||||
return true;
|
||||
}
|
||||
else {
|
||||
// Otherwise create a basic block for the code following the
|
||||
// "default".
|
||||
bb = svi->ctx->CreateBasicBlock("default");
|
||||
svi->defaultBlock = bb;
|
||||
}
|
||||
}
|
||||
|
||||
// If we saw a "case" or "default" label, then update the map to record
|
||||
// that the block we just created follows the block created for the
|
||||
// previous label in the "switch".
|
||||
if (bb != NULL) {
|
||||
svi->nextBlock[svi->lastBlock] = bb;
|
||||
svi->lastBlock = bb;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
SwitchStmt::EmitCode(FunctionEmitContext *ctx) const {
|
||||
if (ctx->GetCurrentBasicBlock() == NULL)
|
||||
return;
|
||||
|
||||
const Type *type;
|
||||
if (expr == NULL || ((type = expr->GetType()) == NULL)) {
|
||||
Assert(m->errorCount > 0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Basic block we'll end up after the switch statement
|
||||
llvm::BasicBlock *bbDone = ctx->CreateBasicBlock("switch_done");
|
||||
|
||||
// Walk the AST of the statements after the 'switch' to collect a bunch
|
||||
// of information about the structure of the 'case' and 'default'
|
||||
// statements.
|
||||
SwitchVisitInfo svi(ctx);
|
||||
WalkAST(stmts, lSwitchASTPreVisit, NULL, &svi);
|
||||
// Record that the basic block following the last one created for a
|
||||
// case/default is the block after the end of the switch statement.
|
||||
svi.nextBlock[svi.lastBlock] = bbDone;
|
||||
|
||||
llvm::Value *exprValue = expr->GetValue(ctx);
|
||||
if (exprValue == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return;
|
||||
}
|
||||
|
||||
ctx->StartSwitch(type->IsUniformType(), bbDone);
|
||||
ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone,
|
||||
svi.caseBlocks, svi.nextBlock);
|
||||
|
||||
if (stmts != NULL)
|
||||
stmts->EmitCode(ctx);
|
||||
|
||||
if (ctx->GetCurrentBasicBlock() != NULL)
|
||||
ctx->BranchInst(bbDone);
|
||||
|
||||
ctx->SetCurrentBasicBlock(bbDone);
|
||||
ctx->EndSwitch();
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
SwitchStmt::Print(int indent) const {
|
||||
printf("%*cSwitch Stmt", indent, ' ');
|
||||
pos.Print();
|
||||
printf("\n");
|
||||
printf("%*cexpr = ", indent, ' ');
|
||||
expr->Print();
|
||||
printf("\n");
|
||||
stmts->Print(indent+4);
|
||||
}
|
||||
|
||||
|
||||
Stmt *
|
||||
SwitchStmt::TypeCheck() {
|
||||
const Type *exprType = expr->GetType();
|
||||
if (exprType == NULL)
|
||||
return NULL;
|
||||
|
||||
const Type *toType = NULL;
|
||||
exprType = exprType->GetAsConstType();
|
||||
bool is64bit = (exprType->GetAsUniformType() ==
|
||||
AtomicType::UniformConstUInt64 ||
|
||||
exprType->GetAsUniformType() ==
|
||||
AtomicType::UniformConstInt64);
|
||||
|
||||
// FIXME: if there's a break or continue under varying control flow
|
||||
// within a switch with a "uniform" condition, we promote the condition
|
||||
// to varying so that everything works out and we are set to handle the
|
||||
// resulting divergent control flow. This is somewhat sub-optimal; see
|
||||
// Issue #XXX for details.
|
||||
bool isUniform = (exprType->IsUniformType() &&
|
||||
lHasVaryingBreakOrContinue(stmts) == false);
|
||||
|
||||
if (isUniform) {
|
||||
if (is64bit) toType = AtomicType::UniformInt64;
|
||||
else toType = AtomicType::UniformInt32;
|
||||
}
|
||||
else {
|
||||
if (is64bit) toType = AtomicType::VaryingInt64;
|
||||
else toType = AtomicType::VaryingInt32;
|
||||
}
|
||||
|
||||
expr = TypeConvertExpr(expr, toType, "switch expression");
|
||||
if (expr == NULL)
|
||||
return NULL;
|
||||
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
SwitchStmt::EstimateCost() const {
|
||||
const Type *type = expr->GetType();
|
||||
if (type && type->IsVaryingType())
|
||||
return COST_VARYING_SWITCH;
|
||||
else
|
||||
return COST_UNIFORM_SWITCH;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// ReturnStmt
|
||||
|
||||
|
||||
Reference in New Issue
Block a user