Add WalkAST() function for generic AST walking.

For starters, use it for the check to see if code is safe to run with the
mask all off.

This also fixes a bug where we would sometimes incorrectly say that
a whole block of code was unsafe to run with an all off mask because we came
to a NULL AST node during traversal.
This commit is contained in:
Matt Pharr
2011-12-15 16:52:47 -08:00
parent 6f6e28077f
commit f9463af75b
3 changed files with 234 additions and 3 deletions

148
ast.cpp
View File

@@ -36,8 +36,11 @@
*/
#include "ast.h"
#include "expr.h"
#include "func.h"
#include "stmt.h"
#include "sym.h"
#include "util.h"
///////////////////////////////////////////////////////////////////////////
// ASTNode
@@ -63,3 +66,148 @@ AST::GenerateIR() {
functions[i]->GenerateIR();
}
///////////////////////////////////////////////////////////////////////////
void
WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
void *data) {
if (node == NULL)
return;
// Call the callback function
if (preFunc != NULL)
preFunc(node, data);
////////////////////////////////////////////////////////////////////////////
// Handle Statements
if (dynamic_cast<Stmt *>(node) != NULL) {
ExprStmt *es;
DeclStmt *ds;
IfStmt *is;
DoStmt *dos;
ForStmt *fs;
ForeachStmt *fes;
ReturnStmt *rs;
StmtList *sl;
PrintStmt *ps;
AssertStmt *as;
if ((es = dynamic_cast<ExprStmt *>(node)) != NULL)
WalkAST(es->expr, preFunc, postFunc, data);
else if ((ds = dynamic_cast<DeclStmt *>(node)) != NULL) {
for (unsigned int i = 0; i < ds->vars.size(); ++i)
WalkAST(ds->vars[i].init, preFunc, postFunc, data);
}
else if ((is = dynamic_cast<IfStmt *>(node)) != NULL) {
WalkAST(is->test, preFunc, postFunc, data);
WalkAST(is->trueStmts, preFunc, postFunc, data);
WalkAST(is->falseStmts, preFunc, postFunc, data);
}
else if ((dos = dynamic_cast<DoStmt *>(node)) != NULL) {
WalkAST(dos->testExpr, preFunc, postFunc, data);
WalkAST(dos->bodyStmts, preFunc, postFunc, data);
}
else if ((fs = dynamic_cast<ForStmt *>(node)) != NULL) {
WalkAST(fs->init, preFunc, postFunc, data);
WalkAST(fs->test, preFunc, postFunc, data);
WalkAST(fs->step, preFunc, postFunc, data);
WalkAST(fs->stmts, preFunc, postFunc, data);
}
else if ((fes = dynamic_cast<ForeachStmt *>(node)) != NULL) {
for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
WalkAST(fes->startExprs[i], preFunc, postFunc, data);
for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
WalkAST(fes->endExprs[i], preFunc, postFunc, data);
WalkAST(fes->stmts, preFunc, postFunc, data);
}
else if (dynamic_cast<BreakStmt *>(node) != NULL ||
dynamic_cast<ContinueStmt *>(node) != NULL) {
// nothing
}
else if ((rs = dynamic_cast<ReturnStmt *>(node)) != NULL)
WalkAST(rs->val, preFunc, postFunc, data);
else if ((sl = dynamic_cast<StmtList *>(node)) != NULL) {
const std::vector<Stmt *> &sls = sl->GetStatements();
for (unsigned int i = 0; i < sls.size(); ++i)
WalkAST(sls[i], preFunc, postFunc, data);
}
else if ((ps = dynamic_cast<PrintStmt *>(node)) != NULL)
WalkAST(ps->values, preFunc, postFunc, data);
else if ((as = dynamic_cast<AssertStmt *>(node)) != NULL)
return WalkAST(as->expr, preFunc, postFunc, data);
else
FATAL("Unhandled statement type in WalkAST()");
}
else {
///////////////////////////////////////////////////////////////////////////
// Handle expressions
assert(dynamic_cast<Expr *>(node) != NULL);
UnaryExpr *ue;
BinaryExpr *be;
AssignExpr *ae;
SelectExpr *se;
ExprList *el;
FunctionCallExpr *fce;
IndexExpr *ie;
MemberExpr *me;
TypeCastExpr *tce;
ReferenceExpr *re;
DereferenceExpr *dre;
SizeOfExpr *soe;
AddressOfExpr *aoe;
if ((ue = dynamic_cast<UnaryExpr *>(node)) != NULL)
WalkAST(ue->expr, preFunc, postFunc, data);
else if ((be = dynamic_cast<BinaryExpr *>(node)) != NULL) {
WalkAST(be->arg0, preFunc, postFunc, data);
WalkAST(be->arg1, preFunc, postFunc, data);
}
else if ((ae = dynamic_cast<AssignExpr *>(node)) != NULL) {
WalkAST(ae->lvalue, preFunc, postFunc, data);
WalkAST(ae->rvalue, preFunc, postFunc, data);
}
else if ((se = dynamic_cast<SelectExpr *>(node)) != NULL) {
WalkAST(se->test, preFunc, postFunc, data);
WalkAST(se->expr1, preFunc, postFunc, data);
WalkAST(se->expr2, preFunc, postFunc, data);
}
else if ((el = dynamic_cast<ExprList *>(node)) != NULL) {
for (unsigned int i = 0; i < el->exprs.size(); ++i)
WalkAST(el->exprs[i], preFunc, postFunc, data);
}
else if ((fce = dynamic_cast<FunctionCallExpr *>(node)) != NULL) {
WalkAST(fce->func, preFunc, postFunc, data);
WalkAST(fce->args, preFunc, postFunc, data);
WalkAST(fce->launchCountExpr, preFunc, postFunc, data);
}
else if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL) {
WalkAST(ie->baseExpr, preFunc, postFunc, data);
WalkAST(ie->index, preFunc, postFunc, data);
}
else if ((me = dynamic_cast<MemberExpr *>(node)) != NULL)
WalkAST(me->expr, preFunc, postFunc, data);
else if ((tce = dynamic_cast<TypeCastExpr *>(node)) != NULL)
return WalkAST(tce->expr, preFunc, postFunc, data);
else if ((re = dynamic_cast<ReferenceExpr *>(node)) != NULL)
return WalkAST(re->expr, preFunc, postFunc, data);
else if ((dre = dynamic_cast<DereferenceExpr *>(node)) != NULL)
return WalkAST(dre->expr, preFunc, postFunc, data);
else if ((soe = dynamic_cast<SizeOfExpr *>(node)) != NULL)
return WalkAST(soe->expr, preFunc, postFunc, data);
else if ((aoe = dynamic_cast<AddressOfExpr *>(node)) != NULL)
WalkAST(aoe->expr, preFunc, postFunc, data);
else if (dynamic_cast<SymbolExpr *>(node) != NULL ||
dynamic_cast<ConstExpr *>(node) != NULL ||
dynamic_cast<FunctionSymbolExpr *>(node) != NULL ||
dynamic_cast<SyncExpr *>(node) != NULL ||
dynamic_cast<NullPointerExpr *>(node) != NULL) {
// nothing to do
}
else
FATAL("Unhandled expression type in WalkAST().");
}
// Call the callback function
if (postFunc != NULL)
postFunc(node, data);
}

11
ast.h
View File

@@ -91,4 +91,15 @@ private:
std::vector<Function *> functions;
};
typedef void (* ASTCallBackFunc)(ASTNode *node, void *data);
/** Walk (some portion of) an AST, starting from the given root node. At
each node, if preFunc is non-NULL, call it preFunc, passing the given
void *data pointer. Makes recursive calls to WalkAST() to process the
node's children; after doing so, postFunc, if non-NULL is called at the
node. */
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc,
ASTCallBackFunc postFunc, void *data);
#endif // ISPC_AST_H

View File

@@ -655,6 +655,72 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
}
static void
lCheckAllOffSafety(ASTNode *node, void *data) {
bool *okPtr = (bool *)data;
if (dynamic_cast<FunctionCallExpr *>(node) != NULL)
// FIXME: If we could somehow determine that the function being
// called was safe (and all of the args Exprs were safe, then it'd
// be nice to be able to return true here. (Consider a call to
// e.g. floatbits() in the stdlib.) Unfortunately for now we just
// have to be conservative.
*okPtr = false;
if (dynamic_cast<AssertStmt *>(node) != NULL)
// While this is fine for varying tests, it's not going to be
// desirable to check an assert on a uniform variable if all of the
// lanes are off.
*okPtr = false;
IndexExpr *ie;
if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) {
const Type *type = ie->baseExpr->GetType();
if (type == NULL)
return;
if (dynamic_cast<const ReferenceType *>(type) != NULL)
type = type->GetReferenceTarget();
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
if (ce == NULL) {
// indexing with a variable...
*okPtr = false;
return;
}
const PointerType *pointerType =
dynamic_cast<const PointerType *>(type);
if (pointerType != NULL) {
// pointer[index] -> can't be sure
*okPtr = false;
return;
}
const SequentialType *seqType =
dynamic_cast<const SequentialType *>(type);
Assert(seqType != NULL);
int nElements = seqType->GetElementCount();
if (nElements == 0) {
// Unsized array, so we can't be sure
*okPtr = false;
return;
}
int32_t indices[ISPC_MAX_NVEC];
int count = ce->AsInt32(indices);
for (int i = 0; i < count; ++i) {
if (indices[i] < 0 || indices[i] >= nElements) {
*okPtr = false;
return;
}
}
// All indices are in-bounds
}
}
/** Similar to the Stmt variant of this function, this conservatively
checks to see if it's safe to run the code for the given Expr even if
the mask is 'all off'.
@@ -662,7 +728,7 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
static bool
lSafeToRunWithAllLanesOff(Expr *expr) {
if (expr == NULL)
return false;
return true;
UnaryExpr *ue;
if ((ue = dynamic_cast<UnaryExpr *>(expr)) != NULL)
@@ -925,8 +991,14 @@ IfStmt::emitVaryingIf(FunctionEmitContext *ctx, llvm::Value *ltest) const {
bool costIsAcceptable = ((trueStmts ? trueStmts->EstimateCost() : 0) +
(falseStmts ? falseStmts->EstimateCost() : 0)) <
PREDICATE_SAFE_IF_STATEMENT_COST;
if (lSafeToRunWithAllLanesOff(trueStmts) &&
lSafeToRunWithAllLanesOff(falseStmts) &&
bool safeToRunWithAllLanesOff = true;
WalkAST(trueStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
WalkAST(falseStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
assert(safe == (lSafeToRunWithAllLanesOff(trueStmts) &
lSafeToRunWithAllLanesOff(falseStmts)));
if (safeToRunWithAllLanesOff &&
(costIsAcceptable || g->opt.disableCoherentControlFlow)) {
ctx->StartVaryingIf(oldMask);
emitMaskedTrueAndFalse(ctx, oldMask, ltest);