diff --git a/ast.cpp b/ast.cpp index 71dd06bb..960f3503 100644 --- a/ast.cpp +++ b/ast.cpp @@ -36,8 +36,11 @@ */ #include "ast.h" +#include "expr.h" #include "func.h" +#include "stmt.h" #include "sym.h" +#include "util.h" /////////////////////////////////////////////////////////////////////////// // ASTNode @@ -63,3 +66,148 @@ AST::GenerateIR() { functions[i]->GenerateIR(); } +/////////////////////////////////////////////////////////////////////////// + +void +WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, + void *data) { + if (node == NULL) + return; + + // Call the callback function + if (preFunc != NULL) + preFunc(node, data); + + //////////////////////////////////////////////////////////////////////////// + // Handle Statements + if (dynamic_cast(node) != NULL) { + ExprStmt *es; + DeclStmt *ds; + IfStmt *is; + DoStmt *dos; + ForStmt *fs; + ForeachStmt *fes; + ReturnStmt *rs; + StmtList *sl; + PrintStmt *ps; + AssertStmt *as; + + if ((es = dynamic_cast(node)) != NULL) + WalkAST(es->expr, preFunc, postFunc, data); + else if ((ds = dynamic_cast(node)) != NULL) { + for (unsigned int i = 0; i < ds->vars.size(); ++i) + WalkAST(ds->vars[i].init, preFunc, postFunc, data); + } + else if ((is = dynamic_cast(node)) != NULL) { + WalkAST(is->test, preFunc, postFunc, data); + WalkAST(is->trueStmts, preFunc, postFunc, data); + WalkAST(is->falseStmts, preFunc, postFunc, data); + } + else if ((dos = dynamic_cast(node)) != NULL) { + WalkAST(dos->testExpr, preFunc, postFunc, data); + WalkAST(dos->bodyStmts, preFunc, postFunc, data); + } + else if ((fs = dynamic_cast(node)) != NULL) { + WalkAST(fs->init, preFunc, postFunc, data); + WalkAST(fs->test, preFunc, postFunc, data); + WalkAST(fs->step, preFunc, postFunc, data); + WalkAST(fs->stmts, preFunc, postFunc, data); + } + else if ((fes = dynamic_cast(node)) != NULL) { + for (unsigned int i = 0; i < fes->startExprs.size(); ++i) + WalkAST(fes->startExprs[i], preFunc, postFunc, data); + for (unsigned int i = 0; i < fes->endExprs.size(); ++i) + WalkAST(fes->endExprs[i], preFunc, postFunc, data); + WalkAST(fes->stmts, preFunc, postFunc, data); + } + else if (dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL) { + // nothing + } + else if ((rs = dynamic_cast(node)) != NULL) + WalkAST(rs->val, preFunc, postFunc, data); + else if ((sl = dynamic_cast(node)) != NULL) { + const std::vector &sls = sl->GetStatements(); + for (unsigned int i = 0; i < sls.size(); ++i) + WalkAST(sls[i], preFunc, postFunc, data); + } + else if ((ps = dynamic_cast(node)) != NULL) + WalkAST(ps->values, preFunc, postFunc, data); + else if ((as = dynamic_cast(node)) != NULL) + return WalkAST(as->expr, preFunc, postFunc, data); + else + FATAL("Unhandled statement type in WalkAST()"); + } + else { + /////////////////////////////////////////////////////////////////////////// + // Handle expressions + assert(dynamic_cast(node) != NULL); + UnaryExpr *ue; + BinaryExpr *be; + AssignExpr *ae; + SelectExpr *se; + ExprList *el; + FunctionCallExpr *fce; + IndexExpr *ie; + MemberExpr *me; + TypeCastExpr *tce; + ReferenceExpr *re; + DereferenceExpr *dre; + SizeOfExpr *soe; + AddressOfExpr *aoe; + + if ((ue = dynamic_cast(node)) != NULL) + WalkAST(ue->expr, preFunc, postFunc, data); + else if ((be = dynamic_cast(node)) != NULL) { + WalkAST(be->arg0, preFunc, postFunc, data); + WalkAST(be->arg1, preFunc, postFunc, data); + } + else if ((ae = dynamic_cast(node)) != NULL) { + WalkAST(ae->lvalue, preFunc, postFunc, data); + WalkAST(ae->rvalue, preFunc, postFunc, data); + } + else if ((se = dynamic_cast(node)) != NULL) { + WalkAST(se->test, preFunc, postFunc, data); + WalkAST(se->expr1, preFunc, postFunc, data); + WalkAST(se->expr2, preFunc, postFunc, data); + } + else if ((el = dynamic_cast(node)) != NULL) { + for (unsigned int i = 0; i < el->exprs.size(); ++i) + WalkAST(el->exprs[i], preFunc, postFunc, data); + } + else if ((fce = dynamic_cast(node)) != NULL) { + WalkAST(fce->func, preFunc, postFunc, data); + WalkAST(fce->args, preFunc, postFunc, data); + WalkAST(fce->launchCountExpr, preFunc, postFunc, data); + } + else if ((ie = dynamic_cast(node)) != NULL) { + WalkAST(ie->baseExpr, preFunc, postFunc, data); + WalkAST(ie->index, preFunc, postFunc, data); + } + else if ((me = dynamic_cast(node)) != NULL) + WalkAST(me->expr, preFunc, postFunc, data); + else if ((tce = dynamic_cast(node)) != NULL) + return WalkAST(tce->expr, preFunc, postFunc, data); + else if ((re = dynamic_cast(node)) != NULL) + return WalkAST(re->expr, preFunc, postFunc, data); + else if ((dre = dynamic_cast(node)) != NULL) + return WalkAST(dre->expr, preFunc, postFunc, data); + else if ((soe = dynamic_cast(node)) != NULL) + return WalkAST(soe->expr, preFunc, postFunc, data); + else if ((aoe = dynamic_cast(node)) != NULL) + WalkAST(aoe->expr, preFunc, postFunc, data); + else if (dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL) { + // nothing to do + } + else + FATAL("Unhandled expression type in WalkAST()."); + } + + // Call the callback function + if (postFunc != NULL) + postFunc(node, data); +} diff --git a/ast.h b/ast.h index bb574b02..969b58e2 100644 --- a/ast.h +++ b/ast.h @@ -91,4 +91,15 @@ private: std::vector functions; }; + +typedef void (* ASTCallBackFunc)(ASTNode *node, void *data); + +/** Walk (some portion of) an AST, starting from the given root node. At + each node, if preFunc is non-NULL, call it preFunc, passing the given + void *data pointer. Makes recursive calls to WalkAST() to process the + node's children; after doing so, postFunc, if non-NULL is called at the + node. */ +extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc, + ASTCallBackFunc postFunc, void *data); + #endif // ISPC_AST_H diff --git a/stmt.cpp b/stmt.cpp index 8a34ba1c..2058c662 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -655,6 +655,72 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask, } + +static void +lCheckAllOffSafety(ASTNode *node, void *data) { + bool *okPtr = (bool *)data; + + if (dynamic_cast(node) != NULL) + // FIXME: If we could somehow determine that the function being + // called was safe (and all of the args Exprs were safe, then it'd + // be nice to be able to return true here. (Consider a call to + // e.g. floatbits() in the stdlib.) Unfortunately for now we just + // have to be conservative. + *okPtr = false; + + if (dynamic_cast(node) != NULL) + // While this is fine for varying tests, it's not going to be + // desirable to check an assert on a uniform variable if all of the + // lanes are off. + *okPtr = false; + + IndexExpr *ie; + if ((ie = dynamic_cast(node)) != NULL && ie->baseExpr != NULL) { + const Type *type = ie->baseExpr->GetType(); + if (type == NULL) + return; + if (dynamic_cast(type) != NULL) + type = type->GetReferenceTarget(); + + ConstExpr *ce = dynamic_cast(ie->index); + if (ce == NULL) { + // indexing with a variable... + *okPtr = false; + return; + } + + const PointerType *pointerType = + dynamic_cast(type); + if (pointerType != NULL) { + // pointer[index] -> can't be sure + *okPtr = false; + return; + } + + const SequentialType *seqType = + dynamic_cast(type); + Assert(seqType != NULL); + int nElements = seqType->GetElementCount(); + if (nElements == 0) { + // Unsized array, so we can't be sure + *okPtr = false; + return; + } + + int32_t indices[ISPC_MAX_NVEC]; + int count = ce->AsInt32(indices); + for (int i = 0; i < count; ++i) { + if (indices[i] < 0 || indices[i] >= nElements) { + *okPtr = false; + return; + } + } + + // All indices are in-bounds + } +} + + /** Similar to the Stmt variant of this function, this conservatively checks to see if it's safe to run the code for the given Expr even if the mask is 'all off'. @@ -662,7 +728,7 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask, static bool lSafeToRunWithAllLanesOff(Expr *expr) { if (expr == NULL) - return false; + return true; UnaryExpr *ue; if ((ue = dynamic_cast(expr)) != NULL) @@ -925,8 +991,14 @@ IfStmt::emitVaryingIf(FunctionEmitContext *ctx, llvm::Value *ltest) const { bool costIsAcceptable = ((trueStmts ? trueStmts->EstimateCost() : 0) + (falseStmts ? falseStmts->EstimateCost() : 0)) < PREDICATE_SAFE_IF_STATEMENT_COST; - if (lSafeToRunWithAllLanesOff(trueStmts) && - lSafeToRunWithAllLanesOff(falseStmts) && + + bool safeToRunWithAllLanesOff = true; + WalkAST(trueStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff); + WalkAST(falseStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff); + assert(safe == (lSafeToRunWithAllLanesOff(trueStmts) & + lSafeToRunWithAllLanesOff(falseStmts))); + + if (safeToRunWithAllLanesOff && (costIsAcceptable || g->opt.disableCoherentControlFlow)) { ctx->StartVaryingIf(oldMask); emitMaskedTrueAndFalse(ctx, oldMask, ltest);