diff --git a/ast.cpp b/ast.cpp index 5a8bec8a..b375f333 100644 --- a/ast.cpp +++ b/ast.cpp @@ -225,3 +225,27 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, else return node; } + + +static ASTNode * +lOptimizeNode(ASTNode *node, void *) { + return node->Optimize(); +} + + +ASTNode * +Optimize(ASTNode *root) { + return WalkAST(root, NULL, lOptimizeNode, NULL); +} + + +Expr * +Optimize(Expr *expr) { + return (Expr *)Optimize((ASTNode *)expr); +} + + +Stmt * +Optimize(Stmt *stmt) { + return (Stmt *)Optimize((ASTNode *)stmt); +} diff --git a/ast.h b/ast.h index 8a24aa92..161f4cde 100644 --- a/ast.h +++ b/ast.h @@ -53,10 +53,11 @@ public: virtual ~ASTNode(); /** The Optimize() method should perform any appropriate early-stage - optimizations on the node (e.g. constant folding). The caller - should use the returned ASTNode * in place of the original node. - This method may return NULL if an error is encountered during - optimization. */ + optimizations on the node (e.g. constant folding). This method + will be called after the node's children have already been + optimized, and the caller will store the returned ASTNode * in + place of the original node. This method should return NULL if an + error is encountered during optimization. */ virtual ASTNode *Optimize() = 0; /** Type checking should be performed by the node when this method is @@ -112,4 +113,8 @@ typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data); extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, void *data); +extern Expr *Optimize(Expr *); +extern Stmt *Optimize(Stmt *); +extern ASTNode *Optimize(ASTNode *root); + #endif // ISPC_AST_H diff --git a/decl.cpp b/decl.cpp index 21736691..11a665de 100644 --- a/decl.cpp +++ b/decl.cpp @@ -382,7 +382,7 @@ Declarator::GetType(const Type *base, DeclSpecs *ds) const { if (decl->initExpr != NULL && (decl->initExpr = decl->initExpr->TypeCheck()) != NULL && - (decl->initExpr = decl->initExpr->Optimize()) != NULL && + (decl->initExpr = Optimize(decl->initExpr)) != NULL && (init = dynamic_cast(decl->initExpr)) == NULL) { Error(decl->initExpr->pos, "Default value for parameter " "\"%s\" must be a compile-time constant.", diff --git a/expr.cpp b/expr.cpp index b36e08db..359abb04 100644 --- a/expr.cpp +++ b/expr.cpp @@ -144,7 +144,7 @@ lArrayToPointer(Expr *expr) { Expr *addr = new AddressOfExpr(index, expr->pos); addr = addr->TypeCheck(); Assert(addr != NULL); - addr = addr->Optimize(); + addr = Optimize(addr); Assert(addr != NULL); return addr; } @@ -843,11 +843,6 @@ UnaryExpr::GetType() const { Expr * UnaryExpr::Optimize() { - if (!expr) - return NULL; - - expr = expr->Optimize(); - ConstExpr *constExpr = dynamic_cast(expr); // If the operand isn't a constant, then we can't do any optimization // here... @@ -1489,12 +1484,7 @@ lConstFoldBoolBinOp(BinaryExpr::Op op, const bool *v0, const bool *v1, Expr * BinaryExpr::Optimize() { - if (arg0 != NULL) - arg0 = arg0->Optimize(); - if (arg1 != NULL) - arg1 = arg1->Optimize(); - - if (!arg0 || !arg1) + if (arg0 == NULL || arg1 == NULL) return NULL; ConstExpr *constArg0 = dynamic_cast(arg0); @@ -1519,7 +1509,7 @@ BinaryExpr::Optimize() { e = e->TypeCheck(); if (e == NULL) return NULL; - return e->Optimize(); + return ::Optimize(e); } } @@ -1542,7 +1532,7 @@ BinaryExpr::Optimize() { rcpCall = rcpCall->TypeCheck(); if (rcpCall == NULL) return NULL; - rcpCall = rcpCall->Optimize(); + rcpCall = ::Optimize(rcpCall); if (rcpCall == NULL) return NULL; @@ -1550,7 +1540,7 @@ BinaryExpr::Optimize() { ret = ret->TypeCheck(); if (ret == NULL) return NULL; - return ret->Optimize(); + return ::Optimize(ret); } else Warning(pos, "rcp() not found from stdlib. Can't apply " @@ -2089,13 +2079,8 @@ AssignExpr::GetValue(FunctionEmitContext *ctx) const { Expr * AssignExpr::Optimize() { - if (lvalue) - lvalue = lvalue->Optimize(); - if (rvalue) - rvalue = rvalue->Optimize(); if (lvalue == NULL || rvalue == NULL) return NULL; - return this; } @@ -2412,15 +2397,8 @@ SelectExpr::GetType() const { Expr * SelectExpr::Optimize() { - if (test) - test = test->Optimize(); - if (expr1) - expr1 = expr1->Optimize(); - if (expr2) - expr2 = expr2->Optimize(); if (test == NULL || expr1 == NULL || expr2 == NULL) return NULL; - return this; } @@ -2650,16 +2628,8 @@ FunctionCallExpr::GetType() const { Expr * FunctionCallExpr::Optimize() { - if (func) - func = func->Optimize(); - if (args) - args = args->Optimize(); - if (launchCountExpr != NULL) - launchCountExpr = launchCountExpr->Optimize(); - - if (!func || !args) + if (func == NULL || args == NULL) return NULL; - return this; } @@ -2858,9 +2828,6 @@ ExprList::GetType() const { ExprList * ExprList::Optimize() { - for (unsigned int i = 0; i < exprs.size(); ++i) - if (exprs[i]) - exprs[i] = exprs[i]->Optimize(); return this; } @@ -3224,13 +3191,8 @@ IndexExpr::GetLValueType() const { Expr * IndexExpr::Optimize() { - if (baseExpr) - baseExpr = baseExpr->Optimize(); - if (index) - index = index->Optimize(); if (baseExpr == NULL || index == NULL) return NULL; - return this; } @@ -3787,8 +3749,6 @@ MemberExpr::TypeCheck() { Expr * MemberExpr::Optimize() { - if (expr) - expr = expr->Optimize(); return expr ? this : NULL; } @@ -5310,7 +5270,7 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const { arrayAsPtr = new TypeCastExpr(toPointerType, arrayAsPtr, false, pos); arrayAsPtr = arrayAsPtr->TypeCheck(); Assert(arrayAsPtr != NULL); - arrayAsPtr = arrayAsPtr->Optimize(); + arrayAsPtr = ::Optimize(arrayAsPtr); Assert(arrayAsPtr != NULL); } Assert(Type::EqualIgnoringConst(arrayAsPtr->GetType(), toPointerType)); @@ -5546,13 +5506,8 @@ TypeCastExpr::TypeCheck() { Expr * TypeCastExpr::Optimize() { - if (expr != NULL) - expr = expr->Optimize(); - if (expr == NULL) - return NULL; - ConstExpr *constExpr = dynamic_cast(expr); - if (!constExpr) + if (constExpr == NULL) // We can't do anything if this isn't a const expr return this; @@ -5736,11 +5691,8 @@ ReferenceExpr::GetLValueType() const { Expr * ReferenceExpr::Optimize() { - if (expr) - expr = expr->Optimize(); if (expr == NULL) return NULL; - return this; } @@ -5855,8 +5807,6 @@ DereferenceExpr::TypeCheck() { Expr * DereferenceExpr::Optimize() { - if (expr != NULL) - expr = expr->Optimize(); if (expr == NULL) return NULL; return this; @@ -5954,8 +5904,6 @@ AddressOfExpr::TypeCheck() { Expr * AddressOfExpr::Optimize() { - if (expr != NULL) - expr = expr->Optimize(); return this; } @@ -6024,8 +5972,6 @@ SizeOfExpr::TypeCheck() { Expr * SizeOfExpr::Optimize() { - if (expr != NULL) - expr = expr->Optimize(); return this; } diff --git a/func.cpp b/func.cpp index 641ff9e3..4dd5b341 100644 --- a/func.cpp +++ b/func.cpp @@ -92,7 +92,7 @@ Function::Function(Symbol *s, const std::vector &a, Stmt *c) { } if (code != NULL) { - code = code->Optimize(); + code = Optimize(code); if (g->debugPrint) { fprintf(stderr, "After optimizing function \"%s\":\n", sym->name.c_str()); diff --git a/module.cpp b/module.cpp index e15d73ae..107b601e 100644 --- a/module.cpp +++ b/module.cpp @@ -273,7 +273,7 @@ Module::AddGlobalVariable(Symbol *sym, Expr *initExpr, bool isConst) { initExpr = TypeConvertExpr(initExpr, sym->type, "initializer"); if (initExpr != NULL) { - initExpr = initExpr->Optimize(); + initExpr = Optimize(initExpr); // Fingers crossed, now let's see if we've got a // constant value.. llvmInitializer = initExpr->GetConstant(sym->type); diff --git a/parse.yy b/parse.yy index 2a919d40..552e63ec 100644 --- a/parse.yy +++ b/parse.yy @@ -1677,7 +1677,7 @@ lGetConstantInt(Expr *expr, int *value, SourcePos pos, const char *usage) { expr = expr->TypeCheck(); if (expr == NULL) return false; - expr = expr->Optimize(); + expr = Optimize(expr); if (expr == NULL) return false; @@ -1754,7 +1754,7 @@ lFinalizeEnumeratorSymbols(std::vector &enums, us end up with a ConstExpr with the desired EnumType... */ Expr *castExpr = new TypeCastExpr(enumType, enums[i]->constValue, false, enums[i]->pos); - castExpr = castExpr->Optimize(); + castExpr = Optimize(castExpr); enums[i]->constValue = dynamic_cast(castExpr); Assert(enums[i]->constValue != NULL); } diff --git a/stmt.cpp b/stmt.cpp index f8c994fa..63a3f803 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -58,6 +58,15 @@ #include #include +/////////////////////////////////////////////////////////////////////////// +// Stmt + +Stmt * +Stmt::Optimize() { + return this; +} + + /////////////////////////////////////////////////////////////////////////// // ExprStmt @@ -77,14 +86,6 @@ ExprStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -ExprStmt::Optimize() { - if (expr) - expr = expr->Optimize(); - return this; -} - - Stmt * ExprStmt::TypeCheck() { if (expr) @@ -345,7 +346,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const { // FIXME: and this is only needed to re-establish // constant-ness so that GetConstant below works for // constant artithmetic expressions... - initExpr = initExpr->Optimize(); + initExpr = ::Optimize(initExpr); } cinit = initExpr->GetConstant(sym->type); @@ -388,10 +389,8 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const { Stmt * DeclStmt::Optimize() { for (unsigned int i = 0; i < vars.size(); ++i) { - if (vars[i].init != NULL) { - vars[i].init = vars[i].init->Optimize(); - Expr *init = vars[i].init; - + Expr *init = vars[i].init; + if (init != NULL && dynamic_cast(init) == NULL) { // If the variable is const-qualified, after we've optimized // the initializer expression, see if we have a ConstExpr. If // so, save it in Symbol::constValue where it can be used in @@ -408,8 +407,7 @@ DeclStmt::Optimize() { // computing array sizes from non-trivial expressions is // consequently limited. Symbol *sym = vars[i].sym; - if (sym->type && sym->type->IsConstType() && init != NULL && - dynamic_cast(init) == NULL && + if (sym->type && sym->type->IsConstType() && Type::Equal(init->GetType(), sym->type)) sym->constValue = dynamic_cast(init); } @@ -566,18 +564,7 @@ IfStmt::EmitCode(FunctionEmitContext *ctx) const { Stmt * -IfStmt::Optimize() { - if (test != NULL) - test = test->Optimize(); - if (trueStmts != NULL) - trueStmts = trueStmts->Optimize(); - if (falseStmts != NULL) - falseStmts = falseStmts->Optimize(); - return this; -} - - -Stmt *IfStmt::TypeCheck() { +IfStmt::TypeCheck() { if (test != NULL) { test = test->TypeCheck(); if (test != NULL) { @@ -1133,16 +1120,6 @@ void DoStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -DoStmt::Optimize() { - if (testExpr) - testExpr = testExpr->Optimize(); - if (bodyStmts) - bodyStmts = bodyStmts->Optimize(); - return this; -} - - Stmt * DoStmt::TypeCheck() { if (testExpr) { @@ -1345,20 +1322,6 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -ForStmt::Optimize() { - if (test) - test = test->Optimize(); - if (init) - init = init->Optimize(); - if (step) - step = step->Optimize(); - if (stmts) - stmts = stmts->Optimize(); - return this; -} - - Stmt * ForStmt::TypeCheck() { if (test) { @@ -1450,12 +1413,6 @@ BreakStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -BreakStmt::Optimize() { - return this; -} - - Stmt * BreakStmt::TypeCheck() { return this; @@ -1495,12 +1452,6 @@ ContinueStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -ContinueStmt::Optimize() { - return this; -} - - Stmt * ContinueStmt::TypeCheck() { return this; @@ -1858,28 +1809,6 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -ForeachStmt::Optimize() { - bool anyErrors = false; - for (unsigned int i = 0; i < startExprs.size(); ++i) { - if (startExprs[i] != NULL) - startExprs[i] = startExprs[i]->Optimize(); - anyErrors |= (startExprs[i] == NULL); - } - for (unsigned int i = 0; i < endExprs.size(); ++i) { - if (endExprs[i] != NULL) - endExprs[i] = endExprs[i]->Optimize(); - anyErrors |= (endExprs[i] == NULL); - } - - if (stmts != NULL) - stmts = stmts->TypeCheck(); - anyErrors |= (stmts == NULL); - - return anyErrors ? NULL : this; -} - - Stmt * ForeachStmt::TypeCheck() { bool anyErrors = false; @@ -2007,14 +1936,6 @@ ReturnStmt::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -ReturnStmt::Optimize() { - if (val) - val = val->Optimize(); - return this; -} - - Stmt * ReturnStmt::TypeCheck() { // FIXME: We don't have ctx->functionType available here; should we? @@ -2059,15 +1980,6 @@ StmtList::EmitCode(FunctionEmitContext *ctx) const { } -Stmt * -StmtList::Optimize() { - for (unsigned int i = 0; i < stmts.size(); ++i) - if (stmts[i]) - stmts[i] = stmts[i]->Optimize(); - return this; -} - - Stmt * StmtList::TypeCheck() { for (unsigned int i = 0; i < stmts.size(); ++i) @@ -2280,14 +2192,6 @@ PrintStmt::Print(int indent) const { } -Stmt * -PrintStmt::Optimize() { - if (values) - values = values->Optimize(); - return this; -} - - Stmt * PrintStmt::TypeCheck() { if (values) @@ -2364,14 +2268,6 @@ AssertStmt::Print(int indent) const { } -Stmt * -AssertStmt::Optimize() { - if (expr) - expr = expr->Optimize(); - return this; -} - - Stmt * AssertStmt::TypeCheck() { if (expr) diff --git a/stmt.h b/stmt.h index e74ce125..73142197 100644 --- a/stmt.h +++ b/stmt.h @@ -60,8 +60,10 @@ public: virtual void Print(int indent) const = 0; // Redeclare these methods with Stmt * return values, rather than - // ASTNode *s, as in the original ASTNode declarations of them. - virtual Stmt *Optimize() = 0; + // ASTNode *s, as in the original ASTNode declarations of them. We'll + // also provide a default implementation of Optimize(), since most + // Stmts don't have anything to do here. + virtual Stmt *Optimize(); virtual Stmt *TypeCheck() = 0; }; @@ -74,7 +76,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -117,7 +118,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -158,7 +158,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -179,7 +178,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -206,7 +204,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -228,7 +225,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -253,7 +249,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -275,7 +270,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -297,7 +291,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -323,7 +316,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const; @@ -350,7 +342,6 @@ public: void EmitCode(FunctionEmitContext *ctx) const; void Print(int indent) const; - Stmt *Optimize(); Stmt *TypeCheck(); int EstimateCost() const;