From ced3f1f5fcdc3fc08fa516a50fa01d01195290d0 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Fri, 16 Dec 2011 11:06:09 -0800 Subject: [PATCH] Have WalkAST postorder callback function return an ASTNode * In general, it should just return the original node pointer, but for type checking and optimization passes, it can return a new value for the node (that will be assigned where the old one was in the tree.) Along the way, fixed some bugs in WalkAST() where the postorder callback wouldn't end up being called for a few expr types (sizeof, dereference, address of, reference). --- ast.cpp | 99 +++++++++++++++++++++++++++++++------------------------- ast.h | 22 ++++++++----- func.cpp | 5 ++- stmt.cpp | 4 +-- stmt.h | 2 -- 5 files changed, 73 insertions(+), 59 deletions(-) diff --git a/ast.cpp b/ast.cpp index 80d90f10..5a8bec8a 100644 --- a/ast.cpp +++ b/ast.cpp @@ -68,17 +68,17 @@ AST::GenerateIR() { /////////////////////////////////////////////////////////////////////////// -void -WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, +ASTNode * +WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, void *data) { if (node == NULL) - return; + return node; // Call the callback function if (preFunc != NULL) { if (preFunc(node, data) == false) // The function asked us to not continue recursively, so stop. - return; + return node; } //////////////////////////////////////////////////////////////////////////// @@ -96,48 +96,55 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, AssertStmt *as; if ((es = dynamic_cast(node)) != NULL) - WalkAST(es->expr, preFunc, postFunc, data); + es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data); else if ((ds = dynamic_cast(node)) != NULL) { for (unsigned int i = 0; i < ds->vars.size(); ++i) - WalkAST(ds->vars[i].init, preFunc, postFunc, data); + ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, + postFunc, data); } else if ((is = dynamic_cast(node)) != NULL) { - WalkAST(is->test, preFunc, postFunc, data); - WalkAST(is->trueStmts, preFunc, postFunc, data); - WalkAST(is->falseStmts, preFunc, postFunc, data); + is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data); + is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, + postFunc, data); + is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, + postFunc, data); } else if ((dos = dynamic_cast(node)) != NULL) { - WalkAST(dos->testExpr, preFunc, postFunc, data); - WalkAST(dos->bodyStmts, preFunc, postFunc, data); + dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, + postFunc, data); + dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, + postFunc, data); } else if ((fs = dynamic_cast(node)) != NULL) { - WalkAST(fs->init, preFunc, postFunc, data); - WalkAST(fs->test, preFunc, postFunc, data); - WalkAST(fs->step, preFunc, postFunc, data); - WalkAST(fs->stmts, preFunc, postFunc, data); + fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data); + fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data); + fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data); + fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data); } else if ((fes = dynamic_cast(node)) != NULL) { for (unsigned int i = 0; i < fes->startExprs.size(); ++i) - WalkAST(fes->startExprs[i], preFunc, postFunc, data); + fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, + postFunc, data); for (unsigned int i = 0; i < fes->endExprs.size(); ++i) - WalkAST(fes->endExprs[i], preFunc, postFunc, data); - WalkAST(fes->stmts, preFunc, postFunc, data); + fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, + postFunc, data); + fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data); } else if (dynamic_cast(node) != NULL || dynamic_cast(node) != NULL) { // nothing } else if ((rs = dynamic_cast(node)) != NULL) - WalkAST(rs->val, preFunc, postFunc, data); + rs->val = (Expr *)WalkAST(rs->val, preFunc, postFunc, data); else if ((sl = dynamic_cast(node)) != NULL) { - const std::vector &sls = sl->GetStatements(); + std::vector &sls = sl->stmts; for (unsigned int i = 0; i < sls.size(); ++i) - WalkAST(sls[i], preFunc, postFunc, data); + sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data); } else if ((ps = dynamic_cast(node)) != NULL) - WalkAST(ps->values, preFunc, postFunc, data); + ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data); else if ((as = dynamic_cast(node)) != NULL) - return WalkAST(as->expr, preFunc, postFunc, data); + as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data); else FATAL("Unhandled statement type in WalkAST()"); } @@ -160,45 +167,47 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, AddressOfExpr *aoe; if ((ue = dynamic_cast(node)) != NULL) - WalkAST(ue->expr, preFunc, postFunc, data); + ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data); else if ((be = dynamic_cast(node)) != NULL) { - WalkAST(be->arg0, preFunc, postFunc, data); - WalkAST(be->arg1, preFunc, postFunc, data); + be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data); + be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data); } else if ((ae = dynamic_cast(node)) != NULL) { - WalkAST(ae->lvalue, preFunc, postFunc, data); - WalkAST(ae->rvalue, preFunc, postFunc, data); + ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data); + ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data); } else if ((se = dynamic_cast(node)) != NULL) { - WalkAST(se->test, preFunc, postFunc, data); - WalkAST(se->expr1, preFunc, postFunc, data); - WalkAST(se->expr2, preFunc, postFunc, data); + se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data); + se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data); + se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data); } else if ((el = dynamic_cast(node)) != NULL) { for (unsigned int i = 0; i < el->exprs.size(); ++i) - WalkAST(el->exprs[i], preFunc, postFunc, data); + el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, + postFunc, data); } else if ((fce = dynamic_cast(node)) != NULL) { - WalkAST(fce->func, preFunc, postFunc, data); - WalkAST(fce->args, preFunc, postFunc, data); - WalkAST(fce->launchCountExpr, preFunc, postFunc, data); + fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data); + fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data); + fce->launchCountExpr = (Expr *)WalkAST(fce->launchCountExpr, preFunc, + postFunc, data); } else if ((ie = dynamic_cast(node)) != NULL) { - WalkAST(ie->baseExpr, preFunc, postFunc, data); - WalkAST(ie->index, preFunc, postFunc, data); + ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data); + ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data); } else if ((me = dynamic_cast(node)) != NULL) - WalkAST(me->expr, preFunc, postFunc, data); + me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data); else if ((tce = dynamic_cast(node)) != NULL) - return WalkAST(tce->expr, preFunc, postFunc, data); + tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); else if ((re = dynamic_cast(node)) != NULL) - return WalkAST(re->expr, preFunc, postFunc, data); + re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); else if ((dre = dynamic_cast(node)) != NULL) - return WalkAST(dre->expr, preFunc, postFunc, data); + dre->expr = (Expr *)WalkAST(dre->expr, preFunc, postFunc, data); else if ((soe = dynamic_cast(node)) != NULL) - return WalkAST(soe->expr, preFunc, postFunc, data); + soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); else if ((aoe = dynamic_cast(node)) != NULL) - WalkAST(aoe->expr, preFunc, postFunc, data); + aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data); else if (dynamic_cast(node) != NULL || dynamic_cast(node) != NULL || dynamic_cast(node) != NULL || @@ -212,5 +221,7 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, // Call the callback function if (postFunc != NULL) - postFunc(node, data); + return postFunc(node, data); + else + return node; } diff --git a/ast.h b/ast.h index 6fcb709a..8a24aa92 100644 --- a/ast.h +++ b/ast.h @@ -92,18 +92,24 @@ private: }; -/** Callback function type for the AST walk. +/** Callback function type for preorder traversial visiting function for + the AST walk. */ -typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data); +typedef bool (* ASTPreCallBackFunc)(ASTNode *node, void *data); + +/** Callback function type for postorder traversial visiting function for + the AST walk. + */ +typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data); /** Walk (some portion of) an AST, starting from the given root node. At each node, if preFunc is non-NULL, call it, passing the given void *data pointer; if the call to preFunc function returns false, then the - children of the node aren't visited. This then makes recursive calls - to WalkAST() to process the node's children; after doing so, calls - postFunc, at the node. The return value from the postFunc call is - ignored. */ -extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc, - ASTCallBackFunc postFunc, void *data); + children of the node aren't visited. This function then makes + recursive calls to WalkAST() to process the node's children; after + doing so, calls postFunc, at the node. The return value from the + postFunc call is ignored. */ +extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc, + ASTPostCallBackFunc postFunc, void *data); #endif // ISPC_AST_H diff --git a/func.cpp b/func.cpp index 603a6641..641ff9e3 100644 --- a/func.cpp +++ b/func.cpp @@ -387,9 +387,8 @@ Function::GenerateIR() { SourcePos firstStmtPos = sym->pos; if (code) { StmtList *sl = dynamic_cast(code); - if (sl && sl->GetStatements().size() > 0 && - sl->GetStatements()[0] != NULL) - firstStmtPos = sl->GetStatements()[0]->pos; + if (sl && sl->stmts.size() > 0 && sl->stmts[0] != NULL) + firstStmtPos = sl->stmts[0]->pos; else firstStmtPos = code->pos; } diff --git a/stmt.cpp b/stmt.cpp index 3dcbc8d1..f8c994fa 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -997,12 +997,12 @@ lVaryingBCPreFunc(ASTNode *node, void *d) { /** Postorder callback function for checking for varying breaks or continues; decrement the varying control flow depth after the node's children have been processed, if this is a varying if statement. */ -static bool +static ASTNode * lVaryingBCPostFunc(ASTNode *node, void *d) { VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d; if (lIsVaryingFor(node)) --info->varyingControlFlowDepth; - return true; + return node; } diff --git a/stmt.h b/stmt.h index 2ce97753..e74ce125 100644 --- a/stmt.h +++ b/stmt.h @@ -302,9 +302,7 @@ public: int EstimateCost() const; void Add(Stmt *s) { if (s) stmts.push_back(s); } - const std::vector &GetStatements() { return stmts; } -private: std::vector stmts; };