Rewrite check for loops for break/continue under varying CF to use WalkAST()

This commit is contained in:
Matt Pharr
2011-12-16 10:44:37 -08:00
parent 45767ad197
commit 34eda04d9b
3 changed files with 125 additions and 20 deletions

View File

@@ -75,8 +75,11 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
return;
// Call the callback function
if (preFunc != NULL)
preFunc(node, data);
if (preFunc != NULL) {
if (preFunc(node, data) == false)
// The function asked us to not continue recursively, so stop.
return;
}
////////////////////////////////////////////////////////////////////////////
// Handle Statements

14
ast.h
View File

@@ -92,13 +92,17 @@ private:
};
typedef void (* ASTCallBackFunc)(ASTNode *node, void *data);
/** Callback function type for the AST walk.
*/
typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data);
/** Walk (some portion of) an AST, starting from the given root node. At
each node, if preFunc is non-NULL, call it preFunc, passing the given
void *data pointer. Makes recursive calls to WalkAST() to process the
node's children; after doing so, postFunc, if non-NULL is called at the
node. */
each node, if preFunc is non-NULL, call it, passing the given void
*data pointer; if the call to preFunc function returns false, then the
children of the node aren't visited. This then makes recursive calls
to WalkAST() to process the node's children; after doing so, calls
postFunc, at the node. The return value from the postFunc call is
ignored. */
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc,
ASTCallBackFunc postFunc, void *data);

124
stmt.cpp
View File

@@ -656,46 +656,54 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
/** Given an AST node, check to see if it's safe if we happen to run the
code for that node with the execution mask all off. */
static void
code for that node with the execution mask all off.
FIXME: this is actually a target-specific thing; for non SSE/AVX
targets with more complete masking support, some of this won't apply...
*/
static bool
lCheckAllOffSafety(ASTNode *node, void *data) {
bool *okPtr = (bool *)data;
if (dynamic_cast<FunctionCallExpr *>(node) != NULL)
if (dynamic_cast<FunctionCallExpr *>(node) != NULL) {
// FIXME: If we could somehow determine that the function being
// called was safe (and all of the args Exprs were safe, then it'd
// be nice to be able to return true here. (Consider a call to
// e.g. floatbits() in the stdlib.) Unfortunately for now we just
// have to be conservative.
*okPtr = false;
return false;
}
if (dynamic_cast<AssertStmt *>(node) != NULL)
// While this is fine for varying tests, it's not going to be
if (dynamic_cast<AssertStmt *>(node) != NULL) {
// While it's fine to run the assert for varying tests, it's not
// desirable to check an assert on a uniform variable if all of the
// lanes are off.
*okPtr = false;
return false;
}
IndexExpr *ie;
if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) {
const Type *type = ie->baseExpr->GetType();
if (type == NULL)
return;
return true;
if (dynamic_cast<const ReferenceType *>(type) != NULL)
type = type->GetReferenceTarget();
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
if (ce == NULL) {
// indexing with a variable...
// indexing with a variable... -> not safe
*okPtr = false;
return;
return false;
}
const PointerType *pointerType =
dynamic_cast<const PointerType *>(type);
if (pointerType != NULL) {
// pointer[index] -> can't be sure
// pointer[index] -> can't be sure -> not safe
*okPtr = false;
return;
return false;
}
const SequentialType *seqType =
@@ -703,22 +711,25 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
Assert(seqType != NULL);
int nElements = seqType->GetElementCount();
if (nElements == 0) {
// Unsized array, so we can't be sure
// Unsized array, so we can't be sure -> not safe
*okPtr = false;
return;
return false;
}
int32_t indices[ISPC_MAX_NVEC];
int count = ce->AsInt32(indices);
for (int i = 0; i < count; ++i) {
if (indices[i] < 0 || indices[i] >= nElements) {
// Index is out of bounds -> not safe
*okPtr = false;
return;
return false;
}
}
// All indices are in-bounds
}
return true;
}
@@ -974,6 +985,83 @@ lHasVaryingBreakOrContinue(Stmt *stmt, bool inVaryingCF = false) {
}
struct VaryingBCCheckInfo {
VaryingBCCheckInfo() {
varyingControlFlowDepth = 0;
foundVaryingBreakOrContinue = false;
}
int varyingControlFlowDepth;
bool foundVaryingBreakOrContinue;
};
/** Returns true if the given node is an 'if' statement where the test
condition has varying type. */
static bool
lIsVaryingFor(ASTNode *node) {
IfStmt *ifStmt;
if ((ifStmt = dynamic_cast<IfStmt *>(node)) != NULL &&
ifStmt->test != NULL) {
const Type *type = ifStmt->test->GetType();
return (type != NULL && type->IsVaryingType());
}
else
return false;
}
/** Preorder callback function for checking for varying breaks or
continues. */
static bool
lVaryingBCPreFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
// We found a break or continue statement; if we're under varying
// control flow, then bingo.
if ((dynamic_cast<BreakStmt *>(node) != NULL ||
dynamic_cast<ContinueStmt *>(node) != NULL) &&
info->varyingControlFlowDepth > 0) {
info->foundVaryingBreakOrContinue = true;
return false;
}
// Update the count of the nesting depth of varying control flow if
// this is an if statement with a varying condition.
if (lIsVaryingFor(node))
++info->varyingControlFlowDepth;
if (dynamic_cast<ForStmt *>(node) != NULL ||
dynamic_cast<DoStmt *>(node) != NULL ||
dynamic_cast<ForeachStmt *>(node) != NULL)
// Don't recurse into these guys, since we don't care about varying
// breaks or continues within them...
return false;
else
return true;
}
/** Postorder callback function for checking for varying breaks or
continues; decrement the varying control flow depth after the node's
children have been processed, if this is a varying if statement. */
static bool
lVaryingBCPostFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
if (lIsVaryingFor(node))
--info->varyingControlFlowDepth;
return true;
}
static bool
lHasVaryingBreakOrContinue2(Stmt *stmt) {
VaryingBCCheckInfo info;
WalkAST(stmt, lVaryingBCPreFunc, lVaryingBCPostFunc, &info);
return info.foundVaryingBreakOrContinue;
}
DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p)
: Stmt(p), testExpr(t), bodyStmts(s),
doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) {
@@ -1124,6 +1212,9 @@ DoStmt::TypeCheck() {
// code generated for the loop includes masking stuff, so
// that we can track which lanes actually want to be
// running, accounting for breaks/continues.
assert(lHasVaryingBreakOrContinue(bodyStmts) ==
lHasVaryingBreakOrContinue2(bodyStmts));
bool uniformTest = (testType->IsUniformType() &&
!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(bodyStmts));
@@ -1185,6 +1276,9 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const {
(!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts));
assert(lHasVaryingBreakOrContinue(stmts) ==
lHasVaryingBreakOrContinue2(stmts));
ctx->StartLoop(bexit, bstep, uniformTest);
ctx->SetDebugPos(pos);
@@ -1328,6 +1422,8 @@ ForStmt::TypeCheck() {
// See comments in DoStmt::TypeCheck() regarding
// 'uniformTest' and the type cast here.
assert(lHasVaryingBreakOrContinue(stmts) ==
lHasVaryingBreakOrContinue2(stmts));
bool uniformTest = (testType->IsUniformType() &&
!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts));
@@ -1353,6 +1449,8 @@ ForStmt::EstimateCost() const {
bool uniformTest = test ? test->GetType()->IsUniformType() :
(!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts));
assert(lHasVaryingBreakOrContinue(stmts) ==
lHasVaryingBreakOrContinue2(stmts));
return ((init ? init->EstimateCost() : 0) +
(test ? test->EstimateCost() : 0) +