Rewrite AST optimization infrastructure to be built on top of WalkAST().

Specifically, stmts and exprs are no longer responsible for first recursively
optimizing their children before doing their own optimization (this turned
out to be error-prone, with children sometimes being forgotten.)  They now
are just responsible for their own optimization, when appropriate.
This commit is contained in:
Matt Pharr
2011-12-16 11:35:18 -08:00
parent ced3f1f5fc
commit f48a662ed3
9 changed files with 64 additions and 202 deletions

24
ast.cpp
View File

@@ -225,3 +225,27 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
else
return node;
}
static ASTNode *
lOptimizeNode(ASTNode *node, void *) {
return node->Optimize();
}
ASTNode *
Optimize(ASTNode *root) {
return WalkAST(root, NULL, lOptimizeNode, NULL);
}
Expr *
Optimize(Expr *expr) {
return (Expr *)Optimize((ASTNode *)expr);
}
Stmt *
Optimize(Stmt *stmt) {
return (Stmt *)Optimize((ASTNode *)stmt);
}

13
ast.h
View File

@@ -53,10 +53,11 @@ public:
virtual ~ASTNode();
/** The Optimize() method should perform any appropriate early-stage
optimizations on the node (e.g. constant folding). The caller
should use the returned ASTNode * in place of the original node.
This method may return NULL if an error is encountered during
optimization. */
optimizations on the node (e.g. constant folding). This method
will be called after the node's children have already been
optimized, and the caller will store the returned ASTNode * in
place of the original node. This method should return NULL if an
error is encountered during optimization. */
virtual ASTNode *Optimize() = 0;
/** Type checking should be performed by the node when this method is
@@ -112,4 +113,8 @@ typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data);
extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc,
ASTPostCallBackFunc postFunc, void *data);
extern Expr *Optimize(Expr *);
extern Stmt *Optimize(Stmt *);
extern ASTNode *Optimize(ASTNode *root);
#endif // ISPC_AST_H

View File

@@ -382,7 +382,7 @@ Declarator::GetType(const Type *base, DeclSpecs *ds) const {
if (decl->initExpr != NULL &&
(decl->initExpr = decl->initExpr->TypeCheck()) != NULL &&
(decl->initExpr = decl->initExpr->Optimize()) != NULL &&
(decl->initExpr = Optimize(decl->initExpr)) != NULL &&
(init = dynamic_cast<ConstExpr *>(decl->initExpr)) == NULL) {
Error(decl->initExpr->pos, "Default value for parameter "
"\"%s\" must be a compile-time constant.",

View File

@@ -144,7 +144,7 @@ lArrayToPointer(Expr *expr) {
Expr *addr = new AddressOfExpr(index, expr->pos);
addr = addr->TypeCheck();
Assert(addr != NULL);
addr = addr->Optimize();
addr = Optimize(addr);
Assert(addr != NULL);
return addr;
}
@@ -843,11 +843,6 @@ UnaryExpr::GetType() const {
Expr *
UnaryExpr::Optimize() {
if (!expr)
return NULL;
expr = expr->Optimize();
ConstExpr *constExpr = dynamic_cast<ConstExpr *>(expr);
// If the operand isn't a constant, then we can't do any optimization
// here...
@@ -1489,12 +1484,7 @@ lConstFoldBoolBinOp(BinaryExpr::Op op, const bool *v0, const bool *v1,
Expr *
BinaryExpr::Optimize() {
if (arg0 != NULL)
arg0 = arg0->Optimize();
if (arg1 != NULL)
arg1 = arg1->Optimize();
if (!arg0 || !arg1)
if (arg0 == NULL || arg1 == NULL)
return NULL;
ConstExpr *constArg0 = dynamic_cast<ConstExpr *>(arg0);
@@ -1519,7 +1509,7 @@ BinaryExpr::Optimize() {
e = e->TypeCheck();
if (e == NULL)
return NULL;
return e->Optimize();
return ::Optimize(e);
}
}
@@ -1542,7 +1532,7 @@ BinaryExpr::Optimize() {
rcpCall = rcpCall->TypeCheck();
if (rcpCall == NULL)
return NULL;
rcpCall = rcpCall->Optimize();
rcpCall = ::Optimize(rcpCall);
if (rcpCall == NULL)
return NULL;
@@ -1550,7 +1540,7 @@ BinaryExpr::Optimize() {
ret = ret->TypeCheck();
if (ret == NULL)
return NULL;
return ret->Optimize();
return ::Optimize(ret);
}
else
Warning(pos, "rcp() not found from stdlib. Can't apply "
@@ -2089,13 +2079,8 @@ AssignExpr::GetValue(FunctionEmitContext *ctx) const {
Expr *
AssignExpr::Optimize() {
if (lvalue)
lvalue = lvalue->Optimize();
if (rvalue)
rvalue = rvalue->Optimize();
if (lvalue == NULL || rvalue == NULL)
return NULL;
return this;
}
@@ -2412,15 +2397,8 @@ SelectExpr::GetType() const {
Expr *
SelectExpr::Optimize() {
if (test)
test = test->Optimize();
if (expr1)
expr1 = expr1->Optimize();
if (expr2)
expr2 = expr2->Optimize();
if (test == NULL || expr1 == NULL || expr2 == NULL)
return NULL;
return this;
}
@@ -2650,16 +2628,8 @@ FunctionCallExpr::GetType() const {
Expr *
FunctionCallExpr::Optimize() {
if (func)
func = func->Optimize();
if (args)
args = args->Optimize();
if (launchCountExpr != NULL)
launchCountExpr = launchCountExpr->Optimize();
if (!func || !args)
if (func == NULL || args == NULL)
return NULL;
return this;
}
@@ -2858,9 +2828,6 @@ ExprList::GetType() const {
ExprList *
ExprList::Optimize() {
for (unsigned int i = 0; i < exprs.size(); ++i)
if (exprs[i])
exprs[i] = exprs[i]->Optimize();
return this;
}
@@ -3224,13 +3191,8 @@ IndexExpr::GetLValueType() const {
Expr *
IndexExpr::Optimize() {
if (baseExpr)
baseExpr = baseExpr->Optimize();
if (index)
index = index->Optimize();
if (baseExpr == NULL || index == NULL)
return NULL;
return this;
}
@@ -3787,8 +3749,6 @@ MemberExpr::TypeCheck() {
Expr *
MemberExpr::Optimize() {
if (expr)
expr = expr->Optimize();
return expr ? this : NULL;
}
@@ -5310,7 +5270,7 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const {
arrayAsPtr = new TypeCastExpr(toPointerType, arrayAsPtr, false, pos);
arrayAsPtr = arrayAsPtr->TypeCheck();
Assert(arrayAsPtr != NULL);
arrayAsPtr = arrayAsPtr->Optimize();
arrayAsPtr = ::Optimize(arrayAsPtr);
Assert(arrayAsPtr != NULL);
}
Assert(Type::EqualIgnoringConst(arrayAsPtr->GetType(), toPointerType));
@@ -5546,13 +5506,8 @@ TypeCastExpr::TypeCheck() {
Expr *
TypeCastExpr::Optimize() {
if (expr != NULL)
expr = expr->Optimize();
if (expr == NULL)
return NULL;
ConstExpr *constExpr = dynamic_cast<ConstExpr *>(expr);
if (!constExpr)
if (constExpr == NULL)
// We can't do anything if this isn't a const expr
return this;
@@ -5736,11 +5691,8 @@ ReferenceExpr::GetLValueType() const {
Expr *
ReferenceExpr::Optimize() {
if (expr)
expr = expr->Optimize();
if (expr == NULL)
return NULL;
return this;
}
@@ -5855,8 +5807,6 @@ DereferenceExpr::TypeCheck() {
Expr *
DereferenceExpr::Optimize() {
if (expr != NULL)
expr = expr->Optimize();
if (expr == NULL)
return NULL;
return this;
@@ -5954,8 +5904,6 @@ AddressOfExpr::TypeCheck() {
Expr *
AddressOfExpr::Optimize() {
if (expr != NULL)
expr = expr->Optimize();
return this;
}
@@ -6024,8 +5972,6 @@ SizeOfExpr::TypeCheck() {
Expr *
SizeOfExpr::Optimize() {
if (expr != NULL)
expr = expr->Optimize();
return this;
}

View File

@@ -92,7 +92,7 @@ Function::Function(Symbol *s, const std::vector<Symbol *> &a, Stmt *c) {
}
if (code != NULL) {
code = code->Optimize();
code = Optimize(code);
if (g->debugPrint) {
fprintf(stderr, "After optimizing function \"%s\":\n",
sym->name.c_str());

View File

@@ -273,7 +273,7 @@ Module::AddGlobalVariable(Symbol *sym, Expr *initExpr, bool isConst) {
initExpr = TypeConvertExpr(initExpr, sym->type, "initializer");
if (initExpr != NULL) {
initExpr = initExpr->Optimize();
initExpr = Optimize(initExpr);
// Fingers crossed, now let's see if we've got a
// constant value..
llvmInitializer = initExpr->GetConstant(sym->type);

View File

@@ -1677,7 +1677,7 @@ lGetConstantInt(Expr *expr, int *value, SourcePos pos, const char *usage) {
expr = expr->TypeCheck();
if (expr == NULL)
return false;
expr = expr->Optimize();
expr = Optimize(expr);
if (expr == NULL)
return false;
@@ -1754,7 +1754,7 @@ lFinalizeEnumeratorSymbols(std::vector<Symbol *> &enums,
us end up with a ConstExpr with the desired EnumType... */
Expr *castExpr = new TypeCastExpr(enumType, enums[i]->constValue,
false, enums[i]->pos);
castExpr = castExpr->Optimize();
castExpr = Optimize(castExpr);
enums[i]->constValue = dynamic_cast<ConstExpr *>(castExpr);
Assert(enums[i]->constValue != NULL);
}

130
stmt.cpp
View File

@@ -58,6 +58,15 @@
#include <llvm/Support/IRBuilder.h>
#include <llvm/Support/raw_ostream.h>
///////////////////////////////////////////////////////////////////////////
// Stmt
Stmt *
Stmt::Optimize() {
return this;
}
///////////////////////////////////////////////////////////////////////////
// ExprStmt
@@ -77,14 +86,6 @@ ExprStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
ExprStmt::Optimize() {
if (expr)
expr = expr->Optimize();
return this;
}
Stmt *
ExprStmt::TypeCheck() {
if (expr)
@@ -345,7 +346,7 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
// FIXME: and this is only needed to re-establish
// constant-ness so that GetConstant below works for
// constant artithmetic expressions...
initExpr = initExpr->Optimize();
initExpr = ::Optimize(initExpr);
}
cinit = initExpr->GetConstant(sym->type);
@@ -388,10 +389,8 @@ DeclStmt::EmitCode(FunctionEmitContext *ctx) const {
Stmt *
DeclStmt::Optimize() {
for (unsigned int i = 0; i < vars.size(); ++i) {
if (vars[i].init != NULL) {
vars[i].init = vars[i].init->Optimize();
Expr *init = vars[i].init;
if (init != NULL && dynamic_cast<ExprList *>(init) == NULL) {
// If the variable is const-qualified, after we've optimized
// the initializer expression, see if we have a ConstExpr. If
// so, save it in Symbol::constValue where it can be used in
@@ -408,8 +407,7 @@ DeclStmt::Optimize() {
// computing array sizes from non-trivial expressions is
// consequently limited.
Symbol *sym = vars[i].sym;
if (sym->type && sym->type->IsConstType() && init != NULL &&
dynamic_cast<ExprList *>(init) == NULL &&
if (sym->type && sym->type->IsConstType() &&
Type::Equal(init->GetType(), sym->type))
sym->constValue = dynamic_cast<ConstExpr *>(init);
}
@@ -566,18 +564,7 @@ IfStmt::EmitCode(FunctionEmitContext *ctx) const {
Stmt *
IfStmt::Optimize() {
if (test != NULL)
test = test->Optimize();
if (trueStmts != NULL)
trueStmts = trueStmts->Optimize();
if (falseStmts != NULL)
falseStmts = falseStmts->Optimize();
return this;
}
Stmt *IfStmt::TypeCheck() {
IfStmt::TypeCheck() {
if (test != NULL) {
test = test->TypeCheck();
if (test != NULL) {
@@ -1133,16 +1120,6 @@ void DoStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
DoStmt::Optimize() {
if (testExpr)
testExpr = testExpr->Optimize();
if (bodyStmts)
bodyStmts = bodyStmts->Optimize();
return this;
}
Stmt *
DoStmt::TypeCheck() {
if (testExpr) {
@@ -1345,20 +1322,6 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
ForStmt::Optimize() {
if (test)
test = test->Optimize();
if (init)
init = init->Optimize();
if (step)
step = step->Optimize();
if (stmts)
stmts = stmts->Optimize();
return this;
}
Stmt *
ForStmt::TypeCheck() {
if (test) {
@@ -1450,12 +1413,6 @@ BreakStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
BreakStmt::Optimize() {
return this;
}
Stmt *
BreakStmt::TypeCheck() {
return this;
@@ -1495,12 +1452,6 @@ ContinueStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
ContinueStmt::Optimize() {
return this;
}
Stmt *
ContinueStmt::TypeCheck() {
return this;
@@ -1858,28 +1809,6 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
ForeachStmt::Optimize() {
bool anyErrors = false;
for (unsigned int i = 0; i < startExprs.size(); ++i) {
if (startExprs[i] != NULL)
startExprs[i] = startExprs[i]->Optimize();
anyErrors |= (startExprs[i] == NULL);
}
for (unsigned int i = 0; i < endExprs.size(); ++i) {
if (endExprs[i] != NULL)
endExprs[i] = endExprs[i]->Optimize();
anyErrors |= (endExprs[i] == NULL);
}
if (stmts != NULL)
stmts = stmts->TypeCheck();
anyErrors |= (stmts == NULL);
return anyErrors ? NULL : this;
}
Stmt *
ForeachStmt::TypeCheck() {
bool anyErrors = false;
@@ -2007,14 +1936,6 @@ ReturnStmt::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
ReturnStmt::Optimize() {
if (val)
val = val->Optimize();
return this;
}
Stmt *
ReturnStmt::TypeCheck() {
// FIXME: We don't have ctx->functionType available here; should we?
@@ -2059,15 +1980,6 @@ StmtList::EmitCode(FunctionEmitContext *ctx) const {
}
Stmt *
StmtList::Optimize() {
for (unsigned int i = 0; i < stmts.size(); ++i)
if (stmts[i])
stmts[i] = stmts[i]->Optimize();
return this;
}
Stmt *
StmtList::TypeCheck() {
for (unsigned int i = 0; i < stmts.size(); ++i)
@@ -2280,14 +2192,6 @@ PrintStmt::Print(int indent) const {
}
Stmt *
PrintStmt::Optimize() {
if (values)
values = values->Optimize();
return this;
}
Stmt *
PrintStmt::TypeCheck() {
if (values)
@@ -2364,14 +2268,6 @@ AssertStmt::Print(int indent) const {
}
Stmt *
AssertStmt::Optimize() {
if (expr)
expr = expr->Optimize();
return this;
}
Stmt *
AssertStmt::TypeCheck() {
if (expr)

17
stmt.h
View File

@@ -60,8 +60,10 @@ public:
virtual void Print(int indent) const = 0;
// Redeclare these methods with Stmt * return values, rather than
// ASTNode *s, as in the original ASTNode declarations of them.
virtual Stmt *Optimize() = 0;
// ASTNode *s, as in the original ASTNode declarations of them. We'll
// also provide a default implementation of Optimize(), since most
// Stmts don't have anything to do here.
virtual Stmt *Optimize();
virtual Stmt *TypeCheck() = 0;
};
@@ -74,7 +76,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -117,7 +118,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -158,7 +158,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -179,7 +178,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -206,7 +204,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -228,7 +225,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -253,7 +249,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -275,7 +270,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -297,7 +291,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -323,7 +316,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;
@@ -350,7 +342,6 @@ public:
void EmitCode(FunctionEmitContext *ctx) const;
void Print(int indent) const;
Stmt *Optimize();
Stmt *TypeCheck();
int EstimateCost() const;