Transition EstimateCost() AST traversal to WalkAST() as well.

This commit is contained in:
Matt Pharr
2011-12-16 12:16:11 -08:00
parent 701334ccf2
commit f23d030e43
5 changed files with 61 additions and 69 deletions

16
ast.cpp
View File

@@ -273,3 +273,19 @@ Stmt *
TypeCheck(Stmt *stmt) { TypeCheck(Stmt *stmt) {
return (Stmt *)TypeCheck((ASTNode *)stmt); return (Stmt *)TypeCheck((ASTNode *)stmt);
} }
static bool
lCostCallback(ASTNode *node, void *c) {
int *cost = (int *)c;
*cost += node->EstimateCost();
return true;
}
int
EstimateCost(ASTNode *root) {
int cost = 0;
WalkAST(root, lCostCallback, NULL, &cost);
return cost;
}

15
ast.h
View File

@@ -66,6 +66,9 @@ public:
pointer in place of the original ASTNode *. */ pointer in place of the original ASTNode *. */
virtual ASTNode *TypeCheck() = 0; virtual ASTNode *TypeCheck() = 0;
/** Estimate the execution cost of the node (not including the cost of
the children. The value returned should be based on the COST_*
enumerant values defined in ispc.h. */
virtual int EstimateCost() const = 0; virtual int EstimateCost() const = 0;
/** All AST nodes must track the file position where they are /** All AST nodes must track the file position where they are
@@ -127,14 +130,18 @@ extern Expr *Optimize(Expr *);
to a Stmt *). */ to a Stmt *). */
extern Stmt *Optimize(Stmt *); extern Stmt *Optimize(Stmt *);
/* Perform type-checking on the given AST (or portion of one), returning a /** Perform type-checking on the given AST (or portion of one), returning a
pointer to the root of the resulting AST. */ pointer to the root of the resulting AST. */
extern ASTNode *TypeCheck(ASTNode *root); extern ASTNode *TypeCheck(ASTNode *root);
/* Convenience version of TypeCheck() for Expr *s that returns an Expr *. */ /** Convenience version of TypeCheck() for Expr *s that returns an Expr *. */
extern Expr *TypeCheck(Expr *); extern Expr *TypeCheck(Expr *);
/* Convenience version of TypeCheck() for Stmt *s that returns an Stmt *. */ /** Convenience version of TypeCheck() for Stmt *s that returns an Stmt *. */
extern Stmt *TypeCheck(Stmt *); extern Stmt *TypeCheck(Stmt *);
/** Returns an estimate of the execution cost of the tree starting at
the given root. */
extern int EstimateCost(ASTNode *root);
#endif // ISPC_AST_H #endif // ISPC_AST_H

View File

@@ -989,7 +989,7 @@ UnaryExpr::TypeCheck() {
int int
UnaryExpr::EstimateCost() const { UnaryExpr::EstimateCost() const {
return (expr ? expr->EstimateCost() : 0) + COST_SIMPLE_ARITH_LOGIC_OP; return COST_SIMPLE_ARITH_LOGIC_OP;
} }
@@ -1885,10 +1885,8 @@ BinaryExpr::TypeCheck() {
int int
BinaryExpr::EstimateCost() const { BinaryExpr::EstimateCost() const {
return ((arg0 ? arg0->EstimateCost() : 0) + return (op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP :
(arg1 ? arg1->EstimateCost() : 0) + COST_SIMPLE_ARITH_LOGIC_OP;
((op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP :
COST_SIMPLE_ARITH_LOGIC_OP));
} }
@@ -2204,15 +2202,12 @@ AssignExpr::TypeCheck() {
int int
AssignExpr::EstimateCost() const { AssignExpr::EstimateCost() const {
int cost = ((lvalue ? lvalue->EstimateCost() : 0) +
(rvalue ? rvalue->EstimateCost() : 0));
cost += COST_ASSIGN;
if (op == Assign) if (op == Assign)
return cost; return COST_ASSIGN;
if (op == DivAssign || op == ModAssign) if (op == DivAssign || op == ModAssign)
return cost + COST_COMPLEX_ARITH_OP; return COST_ASSIGN + COST_COMPLEX_ARITH_OP;
else else
return cost + COST_SIMPLE_ARITH_LOGIC_OP; return COST_ASSIGN + COST_SIMPLE_ARITH_LOGIC_OP;
} }
@@ -2637,7 +2632,7 @@ FunctionCallExpr::TypeCheck() {
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false) if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false)
return NULL; return NULL;
func = fse->TypeCheck(); func = ::TypeCheck(fse);
if (func == NULL) if (func == NULL)
return NULL; return NULL;
@@ -2742,28 +2737,20 @@ FunctionCallExpr::TypeCheck() {
int int
FunctionCallExpr::EstimateCost() const { FunctionCallExpr::EstimateCost() const {
int callCost = 0; if (isLaunch)
if (isLaunch) { return COST_TASK_LAUNCH;
callCost = COST_TASK_LAUNCH;
if (launchCountExpr != NULL)
callCost += launchCountExpr->EstimateCost();
}
else if (dynamic_cast<FunctionSymbolExpr *>(func) == NULL) { else if (dynamic_cast<FunctionSymbolExpr *>(func) == NULL) {
// it's going through a function pointer // it's going through a function pointer
const Type *fpType = func->GetType(); const Type *fpType = func->GetType();
if (fpType != NULL) { if (fpType != NULL) {
Assert(dynamic_cast<const PointerType *>(fpType) != NULL); Assert(dynamic_cast<const PointerType *>(fpType) != NULL);
if (fpType->IsUniformType()) if (fpType->IsUniformType())
callCost = COST_FUNPTR_UNIFORM; return COST_FUNPTR_UNIFORM;
else else
callCost = COST_FUNPTR_VARYING; return COST_FUNPTR_VARYING;
} }
} }
else return COST_FUNCALL;
// regular function call
callCost = COST_FUNCALL;
return (args ? args->EstimateCost() : 0) + callCost;
} }
@@ -2880,12 +2867,7 @@ ExprList::GetConstant(const Type *type) const {
int int
ExprList::EstimateCost() const { ExprList::EstimateCost() const {
int cost = 0; return 0;
for (unsigned int i = 0; i < exprs.size(); ++i) {
if (exprs[i] != NULL)
cost += exprs[i]->EstimateCost();
}
return cost;
} }

View File

@@ -277,7 +277,7 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function,
ctx->SetDebugPos(code->pos); ctx->SetDebugPos(code->pos);
ctx->AddInstrumentationPoint("function entry"); ctx->AddInstrumentationPoint("function entry");
int costEstimate = code->EstimateCost(); int costEstimate = EstimateCost(code);
Debug(code->pos, "Estimated cost for function \"%s\" = %d\n", Debug(code->pos, "Estimated cost for function \"%s\" = %d\n",
sym->name.c_str(), costEstimate); sym->name.c_str(), costEstimate);

View File

@@ -107,7 +107,7 @@ ExprStmt::Print(int indent) const {
int int
ExprStmt::EstimateCost() const { ExprStmt::EstimateCost() const {
return expr ? expr->EstimateCost() : 0; return 0;
} }
@@ -467,11 +467,7 @@ DeclStmt::Print(int indent) const {
int int
DeclStmt::EstimateCost() const { DeclStmt::EstimateCost() const {
int cost = 0; return 0;
for (unsigned int i = 0; i < vars.size(); ++i)
if (vars[i].init != NULL)
cost += vars[i].init->EstimateCost();
return cost;
} }
@@ -577,15 +573,11 @@ IfStmt::TypeCheck() {
int int
IfStmt::EstimateCost() const { IfStmt::EstimateCost() const {
int ifcost = 0;
const Type *type; const Type *type;
if (test && (type = test->GetType()) != NULL) if (test == NULL || (type = test->GetType()) != NULL)
ifcost = type->IsUniformType() ? COST_UNIFORM_IF : COST_VARYING_IF; return 0;
return ifcost + return type->IsUniformType() ? COST_UNIFORM_IF : COST_VARYING_IF;
((test ? test->EstimateCost() : 0) +
(trueStmts ? trueStmts->EstimateCost() : 0) +
(falseStmts ? falseStmts->EstimateCost() : 0));
} }
@@ -766,9 +758,10 @@ IfStmt::emitVaryingIf(FunctionEmitContext *ctx, llvm::Value *ltest) const {
// //
// where our use of blend for conditional assignments doesn't check // where our use of blend for conditional assignments doesn't check
// for the 'all lanes' off case. // for the 'all lanes' off case.
bool costIsAcceptable = ((trueStmts ? trueStmts->EstimateCost() : 0) + int trueFalseCost = (::EstimateCost(trueStmts) +
(falseStmts ? falseStmts->EstimateCost() : 0)) < ::EstimateCost(falseStmts));
PREDICATE_SAFE_IF_STATEMENT_COST; bool costIsAcceptable = (trueFalseCost <
PREDICATE_SAFE_IF_STATEMENT_COST);
bool safeToRunWithAllLanesOff = true; bool safeToRunWithAllLanesOff = true;
WalkAST(trueStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff); WalkAST(trueStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff);
@@ -1146,8 +1139,11 @@ DoStmt::TypeCheck() {
int int
DoStmt::EstimateCost() const { DoStmt::EstimateCost() const {
return ((testExpr ? testExpr->EstimateCost() : 0) + bool uniformTest = testExpr ? testExpr->GetType()->IsUniformType() :
(bodyStmts ? bodyStmts->EstimateCost() : 0)); (!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(bodyStmts));
return uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP;
} }
@@ -1336,11 +1332,7 @@ ForStmt::EstimateCost() const {
(!g->opt.disableUniformControlFlow && (!g->opt.disableUniformControlFlow &&
!lHasVaryingBreakOrContinue(stmts)); !lHasVaryingBreakOrContinue(stmts));
return ((init ? init->EstimateCost() : 0) + return uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP;
(test ? test->EstimateCost() : 0) +
(step ? step->EstimateCost() : 0) +
(stmts ? stmts->EstimateCost() : 0) +
(uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP));
} }
@@ -1827,8 +1819,7 @@ ForeachStmt::TypeCheck() {
int int
ForeachStmt::EstimateCost() const { ForeachStmt::EstimateCost() const {
return dimVariables.size() * (COST_UNIFORM_LOOP + COST_SIMPLE_ARITH_LOGIC_OP) + return dimVariables.size() * (COST_UNIFORM_LOOP + COST_SIMPLE_ARITH_LOGIC_OP);
(stmts ? stmts->EstimateCost() : 0);
} }
@@ -1908,7 +1899,7 @@ ReturnStmt::TypeCheck() {
int int
ReturnStmt::EstimateCost() const { ReturnStmt::EstimateCost() const {
return COST_RETURN + (val ? val->EstimateCost() : 0); return COST_RETURN;
} }
@@ -1947,11 +1938,7 @@ StmtList::TypeCheck() {
int int
StmtList::EstimateCost() const { StmtList::EstimateCost() const {
int cost = 0; return 0;
for (unsigned int i = 0; i < stmts.size(); ++i)
if (stmts[i])
cost += stmts[i]->EstimateCost();
return cost;
} }
@@ -2156,7 +2143,7 @@ PrintStmt::TypeCheck() {
int int
PrintStmt::EstimateCost() const { PrintStmt::EstimateCost() const {
return COST_FUNCALL + (values ? values->EstimateCost() : 0); return COST_FUNCALL;
} }
@@ -2243,6 +2230,6 @@ AssertStmt::TypeCheck() {
int int
AssertStmt::EstimateCost() const { AssertStmt::EstimateCost() const {
return (expr ? expr->EstimateCost() : 0) + COST_ASSERT; return COST_ASSERT;
} }