diff --git a/ast.cpp b/ast.cpp index 7f601d04..445ef4c7 100644 --- a/ast.cpp +++ b/ast.cpp @@ -307,19 +307,39 @@ TypeCheck(Stmt *stmt) { } +struct CostData { + CostData() { cost = foreachDepth = 0; } + + int cost; + int foreachDepth; +}; + + static bool -lCostCallback(ASTNode *node, void *c) { - int *cost = (int *)c; - *cost += node->EstimateCost(); +lCostCallbackPre(ASTNode *node, void *d) { + CostData *data = (CostData *)d; + if (dynamic_cast(node) != NULL) + ++data->foreachDepth; + if (data->foreachDepth == 0) + data->cost += node->EstimateCost(); return true; } +static ASTNode * +lCostCallbackPost(ASTNode *node, void *d) { + CostData *data = (CostData *)d; + if (dynamic_cast(node) != NULL) + --data->foreachDepth; + return node; +} + + int EstimateCost(ASTNode *root) { - int cost = 0; - WalkAST(root, lCostCallback, NULL, &cost); - return cost; + CostData data; + WalkAST(root, lCostCallbackPre, lCostCallbackPost, &data); + return data.cost; }