diff --git a/ast.cpp b/ast.cpp index 80d90f10..5a8bec8a 100644 --- a/ast.cpp +++ b/ast.cpp @@ -68,17 +68,17 @@ AST::GenerateIR() { /////////////////////////////////////////////////////////////////////////// -void -WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, +ASTNode * +WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, void *data) { if (node == NULL) - return; + return node; // Call the callback function if (preFunc != NULL) { if (preFunc(node, data) == false) // The function asked us to not continue recursively, so stop. - return; + return node; } //////////////////////////////////////////////////////////////////////////// @@ -96,48 +96,55 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, AssertStmt *as; if ((es = dynamic_cast(node)) != NULL) - WalkAST(es->expr, preFunc, postFunc, data); + es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data); else if ((ds = dynamic_cast(node)) != NULL) { for (unsigned int i = 0; i < ds->vars.size(); ++i) - WalkAST(ds->vars[i].init, preFunc, postFunc, data); + ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, + postFunc, data); } else if ((is = dynamic_cast(node)) != NULL) { - WalkAST(is->test, preFunc, postFunc, data); - WalkAST(is->trueStmts, preFunc, postFunc, data); - WalkAST(is->falseStmts, preFunc, postFunc, data); + is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data); + is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, + postFunc, data); + is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, + postFunc, data); } else if ((dos = dynamic_cast(node)) != NULL) { - WalkAST(dos->testExpr, preFunc, postFunc, data); - WalkAST(dos->bodyStmts, preFunc, postFunc, data); + dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, + postFunc, data); + dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, + postFunc, data); } else if ((fs = dynamic_cast(node)) != NULL) { - WalkAST(fs->init, preFunc, postFunc, data); - WalkAST(fs->test, preFunc, postFunc, data); - WalkAST(fs->step, preFunc, postFunc, data); - WalkAST(fs->stmts, preFunc, postFunc, data); + fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data); + fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data); + fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data); + fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data); } else if ((fes = dynamic_cast(node)) != NULL) { for (unsigned int i = 0; i < fes->startExprs.size(); ++i) - WalkAST(fes->startExprs[i], preFunc, postFunc, data); + fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, + postFunc, data); for (unsigned int i = 0; i < fes->endExprs.size(); ++i) - WalkAST(fes->endExprs[i], preFunc, postFunc, data); - WalkAST(fes->stmts, preFunc, postFunc, data); + fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, + postFunc, data); + fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data); } else if (dynamic_cast(node) != NULL || dynamic_cast(node) != NULL) { // nothing } else if ((rs = dynamic_cast(node)) != NULL) - WalkAST(rs->val, preFunc, postFunc, data); + rs->val = (Expr *)WalkAST(rs->val, preFunc, postFunc, data); else if ((sl = dynamic_cast(node)) != NULL) { - const std::vector &sls = sl->GetStatements(); + std::vector &sls = sl->stmts; for (unsigned int i = 0; i < sls.size(); ++i) - WalkAST(sls[i], preFunc, postFunc, data); + sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data); } else if ((ps = dynamic_cast(node)) != NULL) - WalkAST(ps->values, preFunc, postFunc, data); + ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data); else if ((as = dynamic_cast(node)) != NULL) - return WalkAST(as->expr, preFunc, postFunc, data); + as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data); else FATAL("Unhandled statement type in WalkAST()"); } @@ -160,45 +167,47 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, AddressOfExpr *aoe; if ((ue = dynamic_cast(node)) != NULL) - WalkAST(ue->expr, preFunc, postFunc, data); + ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data); else if ((be = dynamic_cast(node)) != NULL) { - WalkAST(be->arg0, preFunc, postFunc, data); - WalkAST(be->arg1, preFunc, postFunc, data); + be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data); + be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data); } else if ((ae = dynamic_cast(node)) != NULL) { - WalkAST(ae->lvalue, preFunc, postFunc, data); - WalkAST(ae->rvalue, preFunc, postFunc, data); + ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data); + ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data); } else if ((se = dynamic_cast(node)) != NULL) { - WalkAST(se->test, preFunc, postFunc, data); - WalkAST(se->expr1, preFunc, postFunc, data); - WalkAST(se->expr2, preFunc, postFunc, data); + se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data); + se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data); + se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data); } else if ((el = dynamic_cast(node)) != NULL) { for (unsigned int i = 0; i < el->exprs.size(); ++i) - WalkAST(el->exprs[i], preFunc, postFunc, data); + el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, + postFunc, data); } else if ((fce = dynamic_cast(node)) != NULL) { - WalkAST(fce->func, preFunc, postFunc, data); - WalkAST(fce->args, preFunc, postFunc, data); - WalkAST(fce->launchCountExpr, preFunc, postFunc, data); + fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data); + fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data); + fce->launchCountExpr = (Expr *)WalkAST(fce->launchCountExpr, preFunc, + postFunc, data); } else if ((ie = dynamic_cast(node)) != NULL) { - WalkAST(ie->baseExpr, preFunc, postFunc, data); - WalkAST(ie->index, preFunc, postFunc, data); + ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data); + ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data); } else if ((me = dynamic_cast(node)) != NULL) - WalkAST(me->expr, preFunc, postFunc, data); + me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data); else if ((tce = dynamic_cast(node)) != NULL) - return WalkAST(tce->expr, preFunc, postFunc, data); + tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); else if ((re = dynamic_cast(node)) != NULL) - return WalkAST(re->expr, preFunc, postFunc, data); + re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); else if ((dre = dynamic_cast(node)) != NULL) - return WalkAST(dre->expr, preFunc, postFunc, data); + dre->expr = (Expr *)WalkAST(dre->expr, preFunc, postFunc, data); else if ((soe = dynamic_cast(node)) != NULL) - return WalkAST(soe->expr, preFunc, postFunc, data); + soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); else if ((aoe = dynamic_cast(node)) != NULL) - WalkAST(aoe->expr, preFunc, postFunc, data); + aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data); else if (dynamic_cast(node) != NULL || dynamic_cast(node) != NULL || dynamic_cast(node) != NULL || @@ -212,5 +221,7 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, // Call the callback function if (postFunc != NULL) - postFunc(node, data); + return postFunc(node, data); + else + return node; } diff --git a/ast.h b/ast.h index 6fcb709a..8a24aa92 100644 --- a/ast.h +++ b/ast.h @@ -92,18 +92,24 @@ private: }; -/** Callback function type for the AST walk. +/** Callback function type for preorder traversial visiting function for + the AST walk. */ -typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data); +typedef bool (* ASTPreCallBackFunc)(ASTNode *node, void *data); + +/** Callback function type for postorder traversial visiting function for + the AST walk. + */ +typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data); /** Walk (some portion of) an AST, starting from the given root node. At each node, if preFunc is non-NULL, call it, passing the given void *data pointer; if the call to preFunc function returns false, then the - children of the node aren't visited. This then makes recursive calls - to WalkAST() to process the node's children; after doing so, calls - postFunc, at the node. The return value from the postFunc call is - ignored. */ -extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc, - ASTCallBackFunc postFunc, void *data); + children of the node aren't visited. This function then makes + recursive calls to WalkAST() to process the node's children; after + doing so, calls postFunc, at the node. The return value from the + postFunc call is ignored. */ +extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc, + ASTPostCallBackFunc postFunc, void *data); #endif // ISPC_AST_H diff --git a/func.cpp b/func.cpp index 603a6641..641ff9e3 100644 --- a/func.cpp +++ b/func.cpp @@ -387,9 +387,8 @@ Function::GenerateIR() { SourcePos firstStmtPos = sym->pos; if (code) { StmtList *sl = dynamic_cast(code); - if (sl && sl->GetStatements().size() > 0 && - sl->GetStatements()[0] != NULL) - firstStmtPos = sl->GetStatements()[0]->pos; + if (sl && sl->stmts.size() > 0 && sl->stmts[0] != NULL) + firstStmtPos = sl->stmts[0]->pos; else firstStmtPos = code->pos; } diff --git a/stmt.cpp b/stmt.cpp index 3dcbc8d1..f8c994fa 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -997,12 +997,12 @@ lVaryingBCPreFunc(ASTNode *node, void *d) { /** Postorder callback function for checking for varying breaks or continues; decrement the varying control flow depth after the node's children have been processed, if this is a varying if statement. */ -static bool +static ASTNode * lVaryingBCPostFunc(ASTNode *node, void *d) { VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d; if (lIsVaryingFor(node)) --info->varyingControlFlowDepth; - return true; + return node; } diff --git a/stmt.h b/stmt.h index 2ce97753..e74ce125 100644 --- a/stmt.h +++ b/stmt.h @@ -302,9 +302,7 @@ public: int EstimateCost() const; void Add(Stmt *s) { if (s) stmts.push_back(s); } - const std::vector &GetStatements() { return stmts; } -private: std::vector stmts; };