diff --git a/ast.cpp b/ast.cpp index 81070eab..023e4ba9 100644 --- a/ast.cpp +++ b/ast.cpp @@ -273,3 +273,19 @@ Stmt * TypeCheck(Stmt *stmt) { return (Stmt *)TypeCheck((ASTNode *)stmt); } + + +static bool +lCostCallback(ASTNode *node, void *c) { + int *cost = (int *)c; + *cost += node->EstimateCost(); + return true; +} + + +int +EstimateCost(ASTNode *root) { + int cost = 0; + WalkAST(root, lCostCallback, NULL, &cost); + return cost; +} diff --git a/ast.h b/ast.h index 68fc9a03..0c3d4b64 100644 --- a/ast.h +++ b/ast.h @@ -66,6 +66,9 @@ public: pointer in place of the original ASTNode *. */ virtual ASTNode *TypeCheck() = 0; + /** Estimate the execution cost of the node (not including the cost of + the children. The value returned should be based on the COST_* + enumerant values defined in ispc.h. */ virtual int EstimateCost() const = 0; /** All AST nodes must track the file position where they are @@ -127,14 +130,18 @@ extern Expr *Optimize(Expr *); to a Stmt *). */ extern Stmt *Optimize(Stmt *); -/* Perform type-checking on the given AST (or portion of one), returning a - pointer to the root of the resulting AST. */ +/** Perform type-checking on the given AST (or portion of one), returning a + pointer to the root of the resulting AST. */ extern ASTNode *TypeCheck(ASTNode *root); -/* Convenience version of TypeCheck() for Expr *s that returns an Expr *. */ +/** Convenience version of TypeCheck() for Expr *s that returns an Expr *. */ extern Expr *TypeCheck(Expr *); -/* Convenience version of TypeCheck() for Stmt *s that returns an Stmt *. */ +/** Convenience version of TypeCheck() for Stmt *s that returns an Stmt *. */ extern Stmt *TypeCheck(Stmt *); +/** Returns an estimate of the execution cost of the tree starting at + the given root. */ +extern int EstimateCost(ASTNode *root); + #endif // ISPC_AST_H diff --git a/expr.cpp b/expr.cpp index ade233f1..af6acd28 100644 --- a/expr.cpp +++ b/expr.cpp @@ -989,7 +989,7 @@ UnaryExpr::TypeCheck() { int UnaryExpr::EstimateCost() const { - return (expr ? expr->EstimateCost() : 0) + COST_SIMPLE_ARITH_LOGIC_OP; + return COST_SIMPLE_ARITH_LOGIC_OP; } @@ -1885,10 +1885,8 @@ BinaryExpr::TypeCheck() { int BinaryExpr::EstimateCost() const { - return ((arg0 ? arg0->EstimateCost() : 0) + - (arg1 ? arg1->EstimateCost() : 0) + - ((op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP : - COST_SIMPLE_ARITH_LOGIC_OP)); + return (op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP : + COST_SIMPLE_ARITH_LOGIC_OP; } @@ -2204,15 +2202,12 @@ AssignExpr::TypeCheck() { int AssignExpr::EstimateCost() const { - int cost = ((lvalue ? lvalue->EstimateCost() : 0) + - (rvalue ? rvalue->EstimateCost() : 0)); - cost += COST_ASSIGN; if (op == Assign) - return cost; + return COST_ASSIGN; if (op == DivAssign || op == ModAssign) - return cost + COST_COMPLEX_ARITH_OP; + return COST_ASSIGN + COST_COMPLEX_ARITH_OP; else - return cost + COST_SIMPLE_ARITH_LOGIC_OP; + return COST_ASSIGN + COST_SIMPLE_ARITH_LOGIC_OP; } @@ -2637,7 +2632,7 @@ FunctionCallExpr::TypeCheck() { if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false) return NULL; - func = fse->TypeCheck(); + func = ::TypeCheck(fse); if (func == NULL) return NULL; @@ -2742,28 +2737,20 @@ FunctionCallExpr::TypeCheck() { int FunctionCallExpr::EstimateCost() const { - int callCost = 0; - if (isLaunch) { - callCost = COST_TASK_LAUNCH; - if (launchCountExpr != NULL) - callCost += launchCountExpr->EstimateCost(); - } + if (isLaunch) + return COST_TASK_LAUNCH; else if (dynamic_cast(func) == NULL) { // it's going through a function pointer const Type *fpType = func->GetType(); if (fpType != NULL) { Assert(dynamic_cast(fpType) != NULL); if (fpType->IsUniformType()) - callCost = COST_FUNPTR_UNIFORM; + return COST_FUNPTR_UNIFORM; else - callCost = COST_FUNPTR_VARYING; + return COST_FUNPTR_VARYING; } } - else - // regular function call - callCost = COST_FUNCALL; - - return (args ? args->EstimateCost() : 0) + callCost; + return COST_FUNCALL; } @@ -2880,12 +2867,7 @@ ExprList::GetConstant(const Type *type) const { int ExprList::EstimateCost() const { - int cost = 0; - for (unsigned int i = 0; i < exprs.size(); ++i) { - if (exprs[i] != NULL) - cost += exprs[i]->EstimateCost(); - } - return cost; + return 0; } diff --git a/func.cpp b/func.cpp index 0327bff9..61dfb784 100644 --- a/func.cpp +++ b/func.cpp @@ -277,7 +277,7 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function, ctx->SetDebugPos(code->pos); ctx->AddInstrumentationPoint("function entry"); - int costEstimate = code->EstimateCost(); + int costEstimate = EstimateCost(code); Debug(code->pos, "Estimated cost for function \"%s\" = %d\n", sym->name.c_str(), costEstimate); diff --git a/stmt.cpp b/stmt.cpp index b3ec0c39..e799fc0b 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -107,7 +107,7 @@ ExprStmt::Print(int indent) const { int ExprStmt::EstimateCost() const { - return expr ? expr->EstimateCost() : 0; + return 0; } @@ -467,11 +467,7 @@ DeclStmt::Print(int indent) const { int DeclStmt::EstimateCost() const { - int cost = 0; - for (unsigned int i = 0; i < vars.size(); ++i) - if (vars[i].init != NULL) - cost += vars[i].init->EstimateCost(); - return cost; + return 0; } @@ -577,15 +573,11 @@ IfStmt::TypeCheck() { int IfStmt::EstimateCost() const { - int ifcost = 0; const Type *type; - if (test && (type = test->GetType()) != NULL) - ifcost = type->IsUniformType() ? COST_UNIFORM_IF : COST_VARYING_IF; + if (test == NULL || (type = test->GetType()) != NULL) + return 0; - return ifcost + - ((test ? test->EstimateCost() : 0) + - (trueStmts ? trueStmts->EstimateCost() : 0) + - (falseStmts ? falseStmts->EstimateCost() : 0)); + return type->IsUniformType() ? COST_UNIFORM_IF : COST_VARYING_IF; } @@ -766,9 +758,10 @@ IfStmt::emitVaryingIf(FunctionEmitContext *ctx, llvm::Value *ltest) const { // // where our use of blend for conditional assignments doesn't check // for the 'all lanes' off case. - bool costIsAcceptable = ((trueStmts ? trueStmts->EstimateCost() : 0) + - (falseStmts ? falseStmts->EstimateCost() : 0)) < - PREDICATE_SAFE_IF_STATEMENT_COST; + int trueFalseCost = (::EstimateCost(trueStmts) + + ::EstimateCost(falseStmts)); + bool costIsAcceptable = (trueFalseCost < + PREDICATE_SAFE_IF_STATEMENT_COST); bool safeToRunWithAllLanesOff = true; WalkAST(trueStmts, lCheckAllOffSafety, NULL, &safeToRunWithAllLanesOff); @@ -1146,8 +1139,11 @@ DoStmt::TypeCheck() { int DoStmt::EstimateCost() const { - return ((testExpr ? testExpr->EstimateCost() : 0) + - (bodyStmts ? bodyStmts->EstimateCost() : 0)); + bool uniformTest = testExpr ? testExpr->GetType()->IsUniformType() : + (!g->opt.disableUniformControlFlow && + !lHasVaryingBreakOrContinue(bodyStmts)); + + return uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP; } @@ -1336,11 +1332,7 @@ ForStmt::EstimateCost() const { (!g->opt.disableUniformControlFlow && !lHasVaryingBreakOrContinue(stmts)); - return ((init ? init->EstimateCost() : 0) + - (test ? test->EstimateCost() : 0) + - (step ? step->EstimateCost() : 0) + - (stmts ? stmts->EstimateCost() : 0) + - (uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP)); + return uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP; } @@ -1827,8 +1819,7 @@ ForeachStmt::TypeCheck() { int ForeachStmt::EstimateCost() const { - return dimVariables.size() * (COST_UNIFORM_LOOP + COST_SIMPLE_ARITH_LOGIC_OP) + - (stmts ? stmts->EstimateCost() : 0); + return dimVariables.size() * (COST_UNIFORM_LOOP + COST_SIMPLE_ARITH_LOGIC_OP); } @@ -1908,7 +1899,7 @@ ReturnStmt::TypeCheck() { int ReturnStmt::EstimateCost() const { - return COST_RETURN + (val ? val->EstimateCost() : 0); + return COST_RETURN; } @@ -1947,11 +1938,7 @@ StmtList::TypeCheck() { int StmtList::EstimateCost() const { - int cost = 0; - for (unsigned int i = 0; i < stmts.size(); ++i) - if (stmts[i]) - cost += stmts[i]->EstimateCost(); - return cost; + return 0; } @@ -2156,7 +2143,7 @@ PrintStmt::TypeCheck() { int PrintStmt::EstimateCost() const { - return COST_FUNCALL + (values ? values->EstimateCost() : 0); + return COST_FUNCALL; } @@ -2243,6 +2230,6 @@ AssertStmt::TypeCheck() { int AssertStmt::EstimateCost() const { - return (expr ? expr->EstimateCost() : 0) + COST_ASSERT; + return COST_ASSERT; }