From ca87579f23a2fb3649f7c265102232e4d401d3f8 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Fri, 16 Sep 2011 15:09:17 -0700 Subject: [PATCH] Add a very simple cost model to estimate runtime cost of running code. This is currently only used to decide whether it's worth doing an "are all lanes running" check at the start of functions--for small functions, it's not worth the overhead. The cost is estimated relatively early in compilation (e.g. before we know if an array access is a scatter/gather or not, before constant folding, etc.), so there are many known shortcomings. --- expr.cpp | 117 +++++++++++++++++++++++++++++++++++++++++++++++++++++ expr.h | 30 ++++++++++---- ispc.h | 21 ++++++++++ module.cpp | 9 ++++- stmt.cpp | 83 +++++++++++++++++++++++++++++++++++++ stmt.h | 12 +++++- 6 files changed, 263 insertions(+), 9 deletions(-) diff --git a/expr.cpp b/expr.cpp index 3008923f..2cfb2d95 100644 --- a/expr.cpp +++ b/expr.cpp @@ -741,6 +741,12 @@ UnaryExpr::TypeCheck() { } +int +UnaryExpr::EstimateCost() const { + return (expr ? expr->EstimateCost() : 0) + COST_SIMPLE_ARITH_LOGIC_OP; +} + + void UnaryExpr::Print() const { if (!expr || !GetType()) @@ -1445,6 +1451,15 @@ BinaryExpr::TypeCheck() { } +int +BinaryExpr::EstimateCost() const { + return ((arg0 ? arg0->EstimateCost() : 0) + + (arg1 ? arg1->EstimateCost() : 0) + + (op == Div || op == Mod) ? COST_COMPLEX_ARITH_OP : + COST_SIMPLE_ARITH_LOGIC_OP); +} + + void BinaryExpr::Print() const { if (!arg0 || !arg1 || !GetType()) @@ -1696,6 +1711,19 @@ AssignExpr::TypeCheck() { } +int +AssignExpr::EstimateCost() const { + int cost = ((lvalue ? lvalue->EstimateCost() : 0) + + (rvalue ? rvalue->EstimateCost() : 0)); + if (op == Assign) + return cost; + if (op == DivAssign || op == ModAssign) + return cost + COST_COMPLEX_ARITH_OP; + else + return cost + COST_SIMPLE_ARITH_LOGIC_OP; +} + + void AssignExpr::Print() const { if (!lvalue || !rvalue || !GetType()) @@ -1944,6 +1972,12 @@ SelectExpr::TypeCheck() { } +int +SelectExpr::EstimateCost() const { + return COST_SELECT; +} + + void SelectExpr::Print() const { if (!test || !expr1 || !expr2 || !GetType()) @@ -2440,6 +2474,13 @@ FunctionCallExpr::TypeCheck() { } +int +FunctionCallExpr::EstimateCost() const { + return ((args ? args->EstimateCost() : 0) + + (isLaunch ? COST_TASK_LAUNCH : COST_FUNCALL)); +} + + void FunctionCallExpr::Print() const { if (!func || !args || !GetType()) @@ -2551,6 +2592,17 @@ ExprList::GetConstant(const Type *type) const { } +int +ExprList::EstimateCost() const { + int cost = 0; + for (unsigned int i = 0; i < exprs.size(); ++i) { + if (exprs[i] != NULL) + cost += exprs[i]->EstimateCost(); + } + return cost; +} + + void ExprList::Print() const { printf("expr list ("); @@ -2749,6 +2801,16 @@ IndexExpr::TypeCheck() { } +int +IndexExpr::EstimateCost() const { + // be pessimistic + if (index && index->GetType()->IsVaryingType()) + return COST_GATHER; + else + return COST_LOAD; +} + + void IndexExpr::Print() const { if (!arrayOrVector || !index || !GetType()) @@ -3048,6 +3110,7 @@ MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos) { return new MemberExpr(e, id, p, idpos); } + MemberExpr::MemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos) : Expr(p), identifierPos(idpos) { expr = e; @@ -3144,6 +3207,14 @@ MemberExpr::Optimize() { } +int +MemberExpr::EstimateCost() const { + // FIXME: return gather cost when we can tell a gather is going to be + // needed + return COST_SIMPLE_ARITH_LOGIC_OP; +} + + void MemberExpr::Print() const { if (!expr || !GetType()) @@ -3939,6 +4010,12 @@ ConstExpr::TypeCheck() { } +int +ConstExpr::EstimateCost() const { + return 0; +} + + void ConstExpr::Print() const { printf("[%s] (", GetType()->GetString().c_str()); @@ -4859,6 +4936,13 @@ TypeCastExpr::Optimize() { } +int +TypeCastExpr::EstimateCost() const { + // FIXME: return COST_TYPECAST_COMPLEX when appropriate + return COST_TYPECAST_SIMPLE; +} + + void TypeCastExpr::Print() const { printf("[%s] type cast (", GetType()->GetString().c_str()); @@ -4924,6 +5008,12 @@ ReferenceExpr::TypeCheck() { } +int +ReferenceExpr::EstimateCost() const { + return 0; +} + + void ReferenceExpr::Print() const { if (expr == NULL || GetType() == NULL) @@ -5002,6 +5092,12 @@ DereferenceExpr::Optimize() { } +int +DereferenceExpr::EstimateCost() const { + return COST_DEREF; +} + + void DereferenceExpr::Print() const { if (expr == NULL || GetType() == NULL) @@ -5073,6 +5169,15 @@ SymbolExpr::Optimize() { } +int +SymbolExpr::EstimateCost() const { + if (symbol->constValue != NULL) + return 0; + else + return COST_LOAD; +} + + void SymbolExpr::Print() const { if (symbol == NULL || GetType() == NULL) @@ -5126,6 +5231,12 @@ FunctionSymbolExpr::Optimize() { } +int +FunctionSymbolExpr::EstimateCost() const { + return 0; +} + + void FunctionSymbolExpr::Print() const { if (!matchingFunc || !GetType()) @@ -5160,6 +5271,12 @@ SyncExpr::GetValue(FunctionEmitContext *ctx) const { } +int +SyncExpr::EstimateCost() const { + return COST_SYNC; +} + + void SyncExpr::Print() const { printf("sync"); diff --git a/expr.h b/expr.h index 9be0aec7..3764b19c 100644 --- a/expr.h +++ b/expr.h @@ -121,6 +121,7 @@ public: void Print() const; Expr *Optimize(); Expr *TypeCheck(); + int EstimateCost() const; private: const Op op; @@ -164,6 +165,7 @@ public: Expr *Optimize(); Expr *TypeCheck(); + int EstimateCost() const; private: const Op op; @@ -196,6 +198,7 @@ public: Expr *Optimize(); Expr *TypeCheck(); + int EstimateCost() const; private: const Op op; @@ -217,6 +220,7 @@ public: Expr *Optimize(); Expr *TypeCheck(); + int EstimateCost() const; private: Expr *test, *expr1, *expr2; @@ -240,6 +244,7 @@ public: llvm::Constant *GetConstant(const Type *type) const; ExprList *Optimize(); ExprList *TypeCheck(); + int EstimateCost() const; std::vector exprs; }; @@ -257,6 +262,7 @@ public: Expr *Optimize(); Expr *TypeCheck(); + int EstimateCost() const; private: Expr *func; @@ -285,6 +291,7 @@ public: Expr *Optimize(); Expr *TypeCheck(); + int EstimateCost() const; private: Expr *arrayOrVector, *index; @@ -303,13 +310,15 @@ public: MemberExpr(Expr *expr, const char *identifier, SourcePos pos, SourcePos identifierPos); - virtual llvm::Value *GetValue(FunctionEmitContext *ctx) const; - virtual llvm::Value *GetLValue(FunctionEmitContext *ctx) const; - virtual const Type *GetType() const; - virtual Symbol *GetBaseSymbol() const; - virtual void Print() const; - virtual Expr *Optimize(); - virtual Expr *TypeCheck(); + llvm::Value *GetValue(FunctionEmitContext *ctx) const; + llvm::Value *GetLValue(FunctionEmitContext *ctx) const; + const Type *GetType() const; + Symbol *GetBaseSymbol() const; + void Print() const; + Expr *Optimize(); + Expr *TypeCheck(); + int EstimateCost() const; + virtual int getElementNumber() const; protected: @@ -392,6 +401,7 @@ public: Expr *TypeCheck(); Expr *Optimize(); + int EstimateCost() const; /** Return the ConstExpr's values as booleans, doing type conversion from the actual type if needed. If forceVarying is true, then type @@ -495,6 +505,7 @@ public: void Print() const; Expr *TypeCheck(); Expr *Optimize(); + int EstimateCost() const; private: const Type *type; @@ -514,6 +525,7 @@ public: void Print() const; Expr *TypeCheck(); Expr *Optimize(); + int EstimateCost() const; private: Expr *expr; @@ -533,6 +545,7 @@ public: void Print() const; Expr *TypeCheck(); Expr *Optimize(); + int EstimateCost() const; private: Expr *expr; @@ -551,6 +564,7 @@ public: Expr *TypeCheck(); Expr *Optimize(); void Print() const; + int EstimateCost() const; private: Symbol *symbol; @@ -571,6 +585,7 @@ public: Expr *TypeCheck(); Expr *Optimize(); void Print() const; + int EstimateCost() const; private: friend class FunctionCallExpr; @@ -597,6 +612,7 @@ public: Expr *TypeCheck(); Expr *Optimize(); void Print() const; + int EstimateCost() const; }; #endif // ISPC_EXPR_H diff --git a/ispc.h b/ispc.h index 4b9bec65..b53134c6 100644 --- a/ispc.h +++ b/ispc.h @@ -148,6 +148,8 @@ public: pointer in place of the original ASTNode *. */ virtual ASTNode *TypeCheck() = 0; + virtual int EstimateCost() const = 0; + /** All AST nodes must track the file position where they are defined. */ const SourcePos pos; @@ -365,6 +367,25 @@ struct Globals { std::vector cppArgs; }; +enum { + COST_FUNCALL = 4, + COST_TASK_LAUNCH = 16, + COST_SELECT = 4, + COST_RETURN = 4, + COST_SIMPLE_ARITH_LOGIC_OP = 1, + COST_COMPLEX_ARITH_OP = 4, + COST_COHERENT_BREAK_CONTINE = 4, + COST_REGULAR_BREAK_CONTINUE = 2, + COST_UNIFORM_LOOP = 4, + COST_VARYING_LOOP = 6, + COST_SYNC = 32, + COST_LOAD = 2, + COST_DEREF = 4, + COST_TYPECAST_SIMPLE = 1, + COST_TYPECAST_COMPLEX = 4, + COST_GATHER = 8 +}; + extern Globals *g; extern Module *m; diff --git a/module.cpp b/module.cpp index e816a356..df596cf6 100644 --- a/module.cpp +++ b/module.cpp @@ -708,7 +708,14 @@ lEmitFunctionCode(FunctionEmitContext *ctx, llvm::Function *function, // Finally, we can generate code for the function if (code != NULL) { bool checkMask = (ft->isTask == true) || - (function->hasFnAttr(llvm::Attribute::AlwaysInline) == false); + ((function->hasFnAttr(llvm::Attribute::AlwaysInline) == false) && + code->EstimateCost() > 16); + // If the body of the function is non-trivial, then we wrap the + // entire thing around a varying "cif (true)" test in order to reap + // the side-effect benefit of checking to see if the execution mask + // is all on and thence having a specialized code path for that + // case. If this is a simple function, then this isn't worth the + // code bloat / overhead. if (checkMask) { bool allTrue[ISPC_MAX_NVEC]; for (int i = 0; i < g->target.vectorWidth; ++i) diff --git a/stmt.cpp b/stmt.cpp index 0fccb9ae..07c60571 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -107,6 +107,12 @@ ExprStmt::Print(int indent) const { } +int +ExprStmt::EstimateCost() const { + return expr ? expr->EstimateCost() : 0; +} + + /////////////////////////////////////////////////////////////////////////// // DeclStmt @@ -399,6 +405,16 @@ DeclStmt::Print(int indent) const { } +int +DeclStmt::EstimateCost() const { + int cost = 0; + for (unsigned int i = 0; i < declaration->declarators.size(); ++i) + if (declaration->declarators[i]->initExpr) + cost += declaration->declarators[i]->initExpr->EstimateCost(); + return cost; +} + + /////////////////////////////////////////////////////////////////////////// // IfStmt @@ -522,6 +538,14 @@ Stmt *IfStmt::TypeCheck() { } +int +IfStmt::EstimateCost() const { + return ((test ? test->EstimateCost() : 0) + + (trueStmts ? trueStmts->EstimateCost() : 0) + + (falseStmts ? falseStmts->EstimateCost() : 0)); +} + + void IfStmt::Print(int indent) const { printf("%*cIf Stmt %s", indent, ' ', doAllCheck ? "DO ALL CHECK" : ""); @@ -929,6 +953,13 @@ DoStmt::TypeCheck() { } +int +DoStmt::EstimateCost() const { + return ((testExpr ? testExpr->EstimateCost() : 0) + + (bodyStmts ? bodyStmts->EstimateCost() : 0)); +} + + void DoStmt::Print(int indent) const { printf("%*cDo Stmt", indent, ' '); @@ -1136,6 +1167,20 @@ ForStmt::TypeCheck() { } +int +ForStmt::EstimateCost() const { + bool uniformTest = test ? test->GetType()->IsUniformType() : + (!g->opt.disableUniformControlFlow && + !lHasVaryingBreakOrContinue(stmts)); + + return ((init ? init->EstimateCost() : 0) + + (test ? test->EstimateCost() : 0) + + (step ? step->EstimateCost() : 0) + + (stmts ? stmts->EstimateCost() : 0) + + uniformTest ? COST_UNIFORM_LOOP : COST_VARYING_LOOP); +} + + void ForStmt::Print(int indent) const { printf("%*cFor Stmt", indent, ' '); @@ -1190,6 +1235,13 @@ BreakStmt::TypeCheck() { } +int +BreakStmt::EstimateCost() const { + return doCoherenceCheck ? COST_COHERENT_BREAK_CONTINE : + COST_REGULAR_BREAK_CONTINUE; +} + + void BreakStmt::Print(int indent) const { printf("%*c%sBreak Stmt", indent, ' ', doCoherenceCheck ? "Coherent " : ""); @@ -1228,6 +1280,13 @@ ContinueStmt::TypeCheck() { } +int +ContinueStmt::EstimateCost() const { + return doCoherenceCheck ? COST_COHERENT_BREAK_CONTINE : + COST_REGULAR_BREAK_CONTINUE; +} + + void ContinueStmt::Print(int indent) const { printf("%*c%sContinue Stmt", indent, ' ', doCoherenceCheck ? "Coherent " : ""); @@ -1274,6 +1333,12 @@ ReturnStmt::TypeCheck() { } +int +ReturnStmt::EstimateCost() const { + return COST_RETURN + (val ? val->EstimateCost() : 0); +} + + void ReturnStmt::Print(int indent) const { printf("%*c%sReturn Stmt", indent, ' ', doCoherenceCheck ? "Coherent " : ""); @@ -1319,6 +1384,16 @@ StmtList::TypeCheck() { } +int +StmtList::EstimateCost() const { + int cost = 0; + for (unsigned int i = 0; i < stmts.size(); ++i) + if (stmts[i]) + cost += stmts[i]->EstimateCost(); + return cost; +} + + void StmtList::Print(int indent) const { printf("%*cStmt List", indent, ' '); @@ -1519,3 +1594,11 @@ PrintStmt::TypeCheck() { values = values->TypeCheck(); return this; } + + +int +PrintStmt::EstimateCost() const { + return COST_FUNCALL + (values ? values->EstimateCost() : 0); +} + + diff --git a/stmt.h b/stmt.h index 759026cf..3dac745a 100644 --- a/stmt.h +++ b/stmt.h @@ -75,6 +75,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: Expr *expr; @@ -92,8 +93,9 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; -private: + private: Declaration *declaration; }; @@ -110,6 +112,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; // @todo these are only public for lHasVaryingBreakOrContinue(); would // be nice to clean that up... @@ -151,6 +154,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: Expr *testExpr; @@ -172,6 +176,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: /** 'for' statment initializer; may be NULL, indicating no intitializer */ @@ -199,6 +204,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: /** This indicates whether the generated code will check to see if no @@ -220,6 +226,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: /** This indicates whether the generated code will check to see if no @@ -241,6 +248,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: Expr *val; @@ -263,6 +271,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; void Add(Stmt *s) { if (s) stmts.push_back(s); } const std::vector &GetStatements() { return stmts; } @@ -290,6 +299,7 @@ public: Stmt *Optimize(); Stmt *TypeCheck(); + int EstimateCost() const; private: /** Format string for the print() statement. */