diff --git a/ast.cpp b/ast.cpp index 960f3503..80d90f10 100644 --- a/ast.cpp +++ b/ast.cpp @@ -75,8 +75,11 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, return; // Call the callback function - if (preFunc != NULL) - preFunc(node, data); + if (preFunc != NULL) { + if (preFunc(node, data) == false) + // The function asked us to not continue recursively, so stop. + return; + } //////////////////////////////////////////////////////////////////////////// // Handle Statements diff --git a/ast.h b/ast.h index 969b58e2..6fcb709a 100644 --- a/ast.h +++ b/ast.h @@ -92,13 +92,17 @@ private: }; -typedef void (* ASTCallBackFunc)(ASTNode *node, void *data); +/** Callback function type for the AST walk. + */ +typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data); /** Walk (some portion of) an AST, starting from the given root node. At - each node, if preFunc is non-NULL, call it preFunc, passing the given - void *data pointer. Makes recursive calls to WalkAST() to process the - node's children; after doing so, postFunc, if non-NULL is called at the - node. */ + each node, if preFunc is non-NULL, call it, passing the given void + *data pointer; if the call to preFunc function returns false, then the + children of the node aren't visited. This then makes recursive calls + to WalkAST() to process the node's children; after doing so, calls + postFunc, at the node. The return value from the postFunc call is + ignored. */ extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, void *data); diff --git a/stmt.cpp b/stmt.cpp index d5913cae..c7fa9cc9 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -656,46 +656,54 @@ IfStmt::emitMaskedTrueAndFalse(FunctionEmitContext *ctx, llvm::Value *oldMask, /** Given an AST node, check to see if it's safe if we happen to run the - code for that node with the execution mask all off. */ -static void + code for that node with the execution mask all off. + + FIXME: this is actually a target-specific thing; for non SSE/AVX + targets with more complete masking support, some of this won't apply... + */ +static bool lCheckAllOffSafety(ASTNode *node, void *data) { bool *okPtr = (bool *)data; - if (dynamic_cast(node) != NULL) + if (dynamic_cast(node) != NULL) { // FIXME: If we could somehow determine that the function being // called was safe (and all of the args Exprs were safe, then it'd // be nice to be able to return true here. (Consider a call to // e.g. floatbits() in the stdlib.) Unfortunately for now we just // have to be conservative. *okPtr = false; + return false; + } - if (dynamic_cast(node) != NULL) - // While this is fine for varying tests, it's not going to be + if (dynamic_cast(node) != NULL) { + // While it's fine to run the assert for varying tests, it's not // desirable to check an assert on a uniform variable if all of the // lanes are off. *okPtr = false; + return false; + } IndexExpr *ie; if ((ie = dynamic_cast(node)) != NULL && ie->baseExpr != NULL) { const Type *type = ie->baseExpr->GetType(); if (type == NULL) - return; + return true; if (dynamic_cast(type) != NULL) type = type->GetReferenceTarget(); ConstExpr *ce = dynamic_cast(ie->index); if (ce == NULL) { - // indexing with a variable... + // indexing with a variable... -> not safe *okPtr = false; - return; + return false; } const PointerType *pointerType = dynamic_cast(type); if (pointerType != NULL) { - // pointer[index] -> can't be sure + // pointer[index] -> can't be sure -> not safe *okPtr = false; - return; + return false; } const SequentialType *seqType = @@ -703,22 +711,25 @@ lCheckAllOffSafety(ASTNode *node, void *data) { Assert(seqType != NULL); int nElements = seqType->GetElementCount(); if (nElements == 0) { - // Unsized array, so we can't be sure + // Unsized array, so we can't be sure -> not safe *okPtr = false; - return; + return false; } int32_t indices[ISPC_MAX_NVEC]; int count = ce->AsInt32(indices); for (int i = 0; i < count; ++i) { if (indices[i] < 0 || indices[i] >= nElements) { + // Index is out of bounds -> not safe *okPtr = false; - return; + return false; } } // All indices are in-bounds } + + return true; } @@ -974,6 +985,83 @@ lHasVaryingBreakOrContinue(Stmt *stmt, bool inVaryingCF = false) { } +struct VaryingBCCheckInfo { + VaryingBCCheckInfo() { + varyingControlFlowDepth = 0; + foundVaryingBreakOrContinue = false; + } + + int varyingControlFlowDepth; + bool foundVaryingBreakOrContinue; +}; + + +/** Returns true if the given node is an 'if' statement where the test + condition has varying type. */ +static bool +lIsVaryingFor(ASTNode *node) { + IfStmt *ifStmt; + if ((ifStmt = dynamic_cast(node)) != NULL && + ifStmt->test != NULL) { + const Type *type = ifStmt->test->GetType(); + return (type != NULL && type->IsVaryingType()); + } + else + return false; +} + + +/** Preorder callback function for checking for varying breaks or + continues. */ +static bool +lVaryingBCPreFunc(ASTNode *node, void *d) { + VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d; + + // We found a break or continue statement; if we're under varying + // control flow, then bingo. + if ((dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL) && + info->varyingControlFlowDepth > 0) { + info->foundVaryingBreakOrContinue = true; + return false; + } + + // Update the count of the nesting depth of varying control flow if + // this is an if statement with a varying condition. + if (lIsVaryingFor(node)) + ++info->varyingControlFlowDepth; + + if (dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL || + dynamic_cast(node) != NULL) + // Don't recurse into these guys, since we don't care about varying + // breaks or continues within them... + return false; + else + return true; +} + + +/** Postorder callback function for checking for varying breaks or + continues; decrement the varying control flow depth after the node's + children have been processed, if this is a varying if statement. */ +static bool +lVaryingBCPostFunc(ASTNode *node, void *d) { + VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d; + if (lIsVaryingFor(node)) + --info->varyingControlFlowDepth; + return true; +} + + +static bool +lHasVaryingBreakOrContinue2(Stmt *stmt) { + VaryingBCCheckInfo info; + WalkAST(stmt, lVaryingBCPreFunc, lVaryingBCPostFunc, &info); + return info.foundVaryingBreakOrContinue; +} + + DoStmt::DoStmt(Expr *t, Stmt *s, bool cc, SourcePos p) : Stmt(p), testExpr(t), bodyStmts(s), doCoherentCheck(cc && !g->opt.disableCoherentControlFlow) { @@ -1124,6 +1212,9 @@ DoStmt::TypeCheck() { // code generated for the loop includes masking stuff, so // that we can track which lanes actually want to be // running, accounting for breaks/continues. + assert(lHasVaryingBreakOrContinue(bodyStmts) == + lHasVaryingBreakOrContinue2(bodyStmts)); + bool uniformTest = (testType->IsUniformType() && !g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(bodyStmts)); @@ -1185,6 +1276,9 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const { (!g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts)); + assert(lHasVaryingBreakOrContinue(stmts) == + lHasVaryingBreakOrContinue2(stmts)); + ctx->StartLoop(bexit, bstep, uniformTest); ctx->SetDebugPos(pos); @@ -1328,6 +1422,8 @@ ForStmt::TypeCheck() { // See comments in DoStmt::TypeCheck() regarding // 'uniformTest' and the type cast here. + assert(lHasVaryingBreakOrContinue(stmts) == + lHasVaryingBreakOrContinue2(stmts)); bool uniformTest = (testType->IsUniformType() && !g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts)); @@ -1353,6 +1449,8 @@ ForStmt::EstimateCost() const { bool uniformTest = test ? test->GetType()->IsUniformType() : (!g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts)); + assert(lHasVaryingBreakOrContinue(stmts) == + lHasVaryingBreakOrContinue2(stmts)); return ((init ? init->EstimateCost() : 0) + (test ? test->EstimateCost() : 0) +