diff --git a/expr.cpp b/expr.cpp index f46347ee..81f51151 100644 --- a/expr.cpp +++ b/expr.cpp @@ -7497,7 +7497,24 @@ lPrintOverloadCandidates(SourcePos pos, const std::vector &funcs, Error(pos, "%s", passedTypes.c_str()); } - + +static bool +lIsMatchToNonConstReference(const Type *callType, const Type *funcArgType) { + return (dynamic_cast(funcArgType) && + (funcArgType->IsConstType() == false) && + Type::Equal(callType, funcArgType->GetReferenceTarget())); +} + + +static bool +lIsMatchToNonConstReferenceUnifToVarying(const Type *callType, + const Type *funcArgType) { + return (dynamic_cast(funcArgType) && + (funcArgType->IsConstType() == false) && + Type::Equal(callType->GetAsVaryingType(), + funcArgType->GetReferenceTarget())); +} + /** Helper function used for function overload resolution: returns true if converting the argument to the call type only requires a type conversion that won't lose information. Otherwise return false. @@ -7597,6 +7614,20 @@ FunctionSymbolExpr::getCandidateFunctions(int argCount) const { } +static bool +lArgIsPointerType(const Type *type) { + if (dynamic_cast(type) != NULL) + return true; + + const ReferenceType *rt = dynamic_cast(type); + if (rt == NULL) + return false; + + const Type *t = rt->GetReferenceTarget(); + return (dynamic_cast(t) != NULL); +} + + /** This function computes the value of a cost function that represents the cost of calling a function of the given type with arguments of the given types. If it's not possible to call the function, regardless of @@ -7623,19 +7654,11 @@ FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype, const Type *fargType = ftype->GetParameterType(i); const Type *callType = argTypes[i]; - // For convenience, normalize to non-const types (except for - // references, where const-ness matters). For all other types, - // we're passing by value anyway, so const doesn't matter. - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - if (dynamic_cast(fargType) == NULL) - fargType = fargType->GetAsNonConstType(); - if (Type::Equal(callType, fargType)) // Perfect match: no cost costSum += 0; else if (argCouldBeNULL && (*argCouldBeNULL)[i] && - dynamic_cast(fargType) != NULL) + lArgIsPointerType(fargType)) // Passing NULL to a pointer-typed parameter is also a no-cost // operation costSum += 0; @@ -7645,19 +7668,33 @@ FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype, // cost if it wasn't--so scale up the cost when this isn't the // case.. if (argIsConstant == NULL || (*argIsConstant)[i] == false) - costScale *= 32; + costScale *= 128; + + // For convenience, normalize to non-const types (except for + // references, where const-ness matters). For all other types, + // we're passing by value anyway, so const doesn't matter. + const Type *callTypeNC = callType, *fargTypeNC = fargType; + if (dynamic_cast(callType) == NULL) + callTypeNC = callType->GetAsNonConstType(); + if (dynamic_cast(fargType) == NULL) + fargTypeNC = fargType->GetAsNonConstType(); - if (Type::Equal(callType, fargType)) + if (Type::Equal(callTypeNC, fargTypeNC)) // Exact match (after dealing with references, above) costSum += 1 * costScale; - else if (lIsMatchWithTypeWidening(callType, fargType)) + // note: orig fargType for the next two... + else if (lIsMatchToNonConstReference(callTypeNC, fargType)) costSum += 2 * costScale; - else if (lIsMatchWithUniformToVarying(callType, fargType)) + else if (lIsMatchToNonConstReferenceUnifToVarying(callTypeNC, fargType)) costSum += 4 * costScale; - else if (lIsMatchWithTypeConvSameVariability(callType, fargType)) + else if (lIsMatchWithTypeWidening(callTypeNC, fargTypeNC)) costSum += 8 * costScale; - else if (CanConvertTypes(callType, fargType)) + else if (lIsMatchWithUniformToVarying(callTypeNC, fargTypeNC)) costSum += 16 * costScale; + else if (lIsMatchWithTypeConvSameVariability(callTypeNC, fargTypeNC)) + costSum += 32 * costScale; + else if (CanConvertTypes(callTypeNC, fargTypeNC)) + costSum += 64 * costScale; else // Failure--no type conversion possible... return -1; diff --git a/tests/func-overload-refs.ispc b/tests/func-overload-refs.ispc new file mode 100644 index 00000000..89184812 --- /dev/null +++ b/tests/func-overload-refs.ispc @@ -0,0 +1,14 @@ + +export uniform int width() { return programCount; } + +float foo(float &a) { return 1; } +float foo(const float &a) { return 2; } + +export void f_f(uniform float RET[], uniform float aFOO[]) { + float x = 0; + RET[programIndex] = foo(x); +} + +export void result(uniform float RET[]) { + RET[programIndex] = 1; +}