From f8a39402a21f48a50aa9b9c0d5ab6420e8215940 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Tue, 27 Mar 2012 13:25:11 -0700 Subject: [PATCH] Implement new, simpler function overload resolution algorithm. We now give each conversion a cost and then find the minimum sum of costs for all of the possible overloads. Fixes issue #194. --- expr.cpp | 443 ++++++++++-------------- expr.h | 22 +- tests/func-overload-max.ispc | 12 + tests_errors/func-param-mismatch-2.ispc | 2 +- tests_errors/func-param-mismatch-3.ispc | 2 +- tests_errors/func-param-mismatch.ispc | 2 +- 6 files changed, 214 insertions(+), 269 deletions(-) create mode 100644 tests/func-overload-max.ispc diff --git a/expr.cpp b/expr.cpp index 5b162d15..907f1a84 100644 --- a/expr.cpp +++ b/expr.cpp @@ -3406,24 +3406,28 @@ FunctionCallExpr::TypeCheck() { return NULL; std::vector argTypes; - std::vector argCouldBeNULL; + std::vector argCouldBeNULL, argIsConstant; for (unsigned int i = 0; i < args->exprs.size(); ++i) { - if (args->exprs[i] == NULL) + Expr *expr = args->exprs[i]; + + if (expr == NULL) return NULL; - const Type *t = args->exprs[i]->GetType(); + const Type *t = expr->GetType(); if (t == NULL) return NULL; - argTypes.push_back(t); - argCouldBeNULL.push_back(lIsAllIntZeros(args->exprs[i]) || - dynamic_cast(args->exprs[i]) != NULL); + argTypes.push_back(t); + argCouldBeNULL.push_back(lIsAllIntZeros(expr) || + dynamic_cast(expr)); + argIsConstant.push_back(dynamic_cast(expr) || + dynamic_cast(expr)); } FunctionSymbolExpr *fse = dynamic_cast(func); if (fse != NULL) { // Regular function call - - if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false) + if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, + &argIsConstant) == false) return NULL; func = ::TypeCheck(fse); @@ -7403,282 +7407,183 @@ lPrintOverloadCandidates(SourcePos pos, const std::vector &funcs, } -/** Helper function used for function overload resolution: returns zero - cost if the call argument's type exactly matches the function argument - type (modulo a conversion to a const type if needed), otherwise reports - failure. - */ -static int -lExactMatch(const Type *callType, const Type *funcArgType) { - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - if (dynamic_cast(funcArgType) != NULL && - dynamic_cast(callType) == NULL) - callType = new ReferenceType(callType); - - return Type::Equal(callType, funcArgType) ? 0 : -1; -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if the call argument type and the function argument type match, - modulo conversion to a reference type if needed. - */ -static int -lMatchIgnoringReferences(const Type *callType, const Type *funcArgType) { - int prev = lExactMatch(callType, funcArgType); - if (prev != -1) - return prev; - - callType = callType->GetReferenceTarget(); - if (funcArgType->IsConstType()) - callType = callType->GetAsConstType(); - - return Type::Equal(callType, - funcArgType->GetReferenceTarget()) ? 1 : -1; -} - -/** Helper function used for function overload resolution: returns a cost - of 1 if converting the argument to the call type only requires a type - conversion that won't lose information. Otherwise reports failure. -*/ -static int -lMatchWithTypeWidening(const Type *callType, const Type *funcArgType) { - int prev = lMatchIgnoringReferences(callType, funcArgType); - if (prev != -1) - return prev; - +/** Helper function used for function overload resolution: returns true if + converting the argument to the call type only requires a type + conversion that won't lose information. Otherwise return false. + */ +static bool +lIsMatchWithTypeWidening(const Type *callType, const Type *funcArgType) { const AtomicType *callAt = dynamic_cast(callType); const AtomicType *funcAt = dynamic_cast(funcArgType); if (callAt == NULL || funcAt == NULL) - return -1; + return false; if (callAt->IsUniformType() != funcAt->IsUniformType()) - return -1; + return false; switch (callAt->basicType) { case AtomicType::TYPE_BOOL: - return 1; + return true; case AtomicType::TYPE_INT8: case AtomicType::TYPE_UINT8: - return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; + return (funcAt->basicType != AtomicType::TYPE_BOOL); case AtomicType::TYPE_INT16: case AtomicType::TYPE_UINT16: return (funcAt->basicType != AtomicType::TYPE_BOOL && funcAt->basicType != AtomicType::TYPE_INT8 && - funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; + funcAt->basicType != AtomicType::TYPE_UINT8); case AtomicType::TYPE_INT32: case AtomicType::TYPE_UINT32: return (funcAt->basicType == AtomicType::TYPE_INT32 || funcAt->basicType == AtomicType::TYPE_UINT32 || funcAt->basicType == AtomicType::TYPE_INT64 || - funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + funcAt->basicType == AtomicType::TYPE_UINT64); case AtomicType::TYPE_FLOAT: - return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; + return (funcAt->basicType == AtomicType::TYPE_DOUBLE); case AtomicType::TYPE_INT64: case AtomicType::TYPE_UINT64: return (funcAt->basicType == AtomicType::TYPE_INT64 || - funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + funcAt->basicType == AtomicType::TYPE_UINT64); case AtomicType::TYPE_DOUBLE: - return -1; + return false; default: FATAL("Unhandled atomic type"); - return -1; + return false; } } -/** Helper function used for function overload resolution: returns a cost - of 1 if the call argument type and the function argument type match if - we only do a uniform -> varying type conversion but otherwise have - exactly the same type. +/** Helper function used for function overload resolution: returns true if + the call argument type and the function argument type match if we only + do a uniform -> varying type conversion but otherwise have exactly the + same type. */ -static int -lMatchIgnoringUniform(const Type *callType, const Type *funcArgType) { - int prev = lMatchWithTypeWidening(callType, funcArgType); - if (prev != -1) - return prev; - - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - +static bool +lIsMatchWithUniformToVarying(const Type *callType, const Type *funcArgType) { return (callType->IsUniformType() && funcArgType->IsVaryingType() && - Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; + Type::EqualIgnoringConst(callType->GetAsVaryingType(), funcArgType)); } -/** Helper function used for function overload resolution: returns a cost - of 1 if we can type convert from the call argument type to the function +/** Helper function used for function overload resolution: returns true if + we can type convert from the call argument type to the function argument type, but without doing a uniform -> varying conversion. */ -static int -lMatchWithTypeConvSameVariability(const Type *callType, - const Type *funcArgType) { - int prev = lMatchIgnoringUniform(callType, funcArgType); - if (prev != -1) - return prev; - - if (CanConvertTypes(callType, funcArgType) && - (callType->IsUniformType() == funcArgType->IsUniformType())) - return 1; - else - return -1; +static bool +lIsMatchWithTypeConvSameVariability(const Type *callType, + const Type *funcArgType) { + return (CanConvertTypes(callType, funcArgType) && + (callType->GetVariability() == funcArgType->GetVariability())); } -/** Helper function used for function overload resolution: returns a cost - of 1 if there is any type conversion that gets us from the caller - argument type to the function argument type. +/* Returns the set of function overloads that are potential matches, given + argCount values being passed as arguments to the function call. */ -static int -lMatchWithTypeConv(const Type *callType, const Type *funcArgType) { - int prev = lMatchWithTypeConvSameVariability(callType, funcArgType); - if (prev != -1) - return prev; - - return CanConvertTypes(callType, funcArgType) ? 0 : -1; -} - - -/** Given a set of potential matching functions and their associated cost, - return the one with the lowest cost, if unique. Otherwise, if multiple - functions match with the same cost, return NULL. - */ -static Symbol * -lGetBestMatch(std::vector > &matches) { - Assert(matches.size() > 0); - int minCost = matches[0].first; - - for (unsigned int i = 1; i < matches.size(); ++i) - minCost = std::min(minCost, matches[i].first); - - Symbol *match = NULL; - for (unsigned int i = 0; i < matches.size(); ++i) { - if (matches[i].first == minCost) { - if (match != NULL) - // multiple things had the same cost - return NULL; - else - match = matches[i].second; - } - } - return match; -} - - -/** See if we can find a single function from the set of overload options - based on the predicate function passed in. Returns true if no more - tries should be made to find a match, either due to success from - finding a single overloaded function that matches or failure due to - finding multiple ambiguous matches. - */ -bool -FunctionSymbolExpr::tryResolve(int (*matchFunc)(const Type *, const Type *), - SourcePos argPos, - const std::vector &callTypes, - const std::vector *argCouldBeNULL) { - const char *funName = candidateFunctions.front()->name.c_str(); - - std::vector > matches; - std::vector::iterator iter; - for (iter = candidateFunctions.begin(); - iter != candidateFunctions.end(); ++iter) { - // Loop over the set of candidate functions and try each one - Symbol *candidateFunction = *iter; +std::vector +FunctionSymbolExpr::getCandidateFunctions(int argCount) const { + std::vector ret; + for (int i = 0; i < (int)candidateFunctions.size(); ++i) { const FunctionType *ft = - dynamic_cast(candidateFunction->type); + dynamic_cast(candidateFunctions[i]->type); Assert(ft != NULL); // There's no way to match if the caller is passing more arguments // than this function instance takes. - if ((int)callTypes.size() > ft->GetNumParameters()) + if (argCount > ft->GetNumParameters()) continue; - int i; - // Note that we're looping over the caller arguments, not the - // function arguments; it may be ok to have more arguments to the - // function than are passed, if the function has default argument - // values. This case is handled below. - int cost = 0; - for (i = 0; i < (int)callTypes.size(); ++i) { - // This may happen if there's an error earlier in compilation. - // It's kind of a silly to redundantly discover this for each - // potential match versus detecting this earlier in the - // matching process and just giving up. - const Type *paramType = ft->GetParameterType(i); + // Not enough arguments, and no default argument value to save us + if (argCount < ft->GetNumParameters() && + ft->GetParameterDefault(argCount) == NULL) + continue; - if (callTypes[i] == NULL || paramType == NULL || - dynamic_cast(callTypes[i]) != NULL) - return false; + // Success + ret.push_back(candidateFunctions[i]); + } + return ret; +} - int argCost = matchFunc(callTypes[i], paramType); - if (argCost == -1) { - if (argCouldBeNULL != NULL && (*argCouldBeNULL)[i] == true && - dynamic_cast(paramType) != NULL) - // If the passed argument value is zero and this is a - // pointer type, then it can convert to a NULL value of - // that pointer type. - argCost = 0; - else - // If the predicate function returns -1, we have failed no - // matter what else happens, so we stop trying - break; - } - cost += argCost; - } - if (i == (int)callTypes.size()) { - // All of the arguments matched! - if (i == ft->GetNumParameters()) - // And we have exactly as many arguments as the function - // wants, so we're done. - matches.push_back(std::make_pair(cost, candidateFunction)); - else if (i < ft->GetNumParameters() && - ft->GetParameterDefault(i) != NULL) - // Otherwise we can still make it if there are default - // arguments for the rest of the arguments! Because in - // Module::AddFunction() we have verified that once the - // default arguments start, then all of the following ones - // have them as well. Therefore, we just need to check if - // the arg we stopped at has a default value and we're - // done. - matches.push_back(std::make_pair(cost, candidateFunction)); - // otherwise, we don't have a match + +/** This function computes the value of a cost function that represents the + cost of calling a function of the given type with arguments of the + given types. If it's not possible to call the function, regardless of + any type conversions applied, a cost of -1 is returned. + */ +int +FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype, + const std::vector &argTypes, + const std::vector *argCouldBeNULL, + const std::vector *argIsConstant) { + int costSum = 0; + + // In computing the cost function, we only worry about the actual + // argument types--using function default parameter values is free for + // the purposes here... + for (int i = 0; i < (int)argTypes.size(); ++i) { + // The cost imposed by this argument will be a multiple of + // costScale, which has a value set so that for each of the cost + // buckets, even if all of the function arguments undergo the next + // lower-cost conversion, the sum of their costs will be less than + // a single instance of the next higher-cost conversion. + int costScale = argTypes.size() + 1; + + const Type *fargType = ftype->GetParameterType(i); + const Type *callType = argTypes[i]; + + // For convenience, normalize to non-const types (except for + // references, where const-ness matters). For all other types, + // we're passing by value anyway, so const doesn't matter. + if (dynamic_cast(callType) == NULL) + callType = callType->GetAsNonConstType(); + if (dynamic_cast(fargType) == NULL) + fargType = fargType->GetAsNonConstType(); + + if (Type::Equal(callType, fargType)) + // Perfect match: no cost + costSum += 0; + else if (argCouldBeNULL && (*argCouldBeNULL)[i] && + dynamic_cast(fargType) != NULL) + // Passing NULL to a pointer-typed parameter is also a no-cost + // operation + costSum += 0; + else { + // If the argument is a compile-time constant, we'd like to + // count the cost of various conversions as much lower than the + // cost if it wasn't--so scale up the cost when this isn't the + // case.. + if (argIsConstant == NULL || (*argIsConstant)[i] == false) + costScale *= 32; + + if (Type::Equal(callType, fargType)) + // Exact match (after dealing with references, above) + costSum += 1 * costScale; + else if (lIsMatchWithTypeWidening(callType, fargType)) + costSum += 2 * costScale; + else if (lIsMatchWithUniformToVarying(callType, fargType)) + costSum += 4 * costScale; + else if (lIsMatchWithTypeConvSameVariability(callType, fargType)) + costSum += 8 * costScale; + else if (CanConvertTypes(callType, fargType)) + costSum += 16 * costScale; + else + // Failure--no type conversion possible... + return -1; } } - if (matches.size() == 0) - return false; - else if ((matchingFunc = lGetBestMatch(matches)) != NULL) - // We have a match! - return true; - else { - Error(pos, "Multiple overloaded instances of function \"%s\" matched.", - funName); - - // select the matches that have the lowest cost - std::vector bestMatches; - int minCost = matches[0].first; - for (unsigned int i = 1; i < matches.size(); ++i) - minCost = std::min(minCost, matches[i].first); - for (unsigned int i = 0; i < matches.size(); ++i) - if (matches[i].first == minCost) - bestMatches.push_back(matches[i].second); - - // And print a useful error message - lPrintOverloadCandidates(argPos, bestMatches, callTypes, argCouldBeNULL); - - // Stop trying to find more matches after an ambigious set of - // matches. - return true; - } + return costSum; } bool FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, const std::vector &argTypes, - const std::vector *argCouldBeNULL) { + const std::vector *argCouldBeNULL, + const std::vector *argIsConstant) { + const char *funName = candidateFunctions.front()->name.c_str(); + triedToResolve = true; // Functions with names that start with "__" should only be various @@ -7689,45 +7594,67 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, // called. bool exactMatchOnly = (name.substr(0,2) == "__"); - // Is there an exact match that doesn't require any argument type - // conversion (other than converting type -> reference type)? - if (tryResolve(lExactMatch, argPos, argTypes, argCouldBeNULL)) - return true; + // First, find the subset of overload candidates that take the same + // number of arguments as have parameters (including functions that + // take more arguments but have defaults starting no later than after + // our last parameter). + std::vector actualCandidates = + getCandidateFunctions(argTypes.size()); - if (exactMatchOnly == false) { - // Try to find a single match ignoring references - if (tryResolve(lMatchIgnoringReferences, argPos, argTypes, - argCouldBeNULL)) - return true; + int bestMatchCost = 1<<30; + std::vector matches; + std::vector candidateCosts; - // Try to find an exact match via type widening--i.e. int8 -> - // int16, etc.--things that don't lose data. - if (tryResolve(lMatchWithTypeWidening, argPos, argTypes, argCouldBeNULL)) - return true; + if (actualCandidates.size() == 0) + goto failure; - // Next try to see if there's a match via just uniform -> varying - // promotions. - if (tryResolve(lMatchIgnoringUniform, argPos, argTypes, argCouldBeNULL)) - return true; - - // Try to find a match via type conversion, but don't change - // unif->varying - if (tryResolve(lMatchWithTypeConvSameVariability, argPos, argTypes, - argCouldBeNULL)) - return true; - - // Last chance: try to find a match via arbitrary type conversion. - if (tryResolve(lMatchWithTypeConv, argPos, argTypes, argCouldBeNULL)) - return true; + // Compute the cost for calling each of the candidate functions + for (int i = 0; i < (int)actualCandidates.size(); ++i) { + const FunctionType *ft = + dynamic_cast(actualCandidates[i]->type); + Assert(ft != NULL); + candidateCosts.push_back(computeOverloadCost(ft, argTypes, + argCouldBeNULL, + argIsConstant)); } - // failure :-( - const char *funName = candidateFunctions.front()->name.c_str(); - Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", - funName, exactMatchOnly ? " only considering exact matches" : ""); - lPrintOverloadCandidates(argPos, candidateFunctions, argTypes, - argCouldBeNULL); - return false; + // Find the best cost, and then the candidate or candidates that have + // that cost. + for (int i = 0; i < (int)candidateCosts.size(); ++i) { + if (candidateCosts[i] != -1 && candidateCosts[i] < bestMatchCost) + bestMatchCost = candidateCosts[i]; + } + // None of the candidates matched + if (bestMatchCost == (1<<30)) + goto failure; + for (int i = 0; i < (int)candidateCosts.size(); ++i) { + if (candidateCosts[i] == bestMatchCost) + matches.push_back(actualCandidates[i]); + } + + if (matches.size() == 1) { + // Only one match: success + matchingFunc = matches[0]; + return true; + } + else if (matches.size() > 1) { + // Multiple matches: ambiguous + Error(pos, "Multiple overloaded functions matched call to function " + "\"%s\"%s.", funName, + exactMatchOnly ? " only considering exact matches" : ""); + lPrintOverloadCandidates(argPos, matches, argTypes, argCouldBeNULL); + return false; + } + else { + // No matches at all + failure: + Error(pos, "Unable to find any matching overload for call to function " + "\"%s\"%s.", funName, + exactMatchOnly ? " only considering exact matches" : ""); + lPrintOverloadCandidates(argPos, candidateFunctions, argTypes, + argCouldBeNULL); + return false; + } } diff --git a/expr.h b/expr.h index 5c59ae83..e7461a1a 100644 --- a/expr.h +++ b/expr.h @@ -651,20 +651,26 @@ public: function overloading, this method resolves which actual function the arguments match best. If the argCouldBeNULL parameter is non-NULL, each element indicates whether the corresponding argument - is the number zero, indicating that it could be a NULL pointer. - This parameter may be NULL (for cases where overload resolution is - being done just given type information without the parameter - argument expressions being available. It returns true on success. + is the number zero, indicating that it could be a NULL pointer, and + if argIsConstant is non-NULL, each element indicates whether the + corresponding argument is a compile-time constant value. Both of + these parameters may be NULL (for cases where overload resolution + is being done just given type information without the parameter + argument expressions being available. This function returns true + on success. */ bool ResolveOverloads(SourcePos argPos, const std::vector &argTypes, - const std::vector *argCouldBeNULL = NULL); + const std::vector *argCouldBeNULL = NULL, + const std::vector *argIsConstant = NULL); Symbol *GetMatchingFunction(); private: - bool tryResolve(int (*matchFunc)(const Type *, const Type *), - SourcePos argPos, const std::vector &argTypes, - const std::vector *argCouldBeNULL); + std::vector getCandidateFunctions(int argCount) const; + static int computeOverloadCost(const FunctionType *ftype, + const std::vector &argTypes, + const std::vector *argCouldBeNULL, + const std::vector *argIsConstant); /** Name of the function that is being called. */ std::string name; diff --git a/tests/func-overload-max.ispc b/tests/func-overload-max.ispc new file mode 100644 index 00000000..37360030 --- /dev/null +++ b/tests/func-overload-max.ispc @@ -0,0 +1,12 @@ + +export uniform int width() { return programCount; } + + +export void f_f(uniform float RET[], uniform float aFOO[]) { + float a = 1. / aFOO[programIndex]; + RET[programIndex] = max(0, a); +} + +export void result(uniform float RET[]) { + RET[programIndex] = 1. / (1+programIndex); +} diff --git a/tests_errors/func-param-mismatch-2.ispc b/tests_errors/func-param-mismatch-2.ispc index 09b27064..63c0239a 100644 --- a/tests_errors/func-param-mismatch-2.ispc +++ b/tests_errors/func-param-mismatch-2.ispc @@ -1,4 +1,4 @@ -// Unable to find matching overload for call to function +// Unable to find any matching overload for call to function void foo(int x); diff --git a/tests_errors/func-param-mismatch-3.ispc b/tests_errors/func-param-mismatch-3.ispc index 7e5f2b99..cb34c8a7 100644 --- a/tests_errors/func-param-mismatch-3.ispc +++ b/tests_errors/func-param-mismatch-3.ispc @@ -1,4 +1,4 @@ -// Unable to find matching overload for call to function +// Unable to find any matching overload for call to function void foo(int x); diff --git a/tests_errors/func-param-mismatch.ispc b/tests_errors/func-param-mismatch.ispc index c2bac94f..44a50903 100644 --- a/tests_errors/func-param-mismatch.ispc +++ b/tests_errors/func-param-mismatch.ispc @@ -1,4 +1,4 @@ -// Unable to find matching overload for call to function +// Unable to find any matching overload for call to function void foo();