Rewrite check for loops for break/continue under varying CF to use WalkAST()
This commit is contained in:
7
ast.cpp
7
ast.cpp
@@ -75,8 +75,11 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
|
|||||||
return;
|
return;
|
||||||
|
|
||||||
// Call the callback function
|
// Call the callback function
|
||||||
if (preFunc != NULL)
|
if (preFunc != NULL) {
|
||||||
preFunc(node, data);
|
if (preFunc(node, data) == false)
|
||||||
|
// The function asked us to not continue recursively, so stop.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////
|
||||||
// Handle Statements
|
// Handle Statements
|
||||||
|
|||||||
14
ast.h
14
ast.h
@@ -92,13 +92,17 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
typedef void (* ASTCallBackFunc)(ASTNode *node, void *data);
|
/** Callback function type for the AST walk.
|
||||||
|
*/
|
||||||
|
typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data);
|
||||||
|
|
||||||
/** Walk (some portion of) an AST, starting from the given root node. At
|
/** Walk (some portion of) an AST, starting from the given root node. At
|
||||||
each node, if preFunc is non-NULL, call it preFunc, passing the given
|
each node, if preFunc is non-NULL, call it, passing the given void
|
||||||
void *data pointer. Makes recursive calls to WalkAST() to process the
|
*data pointer; if the call to preFunc function returns false, then the
|
||||||
node's children; after doing so, postFunc, if non-NULL is called at the
|
children of the node aren't visited. This then makes recursive calls
|
||||||
node. */
|
to WalkAST() to process the node's children; after doing so, calls
|
||||||
|
postFunc, at the node. The return value from the postFunc call is
|
||||||
|
ignored. */
|
||||||
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc,
|
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc,
|
||||||
ASTCallBackFunc postFunc, void *data);
|
ASTCallBackFunc postFunc, void *data);
|
||||||
|
|
||||||
|
|||||||
124
stmt.cpp
124
stmt.cpp
@@ -656,46 +656,54 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
|
|||||||
|
|
||||||
|
|
||||||
/** Given an AST node, check to see if it's safe if we happen to run the
|
/** Given an AST node, check to see if it's safe if we happen to run the
|
||||||
code for that node with the execution mask all off. */
|
code for that node with the execution mask all off.
|
||||||
static void
|
|
||||||
|
FIXME: this is actually a target-specific thing; for non SSE/AVX
|
||||||
|
targets with more complete masking support, some of this won't apply...
|
||||||
|
*/
|
||||||
|
static bool
|
||||||
lCheckAllOffSafety(ASTNode *node, void *data) {
|
lCheckAllOffSafety(ASTNode *node, void *data) {
|
||||||
bool *okPtr = (bool *)data;
|
bool *okPtr = (bool *)data;
|
||||||
|
|
||||||
if (dynamic_cast<FunctionCallExpr *>(node) != NULL)
|
if (dynamic_cast<FunctionCallExpr *>(node) != NULL) {
|
||||||
// FIXME: If we could somehow determine that the function being
|
// FIXME: If we could somehow determine that the function being
|
||||||
// called was safe (and all of the args Exprs were safe, then it'd
|
// called was safe (and all of the args Exprs were safe, then it'd
|
||||||
// be nice to be able to return true here. (Consider a call to
|
// be nice to be able to return true here. (Consider a call to
|
||||||
// e.g. floatbits() in the stdlib.) Unfortunately for now we just
|
// e.g. floatbits() in the stdlib.) Unfortunately for now we just
|
||||||
// have to be conservative.
|
// have to be conservative.
|
||||||
*okPtr = false;
|
*okPtr = false;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
if (dynamic_cast<AssertStmt *>(node) != NULL)
|
if (dynamic_cast<AssertStmt *>(node) != NULL) {
|
||||||
// While this is fine for varying tests, it's not going to be
|
// While it's fine to run the assert for varying tests, it's not
|
||||||
// desirable to check an assert on a uniform variable if all of the
|
// desirable to check an assert on a uniform variable if all of the
|
||||||
// lanes are off.
|
// lanes are off.
|
||||||
*okPtr = false;
|
*okPtr = false;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
IndexExpr *ie;
|
IndexExpr *ie;
|
||||||
if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) {
|
if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) {
|
||||||
const Type *type = ie->baseExpr->GetType();
|
const Type *type = ie->baseExpr->GetType();
|
||||||
if (type == NULL)
|
if (type == NULL)
|
||||||
return;
|
return true;
|
||||||
if (dynamic_cast<const ReferenceType *>(type) != NULL)
|
if (dynamic_cast<const ReferenceType *>(type) != NULL)
|
||||||
type = type->GetReferenceTarget();
|
type = type->GetReferenceTarget();
|
||||||
|
|
||||||
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
|
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
|
||||||
if (ce == NULL) {
|
if (ce == NULL) {
|
||||||
// indexing with a variable...
|
// indexing with a variable... -> not safe
|
||||||
*okPtr = false;
|
*okPtr = false;
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const PointerType *pointerType =
|
const PointerType *pointerType =
|
||||||
dynamic_cast<const PointerType *>(type);
|
dynamic_cast<const PointerType *>(type);
|
||||||
if (pointerType != NULL) {
|
if (pointerType != NULL) {
|
||||||
// pointer[index] -> can't be sure
|
// pointer[index] -> can't be sure -> not safe
|
||||||
*okPtr = false;
|
*okPtr = false;
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const SequentialType *seqType =
|
const SequentialType *seqType =
|
||||||
@@ -703,22 +711,25 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
|
|||||||
Assert(seqType != NULL);
|
Assert(seqType != NULL);
|
||||||
int nElements = seqType->GetElementCount();
|
int nElements = seqType->GetElementCount();
|
||||||
if (nElements == 0) {
|
if (nElements == 0) {
|
||||||
// Unsized array, so we can't be sure
|
// Unsized array, so we can't be sure -> not safe
|
||||||
*okPtr = false;
|
*okPtr = false;
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t indices[ISPC_MAX_NVEC];
|
int32_t indices[ISPC_MAX_NVEC];
|
||||||
int count = ce->AsInt32(indices);
|
int count = ce->AsInt32(indices);
|
||||||
for (int i = 0; i < count; ++i) {
|
for (int i = 0; i < count; ++i) {
|
||||||
if (indices[i] < 0 || indices[i] >= nElements) {
|
if (indices[i] < 0 || indices[i] >= nElements) {
|
||||||
|
// Index is out of bounds -> not safe
|
||||||
*okPtr = false;
|
*okPtr = false;
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// All indices are in-bounds
|
// All indices are in-bounds
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -974,6 +985,83 @@ lHasVaryingBreakOrContinue(Stmt *stmt, bool inVaryingCF = false) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
struct VaryingBCCheckInfo {
|
||||||
|
VaryingBCCheckInfo() {
|
||||||
|
varyingControlFlowDepth = 0;
|
||||||
|
foundVaryingBreakOrContinue = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int varyingControlFlowDepth;
|
||||||
|
bool foundVaryingBreakOrContinue;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/** Returns true if the given node is an 'if' statement where the test
|
||||||
|
condition has varying type. */
|
||||||
|
static bool
|
||||||
|
lIsVaryingFor(ASTNode *node) {
|
||||||
|
IfStmt *ifStmt;
|
||||||
|
if ((ifStmt = dynamic_cast<IfStmt *>(node)) != NULL &&
|
||||||
|
ifStmt->test != NULL) {
|
||||||
|
const Type *type = ifStmt->test->GetType();
|
||||||
|
return (type != NULL && type->IsVaryingType());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Preorder callback function for checking for varying breaks or
|
||||||
|
continues. */
|
||||||
|
static bool
|
||||||
|
lVaryingBCPreFunc(ASTNode *node, void *d) {
|
||||||
|
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
|
||||||
|
|
||||||
|
// We found a break or continue statement; if we're under varying
|
||||||
|
// control flow, then bingo.
|
||||||
|
if ((dynamic_cast<BreakStmt *>(node) != NULL ||
|
||||||
|
dynamic_cast<ContinueStmt *>(node) != NULL) &&
|
||||||
|
info->varyingControlFlowDepth > 0) {
|
||||||
|
info->foundVaryingBreakOrContinue = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the count of the nesting depth of varying control flow if
|
||||||
|
// this is an if statement with a varying condition.
|
||||||
|
if (lIsVaryingFor(node))
|
||||||
|
++info->varyingControlFlowDepth;
|
||||||
|
|
||||||
|
if (dynamic_cast<ForStmt *>(node) != NULL ||
|
||||||
|
dynamic_cast<DoStmt *>(node) != NULL ||
|
||||||
|
dynamic_cast<ForeachStmt *>(node) != NULL)
|
||||||
|
// Don't recurse into these guys, since we don't care about varying
|
||||||
|
// breaks or continues within them...
|
||||||
|
return false;
|
||||||
|
else
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Postorder callback function for checking for varying breaks or
|
||||||
|
continues; decrement the varying control flow depth after the node's
|
||||||
|
children have been processed, if this is a varying if statement. */
|
||||||
|
static bool
|
||||||
|
lVaryingBCPostFunc(ASTNode *node, void *d) {
|
||||||
|
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
|
||||||
|
if (lIsVaryingFor(node))
|
||||||
|
--info->varyingControlFlowDepth;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static bool
|
||||||
|
lHasVaryingBreakOrContinue2(Stmt *stmt) {
|
||||||
|
VaryingBCCheckInfo info;
|
||||||
|
WalkAST(stmt, lVaryingBCPreFunc, lVaryingBCPostFunc, &info);
|
||||||
|
return info.foundVaryingBreakOrContinue;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p)
|
DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p)
|
||||||
: Stmt(p), testExpr(t), bodyStmts(s),
|
: Stmt(p), testExpr(t), bodyStmts(s),
|
||||||
doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) {
|
doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) {
|
||||||
@@ -1124,6 +1212,9 @@ DoStmt::TypeCheck() {
|
|||||||
// code generated for the loop includes masking stuff, so
|
// code generated for the loop includes masking stuff, so
|
||||||
// that we can track which lanes actually want to be
|
// that we can track which lanes actually want to be
|
||||||
// running, accounting for breaks/continues.
|
// running, accounting for breaks/continues.
|
||||||
|
assert(lHasVaryingBreakOrContinue(bodyStmts) ==
|
||||||
|
lHasVaryingBreakOrContinue2(bodyStmts));
|
||||||
|
|
||||||
bool uniformTest = (testType->IsUniformType() &&
|
bool uniformTest = (testType->IsUniformType() &&
|
||||||
!g->opt.disableUniformControlFlow &&
|
!g->opt.disableUniformControlFlow &&
|
||||||
!lHasVaryingBreakOrContinue(bodyStmts));
|
!lHasVaryingBreakOrContinue(bodyStmts));
|
||||||
@@ -1185,6 +1276,9 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const {
|
|||||||
(!g->opt.disableUniformControlFlow &&
|
(!g->opt.disableUniformControlFlow &&
|
||||||
!lHasVaryingBreakOrContinue(stmts));
|
!lHasVaryingBreakOrContinue(stmts));
|
||||||
|
|
||||||
|
assert(lHasVaryingBreakOrContinue(stmts) ==
|
||||||
|
lHasVaryingBreakOrContinue2(stmts));
|
||||||
|
|
||||||
ctx->StartLoop(bexit, bstep, uniformTest);
|
ctx->StartLoop(bexit, bstep, uniformTest);
|
||||||
ctx->SetDebugPos(pos);
|
ctx->SetDebugPos(pos);
|
||||||
|
|
||||||
@@ -1328,6 +1422,8 @@ ForStmt::TypeCheck() {
|
|||||||
|
|
||||||
// See comments in DoStmt::TypeCheck() regarding
|
// See comments in DoStmt::TypeCheck() regarding
|
||||||
// 'uniformTest' and the type cast here.
|
// 'uniformTest' and the type cast here.
|
||||||
|
assert(lHasVaryingBreakOrContinue(stmts) ==
|
||||||
|
lHasVaryingBreakOrContinue2(stmts));
|
||||||
bool uniformTest = (testType->IsUniformType() &&
|
bool uniformTest = (testType->IsUniformType() &&
|
||||||
!g->opt.disableUniformControlFlow &&
|
!g->opt.disableUniformControlFlow &&
|
||||||
!lHasVaryingBreakOrContinue(stmts));
|
!lHasVaryingBreakOrContinue(stmts));
|
||||||
@@ -1353,6 +1449,8 @@ ForStmt::EstimateCost() const {
|
|||||||
bool uniformTest = test ? test->GetType()->IsUniformType() :
|
bool uniformTest = test ? test->GetType()->IsUniformType() :
|
||||||
(!g->opt.disableUniformControlFlow &&
|
(!g->opt.disableUniformControlFlow &&
|
||||||
!lHasVaryingBreakOrContinue(stmts));
|
!lHasVaryingBreakOrContinue(stmts));
|
||||||
|
assert(lHasVaryingBreakOrContinue(stmts) ==
|
||||||
|
lHasVaryingBreakOrContinue2(stmts));
|
||||||
|
|
||||||
return ((init ? init->EstimateCost() : 0) +
|
return ((init ? init->EstimateCost() : 0) +
|
||||||
(test ? test->EstimateCost() : 0) +
|
(test ? test->EstimateCost() : 0) +
|
||||||
|
|||||||
Reference in New Issue
Block a user