Have WalkAST postorder callback function return an ASTNode *

In general, it should just return the original node pointer, but for type checking
and optimization passes, it can return a new value for the node (that will be
assigned where the old one was in the tree.)

Along the way, fixed some bugs in WalkAST() where the postorder callback wouldn't
end up being called for a few expr types (sizeof, dereference, address of, 
reference).
This commit is contained in:
Matt Pharr
2011-12-16 11:06:09 -08:00
parent 018aa96c8b
commit ced3f1f5fc
5 changed files with 73 additions and 59 deletions

99
ast.cpp
View File

@@ -68,17 +68,17 @@ AST::GenerateIR() {
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
void ASTNode *
WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc, WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
void *data) { void *data) {
if (node == NULL) if (node == NULL)
return; return node;
// Call the callback function // Call the callback function
if (preFunc != NULL) { if (preFunc != NULL) {
if (preFunc(node, data) == false) if (preFunc(node, data) == false)
// The function asked us to not continue recursively, so stop. // The function asked us to not continue recursively, so stop.
return; return node;
} }
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
@@ -96,48 +96,55 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
AssertStmt *as; AssertStmt *as;
if ((es = dynamic_cast<ExprStmt *>(node)) != NULL) if ((es = dynamic_cast<ExprStmt *>(node)) != NULL)
WalkAST(es->expr, preFunc, postFunc, data); es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data);
else if ((ds = dynamic_cast<DeclStmt *>(node)) != NULL) { else if ((ds = dynamic_cast<DeclStmt *>(node)) != NULL) {
for (unsigned int i = 0; i < ds->vars.size(); ++i) for (unsigned int i = 0; i < ds->vars.size(); ++i)
WalkAST(ds->vars[i].init, preFunc, postFunc, data); ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc,
postFunc, data);
} }
else if ((is = dynamic_cast<IfStmt *>(node)) != NULL) { else if ((is = dynamic_cast<IfStmt *>(node)) != NULL) {
WalkAST(is->test, preFunc, postFunc, data); is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data);
WalkAST(is->trueStmts, preFunc, postFunc, data); is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc,
WalkAST(is->falseStmts, preFunc, postFunc, data); postFunc, data);
is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc,
postFunc, data);
} }
else if ((dos = dynamic_cast<DoStmt *>(node)) != NULL) { else if ((dos = dynamic_cast<DoStmt *>(node)) != NULL) {
WalkAST(dos->testExpr, preFunc, postFunc, data); dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc,
WalkAST(dos->bodyStmts, preFunc, postFunc, data); postFunc, data);
dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc,
postFunc, data);
} }
else if ((fs = dynamic_cast<ForStmt *>(node)) != NULL) { else if ((fs = dynamic_cast<ForStmt *>(node)) != NULL) {
WalkAST(fs->init, preFunc, postFunc, data); fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data);
WalkAST(fs->test, preFunc, postFunc, data); fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data);
WalkAST(fs->step, preFunc, postFunc, data); fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data);
WalkAST(fs->stmts, preFunc, postFunc, data); fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data);
} }
else if ((fes = dynamic_cast<ForeachStmt *>(node)) != NULL) { else if ((fes = dynamic_cast<ForeachStmt *>(node)) != NULL) {
for (unsigned int i = 0; i < fes->startExprs.size(); ++i) for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
WalkAST(fes->startExprs[i], preFunc, postFunc, data); fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc,
postFunc, data);
for (unsigned int i = 0; i < fes->endExprs.size(); ++i) for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
WalkAST(fes->endExprs[i], preFunc, postFunc, data); fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc,
WalkAST(fes->stmts, preFunc, postFunc, data); postFunc, data);
fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data);
} }
else if (dynamic_cast<BreakStmt *>(node) != NULL || else if (dynamic_cast<BreakStmt *>(node) != NULL ||
dynamic_cast<ContinueStmt *>(node) != NULL) { dynamic_cast<ContinueStmt *>(node) != NULL) {
// nothing // nothing
} }
else if ((rs = dynamic_cast<ReturnStmt *>(node)) != NULL) else if ((rs = dynamic_cast<ReturnStmt *>(node)) != NULL)
WalkAST(rs->val, preFunc, postFunc, data); rs->val = (Expr *)WalkAST(rs->val, preFunc, postFunc, data);
else if ((sl = dynamic_cast<StmtList *>(node)) != NULL) { else if ((sl = dynamic_cast<StmtList *>(node)) != NULL) {
const std::vector<Stmt *> &sls = sl->GetStatements(); std::vector<Stmt *> &sls = sl->stmts;
for (unsigned int i = 0; i < sls.size(); ++i) for (unsigned int i = 0; i < sls.size(); ++i)
WalkAST(sls[i], preFunc, postFunc, data); sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data);
} }
else if ((ps = dynamic_cast<PrintStmt *>(node)) != NULL) else if ((ps = dynamic_cast<PrintStmt *>(node)) != NULL)
WalkAST(ps->values, preFunc, postFunc, data); ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data);
else if ((as = dynamic_cast<AssertStmt *>(node)) != NULL) else if ((as = dynamic_cast<AssertStmt *>(node)) != NULL)
return WalkAST(as->expr, preFunc, postFunc, data); as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data);
else else
FATAL("Unhandled statement type in WalkAST()"); FATAL("Unhandled statement type in WalkAST()");
} }
@@ -160,45 +167,47 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
AddressOfExpr *aoe; AddressOfExpr *aoe;
if ((ue = dynamic_cast<UnaryExpr *>(node)) != NULL) if ((ue = dynamic_cast<UnaryExpr *>(node)) != NULL)
WalkAST(ue->expr, preFunc, postFunc, data); ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data);
else if ((be = dynamic_cast<BinaryExpr *>(node)) != NULL) { else if ((be = dynamic_cast<BinaryExpr *>(node)) != NULL) {
WalkAST(be->arg0, preFunc, postFunc, data); be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data);
WalkAST(be->arg1, preFunc, postFunc, data); be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data);
} }
else if ((ae = dynamic_cast<AssignExpr *>(node)) != NULL) { else if ((ae = dynamic_cast<AssignExpr *>(node)) != NULL) {
WalkAST(ae->lvalue, preFunc, postFunc, data); ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data);
WalkAST(ae->rvalue, preFunc, postFunc, data); ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data);
} }
else if ((se = dynamic_cast<SelectExpr *>(node)) != NULL) { else if ((se = dynamic_cast<SelectExpr *>(node)) != NULL) {
WalkAST(se->test, preFunc, postFunc, data); se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data);
WalkAST(se->expr1, preFunc, postFunc, data); se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data);
WalkAST(se->expr2, preFunc, postFunc, data); se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data);
} }
else if ((el = dynamic_cast<ExprList *>(node)) != NULL) { else if ((el = dynamic_cast<ExprList *>(node)) != NULL) {
for (unsigned int i = 0; i < el->exprs.size(); ++i) for (unsigned int i = 0; i < el->exprs.size(); ++i)
WalkAST(el->exprs[i], preFunc, postFunc, data); el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc,
postFunc, data);
} }
else if ((fce = dynamic_cast<FunctionCallExpr *>(node)) != NULL) { else if ((fce = dynamic_cast<FunctionCallExpr *>(node)) != NULL) {
WalkAST(fce->func, preFunc, postFunc, data); fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data);
WalkAST(fce->args, preFunc, postFunc, data); fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data);
WalkAST(fce->launchCountExpr, preFunc, postFunc, data); fce->launchCountExpr = (Expr *)WalkAST(fce->launchCountExpr, preFunc,
postFunc, data);
} }
else if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL) { else if ((ie = dynamic_cast<IndexExpr *>(node)) != NULL) {
WalkAST(ie->baseExpr, preFunc, postFunc, data); ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data);
WalkAST(ie->index, preFunc, postFunc, data); ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data);
} }
else if ((me = dynamic_cast<MemberExpr *>(node)) != NULL) else if ((me = dynamic_cast<MemberExpr *>(node)) != NULL)
WalkAST(me->expr, preFunc, postFunc, data); me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data);
else if ((tce = dynamic_cast<TypeCastExpr *>(node)) != NULL) else if ((tce = dynamic_cast<TypeCastExpr *>(node)) != NULL)
return WalkAST(tce->expr, preFunc, postFunc, data); tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data);
else if ((re = dynamic_cast<ReferenceExpr *>(node)) != NULL) else if ((re = dynamic_cast<ReferenceExpr *>(node)) != NULL)
return WalkAST(re->expr, preFunc, postFunc, data); re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data);
else if ((dre = dynamic_cast<DereferenceExpr *>(node)) != NULL) else if ((dre = dynamic_cast<DereferenceExpr *>(node)) != NULL)
return WalkAST(dre->expr, preFunc, postFunc, data); dre->expr = (Expr *)WalkAST(dre->expr, preFunc, postFunc, data);
else if ((soe = dynamic_cast<SizeOfExpr *>(node)) != NULL) else if ((soe = dynamic_cast<SizeOfExpr *>(node)) != NULL)
return WalkAST(soe->expr, preFunc, postFunc, data); soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data);
else if ((aoe = dynamic_cast<AddressOfExpr *>(node)) != NULL) else if ((aoe = dynamic_cast<AddressOfExpr *>(node)) != NULL)
WalkAST(aoe->expr, preFunc, postFunc, data); aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data);
else if (dynamic_cast<SymbolExpr *>(node) != NULL || else if (dynamic_cast<SymbolExpr *>(node) != NULL ||
dynamic_cast<ConstExpr *>(node) != NULL || dynamic_cast<ConstExpr *>(node) != NULL ||
dynamic_cast<FunctionSymbolExpr *>(node) != NULL || dynamic_cast<FunctionSymbolExpr *>(node) != NULL ||
@@ -212,5 +221,7 @@ WalkAST(ASTNode *node, ASTCallBackFunc preFunc, ASTCallBackFunc postFunc,
// Call the callback function // Call the callback function
if (postFunc != NULL) if (postFunc != NULL)
postFunc(node, data); return postFunc(node, data);
else
return node;
} }

22
ast.h
View File

@@ -92,18 +92,24 @@ private:
}; };
/** Callback function type for the AST walk. /** Callback function type for preorder traversial visiting function for
the AST walk.
*/ */
typedef bool (* ASTCallBackFunc)(ASTNode *node, void *data); typedef bool (* ASTPreCallBackFunc)(ASTNode *node, void *data);
/** Callback function type for postorder traversial visiting function for
the AST walk.
*/
typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data);
/** Walk (some portion of) an AST, starting from the given root node. At /** Walk (some portion of) an AST, starting from the given root node. At
each node, if preFunc is non-NULL, call it, passing the given void each node, if preFunc is non-NULL, call it, passing the given void
*data pointer; if the call to preFunc function returns false, then the *data pointer; if the call to preFunc function returns false, then the
children of the node aren't visited. This then makes recursive calls children of the node aren't visited. This function then makes
to WalkAST() to process the node's children; after doing so, calls recursive calls to WalkAST() to process the node's children; after
postFunc, at the node. The return value from the postFunc call is doing so, calls postFunc, at the node. The return value from the
ignored. */ postFunc call is ignored. */
extern void WalkAST(ASTNode *root, ASTCallBackFunc preFunc, extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc,
ASTCallBackFunc postFunc, void *data); ASTPostCallBackFunc postFunc, void *data);
#endif // ISPC_AST_H #endif // ISPC_AST_H

View File

@@ -387,9 +387,8 @@ Function::GenerateIR() {
SourcePos firstStmtPos = sym->pos; SourcePos firstStmtPos = sym->pos;
if (code) { if (code) {
StmtList *sl = dynamic_cast<StmtList *>(code); StmtList *sl = dynamic_cast<StmtList *>(code);
if (sl && sl->GetStatements().size() > 0 && if (sl && sl->stmts.size() > 0 && sl->stmts[0] != NULL)
sl->GetStatements()[0] != NULL) firstStmtPos = sl->stmts[0]->pos;
firstStmtPos = sl->GetStatements()[0]->pos;
else else
firstStmtPos = code->pos; firstStmtPos = code->pos;
} }

View File

@@ -997,12 +997,12 @@ lVaryingBCPreFunc(ASTNode *node, void *d) {
/** Postorder callback function for checking for varying breaks or /** Postorder callback function for checking for varying breaks or
continues; decrement the varying control flow depth after the node's continues; decrement the varying control flow depth after the node's
children have been processed, if this is a varying if statement. */ children have been processed, if this is a varying if statement. */
static bool static ASTNode *
lVaryingBCPostFunc(ASTNode *node, void *d) { lVaryingBCPostFunc(ASTNode *node, void *d) {
VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d; VaryingBCCheckInfo *info = (VaryingBCCheckInfo *)d;
if (lIsVaryingFor(node)) if (lIsVaryingFor(node))
--info->varyingControlFlowDepth; --info->varyingControlFlowDepth;
return true; return node;
} }

2
stmt.h
View File

@@ -302,9 +302,7 @@ public:
int EstimateCost() const; int EstimateCost() const;
void Add(Stmt *s) { if (s) stmts.push_back(s); } void Add(Stmt *s) { if (s) stmts.push_back(s); }
const std::vector<Stmt *> &GetStatements() { return stmts; }
private:
std::vector<Stmt *> stmts; std::vector<Stmt *> stmts;
}; };