Rewrite check for loops for break/continue under varying CF to use WalkAST()

This commit is contained in:
Matt Pharr
2011-12-16 10:44:37 -08:00
parent 45767ad197
commit 34eda04d9b
3 changed files with 125 additions and 20 deletions

View File

@@ -75,8 +75,11 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
return; return;
// Call the callback function // Call the callback function
if (preFunc != NULL) if (preFunc != NULL) {
preFunc(node, data); if (preFunc(node, data) == false)
// The function asked us to not continue recursively, so stop.
return;
}
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Handle Statements // Handle Statements

14
ast.h
View File

@@ -92,13 +92,17 @@ private:
}; };
typedef void (* ASTCallBackFunc)(ASTNode *node, void *data); /** Callback function type for the AST walk.
*/
typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data);
/** Walk (some portion of) an AST, starting from the given root node. At /** Walk (some portion of) an AST, starting from the given root node. At
each node, if preFunc is non-NULL, call it preFunc, passing the given each node, if preFunc is non-NULL, call it, passing the given void
void *data pointer. Makes recursive calls to WalkAST() to process the *data pointer; if the call to preFunc function returns false, then the
node's children; after doing so, postFunc, if non-NULL is called at the children of the node aren't visited. This then makes recursive calls
node. */ to WalkAST() to process the node's children; after doing so, calls
postFunc, at the node. The return value from the postFunc call is
ignored. */
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc, extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc,
ASTCallBackFunc postFunc, void *data); ASTCallBackFunc postFunc, void *data);

124
stmt.cpp
View File

@@ -656,46 +656,54 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
/** Given an AST node, check to see if it's safe if we happen to run the /** Given an AST node, check to see if it's safe if we happen to run the
code for that node with the execution mask all off. */ code for that node with the execution mask all off.
static void
FIXME: this is actually a target-specific thing; for non SSE/AVX
targets with more complete masking support, some of this won't apply...
*/
static bool
lCheckAllOffSafety(ASTNode *node, void *data) { lCheckAllOffSafety(ASTNode *node, void *data) {
bool *okPtr = (bool *)data; bool *okPtr = (bool *)data;
if (dynamic_cast<FunctionCallExpr *>(node) != NULL) if (dynamic_cast<FunctionCallExpr *>(node) != NULL) {
// FIXME: If we could somehow determine that the function being // FIXME: If we could somehow determine that the function being
// called was safe (and all of the args Exprs were safe, then it'd // called was safe (and all of the args Exprs were safe, then it'd
// be nice to be able to return true here. (Consider a call to // be nice to be able to return true here. (Consider a call to
// e.g. floatbits() in the stdlib.) Unfortunately for now we just // e.g. floatbits() in the stdlib.) Unfortunately for now we just
// have to be conservative. // have to be conservative.
*okPtr = false; *okPtr = false;
return false;
}
if (dynamic_cast<AssertStmt *>(node) != NULL) if (dynamic_cast<AssertStmt *>(node) != NULL) {
// While this is fine for varying tests, it's not going to be // While it's fine to run the assert for varying tests, it's not
// desirable to check an assert on a uniform variable if all of the // desirable to check an assert on a uniform variable if all of the
// lanes are off. // lanes are off.
*okPtr = false; *okPtr = false;
return false;
}
IndexExpr *ie; IndexExpr *ie;
if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) { if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) {
const Type *type = ie->baseExpr->GetType(); const Type *type = ie->baseExpr->GetType();
if (type == NULL) if (type == NULL)
return; return true;
if (dynamic_cast<const ReferenceType *>(type) != NULL) if (dynamic_cast<const ReferenceType *>(type) != NULL)
type = type->GetReferenceTarget(); type = type->GetReferenceTarget();
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index); ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
if (ce == NULL) { if (ce == NULL) {
// indexing with a variable... // indexing with a variable... -> not safe
*okPtr = false; *okPtr = false;
return; return false;
} }
const PointerType *pointerType = const PointerType *pointerType =
dynamic_cast<const PointerType *>(type); dynamic_cast<const PointerType *>(type);
if (pointerType != NULL) { if (pointerType != NULL) {
// pointer[index] -> can't be sure // pointer[index] -> can't be sure -> not safe
*okPtr = false; *okPtr = false;
return; return false;
} }
const SequentialType *seqType = const SequentialType *seqType =
@@ -703,22 +711,25 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
Assert(seqType != NULL); Assert(seqType != NULL);
int nElements = seqType->GetElementCount(); int nElements = seqType->GetElementCount();
if (nElements == 0) { if (nElements == 0) {
// Unsized array, so we can't be sure // Unsized array, so we can't be sure -> not safe
*okPtr = false; *okPtr = false;
return; return false;
} }
int32_t indices[ISPC_MAX_NVEC]; int32_t indices[ISPC_MAX_NVEC];
int count = ce->AsInt32(indices); int count = ce->AsInt32(indices);
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
if (indices[i] < 0 || indices[i] >= nElements) { if (indices[i] < 0 || indices[i] >= nElements) {
// Index is out of bounds -> not safe
*okPtr = false; *okPtr = false;
return; return false;
} }
} }
// All indices are in-bounds // All indices are in-bounds
} }
return true;
} }
@@ -974,6 +985,83 @@ lHasVaryingBreakOrContinue(Stmt *stmt, bool inVaryingCF = false) {
} }
struct VaryingBCCheckInfo {
VaryingBCCheckInfo() {
varyingControlFlowDepth = 0;
foundVaryingBreakOrContinue = false;
}
int varyingControlFlowDepth;
bool foundVaryingBreakOrContinue;
};
/** Returns true if the given node is an 'if' statement where the test
condition has varying type. */
static bool
lIsVaryingFor(ASTNode *node) {
IfStmt *ifStmt;
if ((ifStmt = dynamic_cast<IfStmt *>(node)) != NULL &&
ifStmt->test != NULL) {
const Type *type = ifStmt->test->GetType();
return (type != NULL && type->IsVaryingType());
}
else
return false;
}
/** Preorder callback function for checking for varying breaks or
continues. */
static bool
lVaryingBCPreFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
// We found a break or continue statement; if we're under varying
// control flow, then bingo.
if ((dynamic_cast<BreakStmt *>(node) != NULL ||
dynamic_cast<ContinueStmt *>(node) != NULL) &&
info->varyingControlFlowDepth > 0) {
info->foundVaryingBreakOrContinue = true;
return false;
}
// Update the count of the nesting depth of varying control flow if
// this is an if statement with a varying condition.
if (lIsVaryingFor(node))
++info->varyingControlFlowDepth;
if (dynamic_cast<ForStmt *>(node) != NULL ||
dynamic_cast<DoStmt *>(node) != NULL ||
dynamic_cast<ForeachStmt *>(node) != NULL)
// Don't recurse into these guys, since we don't care about varying
// breaks or continues within them...
return false;
else
return true;
}
/** Postorder callback function for checking for varying breaks or
continues; decrement the varying control flow depth after the node's
children have been processed, if this is a varying if statement. */
static bool
lVaryingBCPostFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
if (lIsVaryingFor(node))
--info->varyingControlFlowDepth;
return true;
}
static bool
lHasVaryingBreakOrContinue2(Stmt *stmt) {
VaryingBCCheckInfo info;
WalkAST(stmt, lVaryingBCPreFunc, lVaryingBCPostFunc, &info);
return info.foundVaryingBreakOrContinue;
}
DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p) DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p)
: Stmt(p), testExpr(t), bodyStmts(s), : Stmt(p), testExpr(t), bodyStmts(s),
doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) { doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) {
@@ -1124,6 +1212,9 @@ DoStmt::TypeCheck() {
// code generated for the loop includes masking stuff, so // code generated for the loop includes masking stuff, so
// that we can track which lanes actually want to be // that we can track which lanes actually want to be
// running, accounting for breaks/continues. // running, accounting for breaks/continues.
assert(lHasVaryingBreakOrContinue(bodyStmts) ==
lHasVaryingBreakOrContinue2(bodyStmts));
bool uniformTest = (testType->IsUniformType() && bool uniformTest = (testType->IsUniformType() &&
!g->opt.disableUniformControlFlow && !g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(bodyStmts)); !lHasVaryingBreakOrContinue(bodyStmts));
@@ -1185,6 +1276,9 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const {
(!g->opt.disableUniformControlFlow && (!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts)); !lHasVaryingBreakOrContinue(stmts));
assert(lHasVaryingBreakOrContinue(stmts) ==
lHasVaryingBreakOrContinue2(stmts));
ctx->StartLoop(bexit, bstep, uniformTest); ctx->StartLoop(bexit, bstep, uniformTest);
ctx->SetDebugPos(pos); ctx->SetDebugPos(pos);
@@ -1328,6 +1422,8 @@ ForStmt::TypeCheck() {
// See comments in DoStmt::TypeCheck() regarding // See comments in DoStmt::TypeCheck() regarding
// 'uniformTest' and the type cast here. // 'uniformTest' and the type cast here.
assert(lHasVaryingBreakOrContinue(stmts) ==
lHasVaryingBreakOrContinue2(stmts));
bool uniformTest = (testType->IsUniformType() && bool uniformTest = (testType->IsUniformType() &&
!g->opt.disableUniformControlFlow && !g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts)); !lHasVaryingBreakOrContinue(stmts));
@@ -1353,6 +1449,8 @@ ForStmt::EstimateCost() const {
bool uniformTest = test ? test->GetType()->IsUniformType() : bool uniformTest = test ? test->GetType()->IsUniformType() :
(!g->opt.disableUniformControlFlow && (!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts)); !lHasVaryingBreakOrContinue(stmts));
assert(lHasVaryingBreakOrContinue(stmts) ==
lHasVaryingBreakOrContinue2(stmts));
return ((init ? init->EstimateCost() : 0) + return ((init ? init->EstimateCost() : 0) +
(test ? test->EstimateCost() : 0) + (test ? test->EstimateCost() : 0) +