From 32a0a30cf5baa3d0dbe1b692a86f2931028de2e6 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Wed, 28 Sep 2011 17:20:31 -0700 Subject: [PATCH] Only allow exact matches for function overload resolution for builtins. The intent is that the code in stdlib.ispc that is calling out to the built-ins should match argument types exactly (using explicit casts as needed), just for maximal clarity/safety. --- builtins.m4 | 13 ++++---- expr.cpp | 56 +++++++++++++++++++------------ expr.h | 7 ++-- parse.yy | 2 +- stdlib.ispc | 96 +++++++++++++++++++++++++++-------------------------- 5 files changed, 96 insertions(+), 78 deletions(-) diff --git a/builtins.m4 b/builtins.m4 index fe6990f4..26beb376 100644 --- a/builtins.m4 +++ b/builtins.m4 @@ -647,16 +647,17 @@ forloop(i, 1, eval($1-1), ` define(`global_atomic_associative', ` -;; note that the mask is expected to be of type $3, so the caller must ensure -;; that for 64-bit types, the mask is cast to a signed int before being passed -;; to this so that it is properly sign extended... (The code in stdlib.ispc -;; does do this..) - define internal <$1 x $3> @__atomic_$2_$4_global($3 * %ptr, <$1 x $3> %val, - <$1 x $3> %mask) nounwind alwaysinline { + <$1 x i32> %m) nounwind alwaysinline { ; first, for any lanes where the mask is off, compute a vector where those lanes ; hold the identity value.. + ifelse($3, `i64', `%mask = sext <$1 x i32> %m to <$1 x i64>') + ifelse($3, `i32', ` + %maskmem = alloca <$1 x i32> + store <$1 x i32> %m, <$1 x i32> * %maskmem + %mask = load <$1 x i32> * %maskmem' + ) ; zero out any lanes that are off %valoff = and <$1 x $3> %val, %mask diff --git a/expr.cpp b/expr.cpp index 81e3d10b..b36a423c 100644 --- a/expr.cpp +++ b/expr.cpp @@ -1189,7 +1189,7 @@ BinaryExpr::Optimize() { m->symbolTable->LookupFunction("rcp"); if (rcpFuns != NULL) { assert(rcpFuns->size() == 2); - Expr *rcpSymExpr = new FunctionSymbolExpr(rcpFuns, pos); + Expr *rcpSymExpr = new FunctionSymbolExpr("rcp", rcpFuns, pos); ExprList *args = new ExprList(arg1, arg1->pos); Expr *rcpCall = new FunctionCallExpr(rcpSymExpr, args, arg1->pos, false); @@ -2213,7 +2213,7 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) { void -FunctionCallExpr::resolveFunctionOverloads() { +FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) { FunctionSymbolExpr *fse = dynamic_cast(func); if (!fse) // error will be issued later if not calling an actual function @@ -2227,31 +2227,33 @@ FunctionCallExpr::resolveFunctionOverloads() { if (tryResolve(lExactMatch)) return; - // Try to find a single match ignoring references - if (tryResolve(lMatchIgnoringReferences)) - return; + if (!exactMatchOnly) { + // Try to find a single match ignoring references + if (tryResolve(lMatchIgnoringReferences)) + return; - // TODO: next, try to find an exact match via type promotion--i.e. char - // -> int, etc--things that don't lose data + // TODO: next, try to find an exact match via type promotion--i.e. char + // -> int, etc--things that don't lose data - // Next try to see if there's a match via just uniform -> varying - // promotions. TODO: look for one with a minimal number of them? - if (tryResolve(lMatchIgnoringUniform)) - return; + // Next try to see if there's a match via just uniform -> varying + // promotions. TODO: look for one with a minimal number of them? + if (tryResolve(lMatchIgnoringUniform)) + return; - // Try to find a match via type conversion, but don't change - // unif->varying - if (tryResolve(lMatchWithTypeConvSameVariability)) - return; + // Try to find a match via type conversion, but don't change + // unif->varying + if (tryResolve(lMatchWithTypeConvSameVariability)) + return; - // Last chance: try to find a match via arbitrary type conversion. - if (tryResolve(lMatchWithTypeConv)) - return; + // Last chance: try to find a match via arbitrary type conversion. + if (tryResolve(lMatchWithTypeConv)) + return; + } // failure :-( const char *funName = fse->candidateFunctions->front()->name.c_str(); - Error(pos, "Unable to find matching overload for call to function \"%s\".", - funName); + Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", + funName, exactMatchOnly ? " only considering exact matches" : ""); fprintf(stderr, "Candidates are:\n"); lPrintFunctionOverloads(*fse->candidateFunctions); lPrintPassedTypes(funName, args->exprs); @@ -2264,7 +2266,15 @@ FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, bool il) args = a; isLaunch = il; - resolveFunctionOverloads(); + FunctionSymbolExpr *fse = dynamic_cast(func); + // Functions with names that start with "__" should only be various + // builtins. For those, we'll demand an exact match, since we'll + // expect whichever function in stdlib.ispc is calling out to one of + // those to be matching the argument types exactly; this is to be a bit + // extra safe to be sure that the expected builtin is in fact being + // called. + bool exactMatchOnly = (fse != NULL) && (fse->name.substr(0,2) == "__"); + resolveFunctionOverloads(exactMatchOnly); } @@ -5201,9 +5211,11 @@ SymbolExpr::Print() const { /////////////////////////////////////////////////////////////////////////// // FunctionSymbolExpr -FunctionSymbolExpr::FunctionSymbolExpr(std::vector *candidates, +FunctionSymbolExpr::FunctionSymbolExpr(const char *n, + std::vector *candidates, SourcePos p) : Expr(p) { + name = n; matchingFunc = NULL; candidateFunctions = candidates; } diff --git a/expr.h b/expr.h index 73e09647..9af33f8b 100644 --- a/expr.h +++ b/expr.h @@ -265,7 +265,7 @@ public: bool isLaunch; private: - void resolveFunctionOverloads(); + void resolveFunctionOverloads(bool exactMatchOnly); bool tryResolve(bool (*matchFunc)(Expr *, const Type *)); }; @@ -567,7 +567,7 @@ private: */ class FunctionSymbolExpr : public Expr { public: - FunctionSymbolExpr(std::vector *candidateFunctions, + FunctionSymbolExpr(const char *name, std::vector *candidateFunctions, SourcePos pos); llvm::Value *GetValue(FunctionEmitContext *ctx) const; @@ -581,6 +581,9 @@ public: private: friend class FunctionCallExpr; + /** Name of the function that is being called. */ + std::string name; + /** All of the functions with the name given in the function call; there may be more then one, in which case we need to resolve which overload is the best match. */ diff --git a/parse.yy b/parse.yy index ecd1c27e..ac75075b 100644 --- a/parse.yy +++ b/parse.yy @@ -222,7 +222,7 @@ primary_expression else { std::vector *funs = m->symbolTable->LookupFunction(name); if (funs) - $$ = new FunctionSymbolExpr(funs, @1); + $$ = new FunctionSymbolExpr(name, funs, @1); } if ($$ == NULL) { std::vector alternates = diff --git a/stdlib.ispc b/stdlib.ispc index 1e8ddc4f..fea10fa5 100644 --- a/stdlib.ispc +++ b/stdlib.ispc @@ -369,7 +369,7 @@ static inline uniform float reduce_min(float v) { static inline uniform float reduce_max(float v) { // For the lanes where the mask is off, replace the given value with // negative infinity, so that it doesn't affect the result. - const uniform int iflt_neg_max = 0xff800000; // -infinity + const int iflt_neg_max = 0xff800000; // -infinity // Must use __floatbits_varying_int32, not floatbits(), since with the // latter the current mask enters into the returned result... return __reduce_max_float(__mask ? v : __floatbits_varying_int32(iflt_neg_max)); @@ -427,7 +427,7 @@ static inline uniform double reduce_min(double v) { } static inline uniform double reduce_max(double v) { - const uniform int64 iflt_neg_max = 0xfff0000000000000; // -infinity + const int64 iflt_neg_max = 0xfff0000000000000; // -infinity // Must use __doublebits_varying_int64, not doublebits(), since with the // latter the current mask enters into the returned result... return __reduce_max_double(__mask ? v : __doublebits_varying_int64(iflt_neg_max)); @@ -471,21 +471,21 @@ static inline uniform unsigned int64 reduce_max(unsigned int64 v) { return __reduce_max_uint64(__mask ? v : 0); } -#define REDUCE_EQUAL(TYPE, FUNCTYPE) \ +#define REDUCE_EQUAL(TYPE, FUNCTYPE, MASKTYPE) \ static inline uniform bool reduce_equal(TYPE v) { \ uniform TYPE unusedValue; \ - return __reduce_equal_##FUNCTYPE(v, unusedValue, (int32)__mask); \ + return __reduce_equal_##FUNCTYPE(v, unusedValue, (MASKTYPE)__mask); \ } \ static inline uniform bool reduce_equal(TYPE v, reference uniform TYPE value) { \ - return __reduce_equal_##FUNCTYPE(v, value, (int32)__mask); \ + return __reduce_equal_##FUNCTYPE(v, value, (MASKTYPE)__mask); \ } -REDUCE_EQUAL(int32, int32) -REDUCE_EQUAL(unsigned int32, int32) -REDUCE_EQUAL(float, float) -REDUCE_EQUAL(int64, int64) -REDUCE_EQUAL(unsigned int64, int64) -REDUCE_EQUAL(double, double) +REDUCE_EQUAL(int32, int32, int32) +REDUCE_EQUAL(unsigned int32, int32, unsigned int32) +REDUCE_EQUAL(float, float, int32) +REDUCE_EQUAL(int64, int64, int32) +REDUCE_EQUAL(unsigned int64, int64, unsigned int32) +REDUCE_EQUAL(double, double, int32) static int32 exclusive_scan_add(int32 v) { return __exclusive_scan_add_i32(v, (int32)__mask); @@ -549,23 +549,25 @@ static unsigned int64 exclusive_scan_or(unsigned int64 v) { static inline uniform int packed_load_active(uniform unsigned int a[], uniform int start, reference unsigned int vals) { - return __packed_load_active(a, start, vals, __mask); + return __packed_load_active(a, (unsigned int)start, vals, + (unsigned int32)__mask); } static inline uniform int packed_store_active(uniform unsigned int a[], uniform int start, unsigned int vals) { - return __packed_store_active(a, start, vals, __mask); + return __packed_store_active(a, (unsigned int)start, vals, + (unsigned int32)__mask); } static inline uniform int packed_load_active(uniform int a[], uniform int start, reference int vals) { - return __packed_load_active(a, start, vals, __mask); + return __packed_load_active(a, start, vals, (int32)__mask); } static inline uniform int packed_store_active(uniform int a[], uniform int start, int vals) { - return __packed_store_active(a, start, vals, __mask); + return __packed_store_active(a, start, vals, (int32)__mask); } /////////////////////////////////////////////////////////////////////////// @@ -597,13 +599,13 @@ static inline uniform TA atomic_##OPA##_global(uniform reference TA ref, \ return ret; \ } -#define DEFINE_ATOMIC_MINMAX_OP(TA,TB,OPA,OPB) \ +#define DEFINE_ATOMIC_MINMAX_OP(TA,TB,OPA,OPB, MASKTYPE) \ static inline TA atomic_##OPA##_global(uniform reference TA ref, TA value) { \ uniform TA oneval = reduce_##OPA(value); \ TA ret; \ if (lanemask() != 0) { \ memory_barrier(); \ - ret = __atomic_##OPB##_uniform_##TB##_global(ref, oneval, __mask); \ + ret = __atomic_##OPB##_uniform_##TB##_global(ref, oneval, (MASKTYPE)__mask); \ memory_barrier(); \ } \ return ret; \ @@ -611,15 +613,15 @@ static inline TA atomic_##OPA##_global(uniform reference TA ref, TA value) { \ static inline uniform TA atomic_##OPA##_global(uniform reference TA ref, \ uniform TA value) { \ memory_barrier(); \ - uniform TA ret = __atomic_##OPB##_uniform_##TB##_global(ref, value, __mask); \ + uniform TA ret = __atomic_##OPB##_uniform_##TB##_global(ref, value, (MASKTYPE)__mask); \ memory_barrier(); \ return ret; \ } DEFINE_ATOMIC_OP(int32,int32,add,add,int32) DEFINE_ATOMIC_OP(int32,int32,subtract,sub,int32) -DEFINE_ATOMIC_MINMAX_OP(int32,int32,min,min) -DEFINE_ATOMIC_MINMAX_OP(int32,int32,max,max) +DEFINE_ATOMIC_MINMAX_OP(int32,int32,min,min,int32) +DEFINE_ATOMIC_MINMAX_OP(int32,int32,max,max,int32) DEFINE_ATOMIC_OP(int32,int32,and,and,int32) DEFINE_ATOMIC_OP(int32,int32,or,or,int32) DEFINE_ATOMIC_OP(int32,int32,xor,xor,int32) @@ -627,21 +629,21 @@ DEFINE_ATOMIC_OP(int32,int32,swap,swap,int32) // For everything but atomic min and max, we can use the same // implementations for unsigned as for signed. -DEFINE_ATOMIC_OP(unsigned int32,int32,add,add,int32) -DEFINE_ATOMIC_OP(unsigned int32,int32,subtract,sub,int32) -DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,min,umin) -DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,max,umax) -DEFINE_ATOMIC_OP(unsigned int32,int32,and,and,int32) -DEFINE_ATOMIC_OP(unsigned int32,int32,or,or,int32) -DEFINE_ATOMIC_OP(unsigned int32,int32,xor,xor,int32) -DEFINE_ATOMIC_OP(unsigned int32,int32,swap,swap,int32) +DEFINE_ATOMIC_OP(unsigned int32,int32,add,add,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int32,int32,subtract,sub,unsigned int32) +DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,min,umin,unsigned int32) +DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,max,umax,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int32,int32,and,and,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int32,int32,or,or,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int32,int32,xor,xor,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int32,int32,swap,swap,unsigned int32) DEFINE_ATOMIC_OP(float,float,swap,swap,int32) DEFINE_ATOMIC_OP(int64,int64,add,add,int32) DEFINE_ATOMIC_OP(int64,int64,subtract,sub,int32) -DEFINE_ATOMIC_MINMAX_OP(int64,int64,min,min) -DEFINE_ATOMIC_MINMAX_OP(int64,int64,max,max) +DEFINE_ATOMIC_MINMAX_OP(int64,int64,min,min,int32) +DEFINE_ATOMIC_MINMAX_OP(int64,int64,max,max,int32) DEFINE_ATOMIC_OP(int64,int64,and,and,int32) DEFINE_ATOMIC_OP(int64,int64,or,or,int32) DEFINE_ATOMIC_OP(int64,int64,xor,xor,int32) @@ -649,41 +651,41 @@ DEFINE_ATOMIC_OP(int64,int64,swap,swap,int32) // For everything but atomic min and max, we can use the same // implementations for unsigned as for signed. -DEFINE_ATOMIC_OP(unsigned int64,int64,add,add,int32) -DEFINE_ATOMIC_OP(unsigned int64,int64,subtract,sub,int32) -DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,min,umin) -DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,max,umax) -DEFINE_ATOMIC_OP(unsigned int64,int64,and,and,int32) -DEFINE_ATOMIC_OP(unsigned int64,int64,or,or,int32) -DEFINE_ATOMIC_OP(unsigned int64,int64,xor,xor,int32) -DEFINE_ATOMIC_OP(unsigned int64,int64,swap,swap,int32) +DEFINE_ATOMIC_OP(unsigned int64,int64,add,add,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int64,int64,subtract,sub,unsigned int32) +DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,min,umin,unsigned int32) +DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,max,umax,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int64,int64,and,and,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int64,int64,or,or,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int64,int64,xor,xor,unsigned int32) +DEFINE_ATOMIC_OP(unsigned int64,int64,swap,swap,unsigned int32) DEFINE_ATOMIC_OP(double,double,swap,swap,int32) #undef DEFINE_ATOMIC_OP -#define ATOMIC_DECL_CMPXCHG(TA, TB) \ +#define ATOMIC_DECL_CMPXCHG(TA, TB, MASKTYPE) \ static inline TA atomic_compare_exchange_global( \ uniform reference TA ref, TA oldval, TA newval) { \ memory_barrier(); \ - TA ret = __atomic_compare_exchange_##TB##_global(ref, oldval, newval, __mask); \ + TA ret = __atomic_compare_exchange_##TB##_global(ref, oldval, newval, (MASKTYPE)__mask); \ memory_barrier(); \ return ret; \ } \ static inline uniform TA atomic_compare_exchange_global( \ uniform reference TA ref, uniform TA oldval, uniform TA newval) { \ memory_barrier(); \ - uniform TA ret = __atomic_compare_exchange_uniform_##TB##_global(ref, oldval, newval, __mask); \ + uniform TA ret = __atomic_compare_exchange_uniform_##TB##_global(ref, oldval, newval, (MASKTYPE)__mask); \ memory_barrier(); \ return ret; \ } -ATOMIC_DECL_CMPXCHG(int32, int32) -ATOMIC_DECL_CMPXCHG(unsigned int32, int32) -ATOMIC_DECL_CMPXCHG(float, float) -ATOMIC_DECL_CMPXCHG(int64, int64) -ATOMIC_DECL_CMPXCHG(unsigned int64, int64) -ATOMIC_DECL_CMPXCHG(double, double) +ATOMIC_DECL_CMPXCHG(int32, int32, int32) +ATOMIC_DECL_CMPXCHG(unsigned int32, int32, unsigned int32) +ATOMIC_DECL_CMPXCHG(float, float, int32) +ATOMIC_DECL_CMPXCHG(int64, int64, int32) +ATOMIC_DECL_CMPXCHG(unsigned int64, int64, unsigned int32) +ATOMIC_DECL_CMPXCHG(double, double, int32) #undef ATOMIC_DECL_CMPXCHG