Add WalkAST() function for generic AST walking.
For starters, use it for the check to see if code is safe to run with the mask all off. This also fixes a bug where we would sometimes incorrectly say that a whole block of code was unsafe to run with an all off mask because we came to a NULL AST node during traversal.
This commit is contained in:
148
ast.cpp
148
ast.cpp
@@ -36,8 +36,11 @@
|
||||
*/
|
||||
|
||||
#include "ast.h"
|
||||
#include "expr.h"
|
||||
#include "func.h"
|
||||
#include "stmt.h"
|
||||
#include "sym.h"
|
||||
#include "util.h"
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// ASTNode
|
||||
@@ -63,3 +66,148 @@ AST::GenerateIR() {
|
||||
functions[i]->GenerateIR();
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void
|
||||
WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
|
||||
void *data) {
|
||||
if (node == NULL)
|
||||
return;
|
||||
|
||||
// Call the callback function
|
||||
if (preFunc != NULL)
|
||||
preFunc(node, data);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Handle Statements
|
||||
if (dynamic_cast<Stmt *>(node) != NULL) {
|
||||
ExprStmt *es;
|
||||
DeclStmt *ds;
|
||||
IfStmt *is;
|
||||
DoStmt *dos;
|
||||
ForStmt *fs;
|
||||
ForeachStmt *fes;
|
||||
ReturnStmt *rs;
|
||||
StmtList *sl;
|
||||
PrintStmt *ps;
|
||||
AssertStmt *as;
|
||||
|
||||
if ((es = dynamic_cast<ExprStmt *>(node)) != NULL)
|
||||
WalkAST(es->expr, preFunc, postFunc, data);
|
||||
else if ((ds = dynamic_cast<DeclStmt *>(node)) != NULL) {
|
||||
for (unsigned int i = 0; i < ds->vars.size(); ++i)
|
||||
WalkAST(ds->vars[i].init, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((is = dynamic_cast<IfStmt *>(node)) != NULL) {
|
||||
WalkAST(is->test, preFunc, postFunc, data);
|
||||
WalkAST(is->trueStmts, preFunc, postFunc, data);
|
||||
WalkAST(is->falseStmts, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((dos = dynamic_cast<DoStmt *>(node)) != NULL) {
|
||||
WalkAST(dos->testExpr, preFunc, postFunc, data);
|
||||
WalkAST(dos->bodyStmts, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((fs = dynamic_cast<ForStmt *>(node)) != NULL) {
|
||||
WalkAST(fs->init, preFunc, postFunc, data);
|
||||
WalkAST(fs->test, preFunc, postFunc, data);
|
||||
WalkAST(fs->step, preFunc, postFunc, data);
|
||||
WalkAST(fs->stmts, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((fes = dynamic_cast<ForeachStmt *>(node)) != NULL) {
|
||||
for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
|
||||
WalkAST(fes->startExprs[i], preFunc, postFunc, data);
|
||||
for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
|
||||
WalkAST(fes->endExprs[i], preFunc, postFunc, data);
|
||||
WalkAST(fes->stmts, preFunc, postFunc, data);
|
||||
}
|
||||
else if (dynamic_cast<BreakStmt *>(node) != NULL ||
|
||||
dynamic_cast<ContinueStmt *>(node) != NULL) {
|
||||
// nothing
|
||||
}
|
||||
else if ((rs = dynamic_cast<ReturnStmt *>(node)) != NULL)
|
||||
WalkAST(rs->val, preFunc, postFunc, data);
|
||||
else if ((sl = dynamic_cast<StmtList *>(node)) != NULL) {
|
||||
const std::vector<Stmt *> &sls = sl->GetStatements();
|
||||
for (unsigned int i = 0; i < sls.size(); ++i)
|
||||
WalkAST(sls[i], preFunc, postFunc, data);
|
||||
}
|
||||
else if ((ps = dynamic_cast<PrintStmt *>(node)) != NULL)
|
||||
WalkAST(ps->values, preFunc, postFunc, data);
|
||||
else if ((as = dynamic_cast<AssertStmt *>(node)) != NULL)
|
||||
return WalkAST(as->expr, preFunc, postFunc, data);
|
||||
else
|
||||
FATAL("Unhandled statement type in WalkAST()");
|
||||
}
|
||||
else {
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// Handle expressions
|
||||
assert(dynamic_cast<Expr *>(node) != NULL);
|
||||
UnaryExpr *ue;
|
||||
BinaryExpr *be;
|
||||
AssignExpr *ae;
|
||||
SelectExpr *se;
|
||||
ExprList *el;
|
||||
FunctionCallExpr *fce;
|
||||
IndexExpr *ie;
|
||||
MemberExpr *me;
|
||||
TypeCastExpr *tce;
|
||||
ReferenceExpr *re;
|
||||
DereferenceExpr *dre;
|
||||
SizeOfExpr *soe;
|
||||
AddressOfExpr *aoe;
|
||||
|
||||
if ((ue = dynamic_cast<UnaryExpr *>(node)) != NULL)
|
||||
WalkAST(ue->expr, preFunc, postFunc, data);
|
||||
else if ((be = dynamic_cast<BinaryExpr *>(node)) != NULL) {
|
||||
WalkAST(be->arg0, preFunc, postFunc, data);
|
||||
WalkAST(be->arg1, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((ae = dynamic_cast<AssignExpr *>(node)) != NULL) {
|
||||
WalkAST(ae->lvalue, preFunc, postFunc, data);
|
||||
WalkAST(ae->rvalue, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((se = dynamic_cast<SelectExpr *>(node)) != NULL) {
|
||||
WalkAST(se->test, preFunc, postFunc, data);
|
||||
WalkAST(se->expr1, preFunc, postFunc, data);
|
||||
WalkAST(se->expr2, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((el = dynamic_cast<ExprList *>(node)) != NULL) {
|
||||
for (unsigned int i = 0; i < el->exprs.size(); ++i)
|
||||
WalkAST(el->exprs[i], preFunc, postFunc, data);
|
||||
}
|
||||
else if ((fce = dynamic_cast<FunctionCallExpr *>(node)) != NULL) {
|
||||
WalkAST(fce->func, preFunc, postFunc, data);
|
||||
WalkAST(fce->args, preFunc, postFunc, data);
|
||||
WalkAST(fce->launchCountExpr, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL) {
|
||||
WalkAST(ie->baseExpr, preFunc, postFunc, data);
|
||||
WalkAST(ie->index, preFunc, postFunc, data);
|
||||
}
|
||||
else if ((me = dynamic_cast<MemberExpr *>(node)) != NULL)
|
||||
WalkAST(me->expr, preFunc, postFunc, data);
|
||||
else if ((tce = dynamic_cast<TypeCastExpr *>(node)) != NULL)
|
||||
return WalkAST(tce->expr, preFunc, postFunc, data);
|
||||
else if ((re = dynamic_cast<ReferenceExpr *>(node)) != NULL)
|
||||
return WalkAST(re->expr, preFunc, postFunc, data);
|
||||
else if ((dre = dynamic_cast<DereferenceExpr *>(node)) != NULL)
|
||||
return WalkAST(dre->expr, preFunc, postFunc, data);
|
||||
else if ((soe = dynamic_cast<SizeOfExpr *>(node)) != NULL)
|
||||
return WalkAST(soe->expr, preFunc, postFunc, data);
|
||||
else if ((aoe = dynamic_cast<AddressOfExpr *>(node)) != NULL)
|
||||
WalkAST(aoe->expr, preFunc, postFunc, data);
|
||||
else if (dynamic_cast<SymbolExpr *>(node) != NULL ||
|
||||
dynamic_cast<ConstExpr *>(node) != NULL ||
|
||||
dynamic_cast<FunctionSymbolExpr *>(node) != NULL ||
|
||||
dynamic_cast<SyncExpr *>(node) != NULL ||
|
||||
dynamic_cast<NullPointerExpr *>(node) != NULL) {
|
||||
// nothing to do
|
||||
}
|
||||
else
|
||||
FATAL("Unhandled expression type in WalkAST().");
|
||||
}
|
||||
|
||||
// Call the callback function
|
||||
if (postFunc != NULL)
|
||||
postFunc(node, data);
|
||||
}
|
||||
|
||||
11
ast.h
11
ast.h
@@ -91,4 +91,15 @@ private:
|
||||
std::vector<Function *> functions;
|
||||
};
|
||||
|
||||
|
||||
typedef void (* ASTCallBackFunc)(ASTNode *node, void *data);
|
||||
|
||||
/** Walk (some portion of) an AST, starting from the given root node. At
|
||||
each node, if preFunc is non-NULL, call it preFunc, passing the given
|
||||
void *data pointer. Makes recursive calls to WalkAST() to process the
|
||||
node's children; after doing so, postFunc, if non-NULL is called at the
|
||||
node. */
|
||||
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc,
|
||||
ASTCallBackFunc postFunc, void *data);
|
||||
|
||||
#endif // ISPC_AST_H
|
||||
|
||||
78
stmt.cpp
78
stmt.cpp
@@ -655,6 +655,72 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
|
||||
}
|
||||
|
||||
|
||||
|
||||
static void
|
||||
lCheckAllOffSafety(ASTNode *node, void *data) {
|
||||
bool *okPtr = (bool *)data;
|
||||
|
||||
if (dynamic_cast<FunctionCallExpr *>(node) != NULL)
|
||||
// FIXME: If we could somehow determine that the function being
|
||||
// called was safe (and all of the args Exprs were safe, then it'd
|
||||
// be nice to be able to return true here. (Consider a call to
|
||||
// e.g. floatbits() in the stdlib.) Unfortunately for now we just
|
||||
// have to be conservative.
|
||||
*okPtr = false;
|
||||
|
||||
if (dynamic_cast<AssertStmt *>(node) != NULL)
|
||||
// While this is fine for varying tests, it's not going to be
|
||||
// desirable to check an assert on a uniform variable if all of the
|
||||
// lanes are off.
|
||||
*okPtr = false;
|
||||
|
||||
IndexExpr *ie;
|
||||
if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL && ie->baseExpr != NULL) {
|
||||
const Type *type = ie->baseExpr->GetType();
|
||||
if (type == NULL)
|
||||
return;
|
||||
if (dynamic_cast<const ReferenceType *>(type) != NULL)
|
||||
type = type->GetReferenceTarget();
|
||||
|
||||
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
|
||||
if (ce == NULL) {
|
||||
// indexing with a variable...
|
||||
*okPtr = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const PointerType *pointerType =
|
||||
dynamic_cast<const PointerType *>(type);
|
||||
if (pointerType != NULL) {
|
||||
// pointer[index] -> can't be sure
|
||||
*okPtr = false;
|
||||
return;
|
||||
}
|
||||
|
||||
const SequentialType *seqType =
|
||||
dynamic_cast<const SequentialType *>(type);
|
||||
Assert(seqType != NULL);
|
||||
int nElements = seqType->GetElementCount();
|
||||
if (nElements == 0) {
|
||||
// Unsized array, so we can't be sure
|
||||
*okPtr = false;
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t indices[ISPC_MAX_NVEC];
|
||||
int count = ce->AsInt32(indices);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
if (indices[i] < 0 || indices[i] >= nElements) {
|
||||
*okPtr = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// All indices are in-bounds
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Similar to the Stmt variant of this function, this conservatively
|
||||
checks to see if it's safe to run the code for the given Expr even if
|
||||
the mask is 'all off'.
|
||||
@@ -662,7 +728,7 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask,
|
||||
static bool
|
||||
lSafeToRunWithAllLanesOff(Expr *expr) {
|
||||
if (expr == NULL)
|
||||
return false;
|
||||
return true;
|
||||
|
||||
UnaryExpr *ue;
|
||||
if ((ue = dynamic_cast<UnaryExpr *>(expr)) != NULL)
|
||||
@@ -925,8 +991,14 @@ IfStmt::emitVaryingIf(FunctionEmitContext *ctx, llvm::Value *ltest) const {
|
||||
bool costIsAcceptable = ((trueStmts ? trueStmts->EstimateCost() : 0) +
|
||||
(falseStmts ? falseStmts->EstimateCost() : 0)) <
|
||||
PREDICATE_SAFE_IF_STATEMENT_COST;
|
||||
if (lSafeToRunWithAllLanesOff(trueStmts) &&
|
||||
lSafeToRunWithAllLanesOff(falseStmts) &&
|
||||
|
||||
bool safeToRunWithAllLanesOff = true;
|
||||
WalkAST(trueStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
|
||||
WalkAST(falseStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
|
||||
assert(safe == (lSafeToRunWithAllLanesOff(trueStmts) &
|
||||
lSafeToRunWithAllLanesOff(falseStmts)));
|
||||
|
||||
if (safeToRunWithAllLanesOff &&
|
||||
(costIsAcceptable || g->opt.disableCoherentControlFlow)) {
|
||||
ctx->StartVaryingIf(oldMask);
|
||||
emitMaskedTrueAndFalse(ctx, oldMask, ltest);
|
||||
|
||||
Reference in New Issue
Block a user