Only allow exact matches for function overload resolution for builtins.

The intent is that the code in stdlib.ispc that is calling out to the built-ins
  should match argument types exactly (using explicit casts as needed), just
  for maximal clarity/safety.
This commit is contained in:
Matt Pharr
2011-09-28 17:20:31 -07:00
parent 6d39d5fc3e
commit 32a0a30cf5
5 changed files with 96 additions and 78 deletions

View File

@@ -647,16 +647,17 @@ forloop(i, 1, eval($1-1), `
define(`global_atomic_associative', ` define(`global_atomic_associative', `
;; note that the mask is expected to be of type $3, so the caller must ensure
;; that for 64-bit types, the mask is cast to a signed int before being passed
;; to this so that it is properly sign extended... (The code in stdlib.ispc
;; does do this..)
define internal <$1 x $3> @__atomic_$2_$4_global($3 * %ptr, <$1 x $3> %val, define internal <$1 x $3> @__atomic_$2_$4_global($3 * %ptr, <$1 x $3> %val,
<$1 x $3> %mask) nounwind alwaysinline { <$1 x i32> %m) nounwind alwaysinline {
; first, for any lanes where the mask is off, compute a vector where those lanes ; first, for any lanes where the mask is off, compute a vector where those lanes
; hold the identity value.. ; hold the identity value..
ifelse($3, `i64', `%mask = sext <$1 x i32> %m to <$1 x i64>')
ifelse($3, `i32', `
%maskmem = alloca <$1 x i32>
store <$1 x i32> %m, <$1 x i32> * %maskmem
%mask = load <$1 x i32> * %maskmem'
)
; zero out any lanes that are off ; zero out any lanes that are off
%valoff = and <$1 x $3> %val, %mask %valoff = and <$1 x $3> %val, %mask

View File

@@ -1189,7 +1189,7 @@ BinaryExpr::Optimize() {
m->symbolTable->LookupFunction("rcp"); m->symbolTable->LookupFunction("rcp");
if (rcpFuns != NULL) { if (rcpFuns != NULL) {
assert(rcpFuns->size() == 2); assert(rcpFuns->size() == 2);
Expr *rcpSymExpr = new FunctionSymbolExpr(rcpFuns, pos); Expr *rcpSymExpr = new FunctionSymbolExpr("rcp", rcpFuns, pos);
ExprList *args = new ExprList(arg1, arg1->pos); ExprList *args = new ExprList(arg1, arg1->pos);
Expr *rcpCall = new FunctionCallExpr(rcpSymExpr, args, Expr *rcpCall = new FunctionCallExpr(rcpSymExpr, args,
arg1->pos, false); arg1->pos, false);
@@ -2213,7 +2213,7 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
void void
FunctionCallExpr::resolveFunctionOverloads() { FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func); FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (!fse) if (!fse)
// error will be issued later if not calling an actual function // error will be issued later if not calling an actual function
@@ -2227,31 +2227,33 @@ FunctionCallExpr::resolveFunctionOverloads() {
if (tryResolve(lExactMatch)) if (tryResolve(lExactMatch))
return; return;
// Try to find a single match ignoring references if (!exactMatchOnly) {
if (tryResolve(lMatchIgnoringReferences)) // Try to find a single match ignoring references
return; if (tryResolve(lMatchIgnoringReferences))
return;
// TODO: next, try to find an exact match via type promotion--i.e. char // TODO: next, try to find an exact match via type promotion--i.e. char
// -> int, etc--things that don't lose data // -> int, etc--things that don't lose data
// Next try to see if there's a match via just uniform -> varying // Next try to see if there's a match via just uniform -> varying
// promotions. TODO: look for one with a minimal number of them? // promotions. TODO: look for one with a minimal number of them?
if (tryResolve(lMatchIgnoringUniform)) if (tryResolve(lMatchIgnoringUniform))
return; return;
// Try to find a match via type conversion, but don't change // Try to find a match via type conversion, but don't change
// unif->varying // unif->varying
if (tryResolve(lMatchWithTypeConvSameVariability)) if (tryResolve(lMatchWithTypeConvSameVariability))
return; return;
// Last chance: try to find a match via arbitrary type conversion. // Last chance: try to find a match via arbitrary type conversion.
if (tryResolve(lMatchWithTypeConv)) if (tryResolve(lMatchWithTypeConv))
return; return;
}
// failure :-( // failure :-(
const char *funName = fse->candidateFunctions->front()->name.c_str(); const char *funName = fse->candidateFunctions->front()->name.c_str();
Error(pos, "Unable to find matching overload for call to function \"%s\".", Error(pos, "Unable to find matching overload for call to function \"%s\"%s.",
funName); funName, exactMatchOnly ? " only considering exact matches" : "");
fprintf(stderr, "Candidates are:\n"); fprintf(stderr, "Candidates are:\n");
lPrintFunctionOverloads(*fse->candidateFunctions); lPrintFunctionOverloads(*fse->candidateFunctions);
lPrintPassedTypes(funName, args->exprs); lPrintPassedTypes(funName, args->exprs);
@@ -2264,7 +2266,15 @@ FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, bool il)
args = a; args = a;
isLaunch = il; isLaunch = il;
resolveFunctionOverloads(); FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
// Functions with names that start with "__" should only be various
// builtins. For those, we'll demand an exact match, since we'll
// expect whichever function in stdlib.ispc is calling out to one of
// those to be matching the argument types exactly; this is to be a bit
// extra safe to be sure that the expected builtin is in fact being
// called.
bool exactMatchOnly = (fse != NULL) && (fse->name.substr(0,2) == "__");
resolveFunctionOverloads(exactMatchOnly);
} }
@@ -5201,9 +5211,11 @@ SymbolExpr::Print() const {
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// FunctionSymbolExpr // FunctionSymbolExpr
FunctionSymbolExpr::FunctionSymbolExpr(std::vector<Symbol *> *candidates, FunctionSymbolExpr::FunctionSymbolExpr(const char *n,
std::vector<Symbol *> *candidates,
SourcePos p) SourcePos p)
: Expr(p) { : Expr(p) {
name = n;
matchingFunc = NULL; matchingFunc = NULL;
candidateFunctions = candidates; candidateFunctions = candidates;
} }

7
expr.h
View File

@@ -265,7 +265,7 @@ public:
bool isLaunch; bool isLaunch;
private: private:
void resolveFunctionOverloads(); void resolveFunctionOverloads(bool exactMatchOnly);
bool tryResolve(bool (*matchFunc)(Expr *, const Type *)); bool tryResolve(bool (*matchFunc)(Expr *, const Type *));
}; };
@@ -567,7 +567,7 @@ private:
*/ */
class FunctionSymbolExpr : public Expr { class FunctionSymbolExpr : public Expr {
public: public:
FunctionSymbolExpr(std::vector<Symbol *> *candidateFunctions, FunctionSymbolExpr(const char *name, std::vector<Symbol *> *candidateFunctions,
SourcePos pos); SourcePos pos);
llvm::Value *GetValue(FunctionEmitContext *ctx) const; llvm::Value *GetValue(FunctionEmitContext *ctx) const;
@@ -581,6 +581,9 @@ public:
private: private:
friend class FunctionCallExpr; friend class FunctionCallExpr;
/** Name of the function that is being called. */
std::string name;
/** All of the functions with the name given in the function call; /** All of the functions with the name given in the function call;
there may be more then one, in which case we need to resolve which there may be more then one, in which case we need to resolve which
overload is the best match. */ overload is the best match. */

View File

@@ -222,7 +222,7 @@ primary_expression
else { else {
std::vector<Symbol *> *funs = m->symbolTable->LookupFunction(name); std::vector<Symbol *> *funs = m->symbolTable->LookupFunction(name);
if (funs) if (funs)
$$ = new FunctionSymbolExpr(funs, @1); $$ = new FunctionSymbolExpr(name, funs, @1);
} }
if ($$ == NULL) { if ($$ == NULL) {
std::vector<std::string> alternates = std::vector<std::string> alternates =

View File

@@ -369,7 +369,7 @@ static inline uniform float reduce_min(float v) {
static inline uniform float reduce_max(float v) { static inline uniform float reduce_max(float v) {
// For the lanes where the mask is off, replace the given value with // For the lanes where the mask is off, replace the given value with
// negative infinity, so that it doesn't affect the result. // negative infinity, so that it doesn't affect the result.
const uniform int iflt_neg_max = 0xff800000; // -infinity const int iflt_neg_max = 0xff800000; // -infinity
// Must use __floatbits_varying_int32, not floatbits(), since with the // Must use __floatbits_varying_int32, not floatbits(), since with the
// latter the current mask enters into the returned result... // latter the current mask enters into the returned result...
return __reduce_max_float(__mask ? v : __floatbits_varying_int32(iflt_neg_max)); return __reduce_max_float(__mask ? v : __floatbits_varying_int32(iflt_neg_max));
@@ -427,7 +427,7 @@ static inline uniform double reduce_min(double v) {
} }
static inline uniform double reduce_max(double v) { static inline uniform double reduce_max(double v) {
const uniform int64 iflt_neg_max = 0xfff0000000000000; // -infinity const int64 iflt_neg_max = 0xfff0000000000000; // -infinity
// Must use __doublebits_varying_int64, not doublebits(), since with the // Must use __doublebits_varying_int64, not doublebits(), since with the
// latter the current mask enters into the returned result... // latter the current mask enters into the returned result...
return __reduce_max_double(__mask ? v : __doublebits_varying_int64(iflt_neg_max)); return __reduce_max_double(__mask ? v : __doublebits_varying_int64(iflt_neg_max));
@@ -471,21 +471,21 @@ static inline uniform unsigned int64 reduce_max(unsigned int64 v) {
return __reduce_max_uint64(__mask ? v : 0); return __reduce_max_uint64(__mask ? v : 0);
} }
#define REDUCE_EQUAL(TYPE, FUNCTYPE) \ #define REDUCE_EQUAL(TYPE, FUNCTYPE, MASKTYPE) \
static inline uniform bool reduce_equal(TYPE v) { \ static inline uniform bool reduce_equal(TYPE v) { \
uniform TYPE unusedValue; \ uniform TYPE unusedValue; \
return __reduce_equal_##FUNCTYPE(v, unusedValue, (int32)__mask); \ return __reduce_equal_##FUNCTYPE(v, unusedValue, (MASKTYPE)__mask); \
} \ } \
static inline uniform bool reduce_equal(TYPE v, reference uniform TYPE value) { \ static inline uniform bool reduce_equal(TYPE v, reference uniform TYPE value) { \
return __reduce_equal_##FUNCTYPE(v, value, (int32)__mask); \ return __reduce_equal_##FUNCTYPE(v, value, (MASKTYPE)__mask); \
} }
REDUCE_EQUAL(int32, int32) REDUCE_EQUAL(int32, int32, int32)
REDUCE_EQUAL(unsigned int32, int32) REDUCE_EQUAL(unsigned int32, int32, unsigned int32)
REDUCE_EQUAL(float, float) REDUCE_EQUAL(float, float, int32)
REDUCE_EQUAL(int64, int64) REDUCE_EQUAL(int64, int64, int32)
REDUCE_EQUAL(unsigned int64, int64) REDUCE_EQUAL(unsigned int64, int64, unsigned int32)
REDUCE_EQUAL(double, double) REDUCE_EQUAL(double, double, int32)
static int32 exclusive_scan_add(int32 v) { static int32 exclusive_scan_add(int32 v) {
return __exclusive_scan_add_i32(v, (int32)__mask); return __exclusive_scan_add_i32(v, (int32)__mask);
@@ -549,23 +549,25 @@ static unsigned int64 exclusive_scan_or(unsigned int64 v) {
static inline uniform int static inline uniform int
packed_load_active(uniform unsigned int a[], uniform int start, packed_load_active(uniform unsigned int a[], uniform int start,
reference unsigned int vals) { reference unsigned int vals) {
return __packed_load_active(a, start, vals, __mask); return __packed_load_active(a, (unsigned int)start, vals,
(unsigned int32)__mask);
} }
static inline uniform int static inline uniform int
packed_store_active(uniform unsigned int a[], uniform int start, packed_store_active(uniform unsigned int a[], uniform int start,
unsigned int vals) { unsigned int vals) {
return __packed_store_active(a, start, vals, __mask); return __packed_store_active(a, (unsigned int)start, vals,
(unsigned int32)__mask);
} }
static inline uniform int packed_load_active(uniform int a[], uniform int start, static inline uniform int packed_load_active(uniform int a[], uniform int start,
reference int vals) { reference int vals) {
return __packed_load_active(a, start, vals, __mask); return __packed_load_active(a, start, vals, (int32)__mask);
} }
static inline uniform int packed_store_active(uniform int a[], uniform int start, static inline uniform int packed_store_active(uniform int a[], uniform int start,
int vals) { int vals) {
return __packed_store_active(a, start, vals, __mask); return __packed_store_active(a, start, vals, (int32)__mask);
} }
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
@@ -597,13 +599,13 @@ static inline uniform TA atomic_##OPA##_global(uniform reference TA ref, \
return ret; \ return ret; \
} }
#define DEFINE_ATOMIC_MINMAX_OP(TA,TB,OPA,OPB) \ #define DEFINE_ATOMIC_MINMAX_OP(TA,TB,OPA,OPB, MASKTYPE) \
static inline TA atomic_##OPA##_global(uniform reference TA ref, TA value) { \ static inline TA atomic_##OPA##_global(uniform reference TA ref, TA value) { \
uniform TA oneval = reduce_##OPA(value); \ uniform TA oneval = reduce_##OPA(value); \
TA ret; \ TA ret; \
if (lanemask() != 0) { \ if (lanemask() != 0) { \
memory_barrier(); \ memory_barrier(); \
ret = __atomic_##OPB##_uniform_##TB##_global(ref, oneval, __mask); \ ret = __atomic_##OPB##_uniform_##TB##_global(ref, oneval, (MASKTYPE)__mask); \
memory_barrier(); \ memory_barrier(); \
} \ } \
return ret; \ return ret; \
@@ -611,15 +613,15 @@ static inline TA atomic_##OPA##_global(uniform reference TA ref, TA value) { \
static inline uniform TA atomic_##OPA##_global(uniform reference TA ref, \ static inline uniform TA atomic_##OPA##_global(uniform reference TA ref, \
uniform TA value) { \ uniform TA value) { \
memory_barrier(); \ memory_barrier(); \
uniform TA ret = __atomic_##OPB##_uniform_##TB##_global(ref, value, __mask); \ uniform TA ret = __atomic_##OPB##_uniform_##TB##_global(ref, value, (MASKTYPE)__mask); \
memory_barrier(); \ memory_barrier(); \
return ret; \ return ret; \
} }
DEFINE_ATOMIC_OP(int32,int32,add,add,int32) DEFINE_ATOMIC_OP(int32,int32,add,add,int32)
DEFINE_ATOMIC_OP(int32,int32,subtract,sub,int32) DEFINE_ATOMIC_OP(int32,int32,subtract,sub,int32)
DEFINE_ATOMIC_MINMAX_OP(int32,int32,min,min) DEFINE_ATOMIC_MINMAX_OP(int32,int32,min,min,int32)
DEFINE_ATOMIC_MINMAX_OP(int32,int32,max,max) DEFINE_ATOMIC_MINMAX_OP(int32,int32,max,max,int32)
DEFINE_ATOMIC_OP(int32,int32,and,and,int32) DEFINE_ATOMIC_OP(int32,int32,and,and,int32)
DEFINE_ATOMIC_OP(int32,int32,or,or,int32) DEFINE_ATOMIC_OP(int32,int32,or,or,int32)
DEFINE_ATOMIC_OP(int32,int32,xor,xor,int32) DEFINE_ATOMIC_OP(int32,int32,xor,xor,int32)
@@ -627,21 +629,21 @@ DEFINE_ATOMIC_OP(int32,int32,swap,swap,int32)
// For everything but atomic min and max, we can use the same // For everything but atomic min and max, we can use the same
// implementations for unsigned as for signed. // implementations for unsigned as for signed.
DEFINE_ATOMIC_OP(unsigned int32,int32,add,add,int32) DEFINE_ATOMIC_OP(unsigned int32,int32,add,add,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int32,int32,subtract,sub,int32) DEFINE_ATOMIC_OP(unsigned int32,int32,subtract,sub,unsigned int32)
DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,min,umin) DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,min,umin,unsigned int32)
DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,max,umax) DEFINE_ATOMIC_MINMAX_OP(unsigned int32,uint32,max,umax,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int32,int32,and,and,int32) DEFINE_ATOMIC_OP(unsigned int32,int32,and,and,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int32,int32,or,or,int32) DEFINE_ATOMIC_OP(unsigned int32,int32,or,or,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int32,int32,xor,xor,int32) DEFINE_ATOMIC_OP(unsigned int32,int32,xor,xor,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int32,int32,swap,swap,int32) DEFINE_ATOMIC_OP(unsigned int32,int32,swap,swap,unsigned int32)
DEFINE_ATOMIC_OP(float,float,swap,swap,int32) DEFINE_ATOMIC_OP(float,float,swap,swap,int32)
DEFINE_ATOMIC_OP(int64,int64,add,add,int32) DEFINE_ATOMIC_OP(int64,int64,add,add,int32)
DEFINE_ATOMIC_OP(int64,int64,subtract,sub,int32) DEFINE_ATOMIC_OP(int64,int64,subtract,sub,int32)
DEFINE_ATOMIC_MINMAX_OP(int64,int64,min,min) DEFINE_ATOMIC_MINMAX_OP(int64,int64,min,min,int32)
DEFINE_ATOMIC_MINMAX_OP(int64,int64,max,max) DEFINE_ATOMIC_MINMAX_OP(int64,int64,max,max,int32)
DEFINE_ATOMIC_OP(int64,int64,and,and,int32) DEFINE_ATOMIC_OP(int64,int64,and,and,int32)
DEFINE_ATOMIC_OP(int64,int64,or,or,int32) DEFINE_ATOMIC_OP(int64,int64,or,or,int32)
DEFINE_ATOMIC_OP(int64,int64,xor,xor,int32) DEFINE_ATOMIC_OP(int64,int64,xor,xor,int32)
@@ -649,41 +651,41 @@ DEFINE_ATOMIC_OP(int64,int64,swap,swap,int32)
// For everything but atomic min and max, we can use the same // For everything but atomic min and max, we can use the same
// implementations for unsigned as for signed. // implementations for unsigned as for signed.
DEFINE_ATOMIC_OP(unsigned int64,int64,add,add,int32) DEFINE_ATOMIC_OP(unsigned int64,int64,add,add,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int64,int64,subtract,sub,int32) DEFINE_ATOMIC_OP(unsigned int64,int64,subtract,sub,unsigned int32)
DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,min,umin) DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,min,umin,unsigned int32)
DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,max,umax) DEFINE_ATOMIC_MINMAX_OP(unsigned int64,uint64,max,umax,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int64,int64,and,and,int32) DEFINE_ATOMIC_OP(unsigned int64,int64,and,and,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int64,int64,or,or,int32) DEFINE_ATOMIC_OP(unsigned int64,int64,or,or,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int64,int64,xor,xor,int32) DEFINE_ATOMIC_OP(unsigned int64,int64,xor,xor,unsigned int32)
DEFINE_ATOMIC_OP(unsigned int64,int64,swap,swap,int32) DEFINE_ATOMIC_OP(unsigned int64,int64,swap,swap,unsigned int32)
DEFINE_ATOMIC_OP(double,double,swap,swap,int32) DEFINE_ATOMIC_OP(double,double,swap,swap,int32)
#undef DEFINE_ATOMIC_OP #undef DEFINE_ATOMIC_OP
#define ATOMIC_DECL_CMPXCHG(TA, TB) \ #define ATOMIC_DECL_CMPXCHG(TA, TB, MASKTYPE) \
static inline TA atomic_compare_exchange_global( \ static inline TA atomic_compare_exchange_global( \
uniform reference TA ref, TA oldval, TA newval) { \ uniform reference TA ref, TA oldval, TA newval) { \
memory_barrier(); \ memory_barrier(); \
TA ret = __atomic_compare_exchange_##TB##_global(ref, oldval, newval, __mask); \ TA ret = __atomic_compare_exchange_##TB##_global(ref, oldval, newval, (MASKTYPE)__mask); \
memory_barrier(); \ memory_barrier(); \
return ret; \ return ret; \
} \ } \
static inline uniform TA atomic_compare_exchange_global( \ static inline uniform TA atomic_compare_exchange_global( \
uniform reference TA ref, uniform TA oldval, uniform TA newval) { \ uniform reference TA ref, uniform TA oldval, uniform TA newval) { \
memory_barrier(); \ memory_barrier(); \
uniform TA ret = __atomic_compare_exchange_uniform_##TB##_global(ref, oldval, newval, __mask); \ uniform TA ret = __atomic_compare_exchange_uniform_##TB##_global(ref, oldval, newval, (MASKTYPE)__mask); \
memory_barrier(); \ memory_barrier(); \
return ret; \ return ret; \
} }
ATOMIC_DECL_CMPXCHG(int32, int32) ATOMIC_DECL_CMPXCHG(int32, int32, int32)
ATOMIC_DECL_CMPXCHG(unsigned int32, int32) ATOMIC_DECL_CMPXCHG(unsigned int32, int32, unsigned int32)
ATOMIC_DECL_CMPXCHG(float, float) ATOMIC_DECL_CMPXCHG(float, float, int32)
ATOMIC_DECL_CMPXCHG(int64, int64) ATOMIC_DECL_CMPXCHG(int64, int64, int32)
ATOMIC_DECL_CMPXCHG(unsigned int64, int64) ATOMIC_DECL_CMPXCHG(unsigned int64, int64, unsigned int32)
ATOMIC_DECL_CMPXCHG(double, double) ATOMIC_DECL_CMPXCHG(double, double, int32)
#undef ATOMIC_DECL_CMPXCHG #undef ATOMIC_DECL_CMPXCHG