Implement new, simpler function overload resolution algorithm.

We now give each conversion a cost and then find the minimum sum
of costs for all of the possible overloads.

Fixes issue #194.
This commit is contained in:
Matt Pharr
2012-03-27 13:25:11 -07:00
parent 247775d1ec
commit f8a39402a2
6 changed files with 214 additions and 269 deletions

443
expr.cpp
View File

@@ -3406,24 +3406,28 @@ FunctionCallExpr::TypeCheck() {
return NULL; return NULL;
std::vector<const Type *> argTypes; std::vector<const Type *> argTypes;
std::vector<bool> argCouldBeNULL; std::vector<bool> argCouldBeNULL, argIsConstant;
for (unsigned int i = 0; i < args->exprs.size(); ++i) { for (unsigned int i = 0; i < args->exprs.size(); ++i) {
if (args->exprs[i] == NULL) Expr *expr = args->exprs[i];
if (expr == NULL)
return NULL; return NULL;
const Type *t = args->exprs[i]->GetType(); const Type *t = expr->GetType();
if (t == NULL) if (t == NULL)
return NULL; return NULL;
argTypes.push_back(t);
argCouldBeNULL.push_back(lIsAllIntZeros(args->exprs[i]) || argTypes.push_back(t);
dynamic_cast<NullPointerExpr *>(args->exprs[i]) != NULL); argCouldBeNULL.push_back(lIsAllIntZeros(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
argIsConstant.push_back(dynamic_cast<ConstExpr *>(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
} }
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func); FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (fse != NULL) { if (fse != NULL) {
// Regular function call // Regular function call
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL,
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false) &argIsConstant) == false)
return NULL; return NULL;
func = ::TypeCheck(fse); func = ::TypeCheck(fse);
@@ -7403,282 +7407,183 @@ lPrintOverloadCandidates(SourcePos pos, const std::vector<Symbol *> &funcs,
} }
/** Helper function used for function overload resolution: returns zero /** Helper function used for function overload resolution: returns true if
cost if the call argument's type exactly matches the function argument converting the argument to the call type only requires a type
type (modulo a conversion to a const type if needed), otherwise reports conversion that won't lose information. Otherwise return false.
failure. */
*/ static bool
static int lIsMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
lExactMatch(const Type *callType, const Type *funcArgType) {
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(funcArgType) != NULL &&
dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = new ReferenceType(callType);
return Type::Equal(callType, funcArgType) ? 0 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match,
modulo conversion to a reference type if needed.
*/
static int
lMatchIgnoringReferences(const Type *callType, const Type *funcArgType) {
int prev = lExactMatch(callType, funcArgType);
if (prev != -1)
return prev;
callType = callType->GetReferenceTarget();
if (funcArgType->IsConstType())
callType = callType->GetAsConstType();
return Type::Equal(callType,
funcArgType->GetReferenceTarget()) ? 1 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if converting the argument to the call type only requires a type
conversion that won't lose information. Otherwise reports failure.
*/
static int
lMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
int prev = lMatchIgnoringReferences(callType, funcArgType);
if (prev != -1)
return prev;
const AtomicType *callAt = dynamic_cast<const AtomicType *>(callType); const AtomicType *callAt = dynamic_cast<const AtomicType *>(callType);
const AtomicType *funcAt = dynamic_cast<const AtomicType *>(funcArgType); const AtomicType *funcAt = dynamic_cast<const AtomicType *>(funcArgType);
if (callAt == NULL || funcAt == NULL) if (callAt == NULL || funcAt == NULL)
return -1; return false;
if (callAt->IsUniformType() != funcAt->IsUniformType()) if (callAt->IsUniformType() != funcAt->IsUniformType())
return -1; return false;
switch (callAt->basicType) { switch (callAt->basicType) {
case AtomicType::TYPE_BOOL: case AtomicType::TYPE_BOOL:
return 1; return true;
case AtomicType::TYPE_INT8: case AtomicType::TYPE_INT8:
case AtomicType::TYPE_UINT8: case AtomicType::TYPE_UINT8:
return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; return (funcAt->basicType != AtomicType::TYPE_BOOL);
case AtomicType::TYPE_INT16: case AtomicType::TYPE_INT16:
case AtomicType::TYPE_UINT16: case AtomicType::TYPE_UINT16:
return (funcAt->basicType != AtomicType::TYPE_BOOL && return (funcAt->basicType != AtomicType::TYPE_BOOL &&
funcAt->basicType != AtomicType::TYPE_INT8 && funcAt->basicType != AtomicType::TYPE_INT8 &&
funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; funcAt->basicType != AtomicType::TYPE_UINT8);
case AtomicType::TYPE_INT32: case AtomicType::TYPE_INT32:
case AtomicType::TYPE_UINT32: case AtomicType::TYPE_UINT32:
return (funcAt->basicType == AtomicType::TYPE_INT32 || return (funcAt->basicType == AtomicType::TYPE_INT32 ||
funcAt->basicType == AtomicType::TYPE_UINT32 || funcAt->basicType == AtomicType::TYPE_UINT32 ||
funcAt->basicType == AtomicType::TYPE_INT64 || funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; funcAt->basicType == AtomicType::TYPE_UINT64);
case AtomicType::TYPE_FLOAT: case AtomicType::TYPE_FLOAT:
return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; return (funcAt->basicType == AtomicType::TYPE_DOUBLE);
case AtomicType::TYPE_INT64: case AtomicType::TYPE_INT64:
case AtomicType::TYPE_UINT64: case AtomicType::TYPE_UINT64:
return (funcAt->basicType == AtomicType::TYPE_INT64 || return (funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; funcAt->basicType == AtomicType::TYPE_UINT64);
case AtomicType::TYPE_DOUBLE: case AtomicType::TYPE_DOUBLE:
return -1; return false;
default: default:
FATAL("Unhandled atomic type"); FATAL("Unhandled atomic type");
return -1; return false;
} }
} }
/** Helper function used for function overload resolution: returns a cost /** Helper function used for function overload resolution: returns true if
of 1 if the call argument type and the function argument type match if the call argument type and the function argument type match if we only
we only do a uniform -> varying type conversion but otherwise have do a uniform -> varying type conversion but otherwise have exactly the
exactly the same type. same type.
*/ */
static int static bool
lMatchIgnoringUniform(const Type *callType, const Type *funcArgType) { lIsMatchWithUniformToVarying(const Type *callType, const Type *funcArgType) {
int prev = lMatchWithTypeWidening(callType, funcArgType);
if (prev != -1)
return prev;
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
return (callType->IsUniformType() && return (callType->IsUniformType() &&
funcArgType->IsVaryingType() && funcArgType->IsVaryingType() &&
Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; Type::EqualIgnoringConst(callType->GetAsVaryingType(), funcArgType));
} }
/** Helper function used for function overload resolution: returns a cost /** Helper function used for function overload resolution: returns true if
of 1 if we can type convert from the call argument type to the function we can type convert from the call argument type to the function
argument type, but without doing a uniform -> varying conversion. argument type, but without doing a uniform -> varying conversion.
*/ */
static int static bool
lMatchWithTypeConvSameVariability(const Type *callType, lIsMatchWithTypeConvSameVariability(const Type *callType,
const Type *funcArgType) { const Type *funcArgType) {
int prev = lMatchIgnoringUniform(callType, funcArgType); return (CanConvertTypes(callType, funcArgType) &&
if (prev != -1) (callType->GetVariability() == funcArgType->GetVariability()));
return prev;
if (CanConvertTypes(callType, funcArgType) &&
(callType->IsUniformType() == funcArgType->IsUniformType()))
return 1;
else
return -1;
} }
/** Helper function used for function overload resolution: returns a cost /* Returns the set of function overloads that are potential matches, given
of 1 if there is any type conversion that gets us from the caller argCount values being passed as arguments to the function call.
argument type to the function argument type.
*/ */
static int std::vector<Symbol *>
lMatchWithTypeConv(const Type *callType, const Type *funcArgType) { FunctionSymbolExpr::getCandidateFunctions(int argCount) const {
int prev = lMatchWithTypeConvSameVariability(callType, funcArgType); std::vector<Symbol *> ret;
if (prev != -1) for (int i = 0; i < (int)candidateFunctions.size(); ++i) {
return prev;
return CanConvertTypes(callType, funcArgType) ? 0 : -1;
}
/** Given a set of potential matching functions and their associated cost,
return the one with the lowest cost, if unique. Otherwise, if multiple
functions match with the same cost, return NULL.
*/
static Symbol *
lGetBestMatch(std::vector<std::pair<int, Symbol *> > &matches) {
Assert(matches.size() > 0);
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
Symbol *match = NULL;
for (unsigned int i = 0; i < matches.size(); ++i) {
if (matches[i].first == minCost) {
if (match != NULL)
// multiple things had the same cost
return NULL;
else
match = matches[i].second;
}
}
return match;
}
/** See if we can find a single function from the set of overload options
based on the predicate function passed in. Returns true if no more
tries should be made to find a match, either due to success from
finding a single overloaded function that matches or failure due to
finding multiple ambiguous matches.
*/
bool
FunctionSymbolExpr::tryResolve(int (*matchFunc)(const Type *, const Type *),
SourcePos argPos,
const std::vector<const Type *> &callTypes,
const std::vector<bool> *argCouldBeNULL) {
const char *funName = candidateFunctions.front()->name.c_str();
std::vector<std::pair<int, Symbol *> > matches;
std::vector<Symbol *>::iterator iter;
for (iter = candidateFunctions.begin();
iter != candidateFunctions.end(); ++iter) {
// Loop over the set of candidate functions and try each one
Symbol *candidateFunction = *iter;
const FunctionType *ft = const FunctionType *ft =
dynamic_cast<const FunctionType *>(candidateFunction->type); dynamic_cast<const FunctionType *>(candidateFunctions[i]->type);
Assert(ft != NULL); Assert(ft != NULL);
// There's no way to match if the caller is passing more arguments // There's no way to match if the caller is passing more arguments
// than this function instance takes. // than this function instance takes.
if ((int)callTypes.size() > ft->GetNumParameters()) if (argCount > ft->GetNumParameters())
continue; continue;
int i; // Not enough arguments, and no default argument value to save us
// Note that we're looping over the caller arguments, not the if (argCount < ft->GetNumParameters() &&
// function arguments; it may be ok to have more arguments to the ft->GetParameterDefault(argCount) == NULL)
// function than are passed, if the function has default argument continue;
// values. This case is handled below.
int cost = 0;
for (i = 0; i < (int)callTypes.size(); ++i) {
// This may happen if there's an error earlier in compilation.
// It's kind of a silly to redundantly discover this for each
// potential match versus detecting this earlier in the
// matching process and just giving up.
const Type *paramType = ft->GetParameterType(i);
if (callTypes[i] == NULL || paramType == NULL || // Success
dynamic_cast<const FunctionType *>(callTypes[i]) != NULL) ret.push_back(candidateFunctions[i]);
return false; }
return ret;
}
int argCost = matchFunc(callTypes[i], paramType);
if (argCost == -1) { /** This function computes the value of a cost function that represents the
if (argCouldBeNULL != NULL && (*argCouldBeNULL)[i] == true && cost of calling a function of the given type with arguments of the
dynamic_cast<const PointerType *>(paramType) != NULL) given types. If it's not possible to call the function, regardless of
// If the passed argument value is zero and this is a any type conversions applied, a cost of -1 is returned.
// pointer type, then it can convert to a NULL value of */
// that pointer type. int
argCost = 0; FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype,
else const std::vector<const Type *> &argTypes,
// If the predicate function returns -1, we have failed no const std::vector<bool> *argCouldBeNULL,
// matter what else happens, so we stop trying const std::vector<bool> *argIsConstant) {
break; int costSum = 0;
}
cost += argCost; // In computing the cost function, we only worry about the actual
} // argument types--using function default parameter values is free for
if (i == (int)callTypes.size()) { // the purposes here...
// All of the arguments matched! for (int i = 0; i < (int)argTypes.size(); ++i) {
if (i == ft->GetNumParameters()) // The cost imposed by this argument will be a multiple of
// And we have exactly as many arguments as the function // costScale, which has a value set so that for each of the cost
// wants, so we're done. // buckets, even if all of the function arguments undergo the next
matches.push_back(std::make_pair(cost, candidateFunction)); // lower-cost conversion, the sum of their costs will be less than
else if (i < ft->GetNumParameters() && // a single instance of the next higher-cost conversion.
ft->GetParameterDefault(i) != NULL) int costScale = argTypes.size() + 1;
// Otherwise we can still make it if there are default
// arguments for the rest of the arguments! Because in const Type *fargType = ftype->GetParameterType(i);
// Module::AddFunction() we have verified that once the const Type *callType = argTypes[i];
// default arguments start, then all of the following ones
// have them as well. Therefore, we just need to check if // For convenience, normalize to non-const types (except for
// the arg we stopped at has a default value and we're // references, where const-ness matters). For all other types,
// done. // we're passing by value anyway, so const doesn't matter.
matches.push_back(std::make_pair(cost, candidateFunction)); if (dynamic_cast<const ReferenceType *>(callType) == NULL)
// otherwise, we don't have a match callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(fargType) == NULL)
fargType = fargType->GetAsNonConstType();
if (Type::Equal(callType, fargType))
// Perfect match: no cost
costSum += 0;
else if (argCouldBeNULL && (*argCouldBeNULL)[i] &&
dynamic_cast<const PointerType *>(fargType) != NULL)
// Passing NULL to a pointer-typed parameter is also a no-cost
// operation
costSum += 0;
else {
// If the argument is a compile-time constant, we'd like to
// count the cost of various conversions as much lower than the
// cost if it wasn't--so scale up the cost when this isn't the
// case..
if (argIsConstant == NULL || (*argIsConstant)[i] == false)
costScale *= 32;
if (Type::Equal(callType, fargType))
// Exact match (after dealing with references, above)
costSum += 1 * costScale;
else if (lIsMatchWithTypeWidening(callType, fargType))
costSum += 2 * costScale;
else if (lIsMatchWithUniformToVarying(callType, fargType))
costSum += 4 * costScale;
else if (lIsMatchWithTypeConvSameVariability(callType, fargType))
costSum += 8 * costScale;
else if (CanConvertTypes(callType, fargType))
costSum += 16 * costScale;
else
// Failure--no type conversion possible...
return -1;
} }
} }
if (matches.size() == 0) return costSum;
return false;
else if ((matchingFunc = lGetBestMatch(matches)) != NULL)
// We have a match!
return true;
else {
Error(pos, "Multiple overloaded instances of function \"%s\" matched.",
funName);
// select the matches that have the lowest cost
std::vector<Symbol *> bestMatches;
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
for (unsigned int i = 0; i < matches.size(); ++i)
if (matches[i].first == minCost)
bestMatches.push_back(matches[i].second);
// And print a useful error message
lPrintOverloadCandidates(argPos, bestMatches, callTypes, argCouldBeNULL);
// Stop trying to find more matches after an ambigious set of
// matches.
return true;
}
} }
bool bool
FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
const std::vector<const Type *> &argTypes, const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL) { const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant) {
const char *funName = candidateFunctions.front()->name.c_str();
triedToResolve = true; triedToResolve = true;
// Functions with names that start with "__" should only be various // Functions with names that start with "__" should only be various
@@ -7689,45 +7594,67 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
// called. // called.
bool exactMatchOnly = (name.substr(0,2) == "__"); bool exactMatchOnly = (name.substr(0,2) == "__");
// Is there an exact match that doesn't require any argument type // First, find the subset of overload candidates that take the same
// conversion (other than converting type -> reference type)? // number of arguments as have parameters (including functions that
if (tryResolve(lExactMatch, argPos, argTypes, argCouldBeNULL)) // take more arguments but have defaults starting no later than after
return true; // our last parameter).
std::vector<Symbol *> actualCandidates =
getCandidateFunctions(argTypes.size());
if (exactMatchOnly == false) { int bestMatchCost = 1<<30;
// Try to find a single match ignoring references std::vector<Symbol *> matches;
if (tryResolve(lMatchIgnoringReferences, argPos, argTypes, std::vector<int> candidateCosts;
argCouldBeNULL))
return true;
// Try to find an exact match via type widening--i.e. int8 -> if (actualCandidates.size() == 0)
// int16, etc.--things that don't lose data. goto failure;
if (tryResolve(lMatchWithTypeWidening, argPos, argTypes, argCouldBeNULL))
return true;
// Next try to see if there's a match via just uniform -> varying // Compute the cost for calling each of the candidate functions
// promotions. for (int i = 0; i < (int)actualCandidates.size(); ++i) {
if (tryResolve(lMatchIgnoringUniform, argPos, argTypes, argCouldBeNULL)) const FunctionType *ft =
return true; dynamic_cast<const FunctionType *>(actualCandidates[i]->type);
Assert(ft != NULL);
// Try to find a match via type conversion, but don't change candidateCosts.push_back(computeOverloadCost(ft, argTypes,
// unif->varying argCouldBeNULL,
if (tryResolve(lMatchWithTypeConvSameVariability, argPos, argTypes, argIsConstant));
argCouldBeNULL))
return true;
// Last chance: try to find a match via arbitrary type conversion.
if (tryResolve(lMatchWithTypeConv, argPos, argTypes, argCouldBeNULL))
return true;
} }
// failure :-( // Find the best cost, and then the candidate or candidates that have
const char *funName = candidateFunctions.front()->name.c_str(); // that cost.
Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", for (int i = 0; i < (int)candidateCosts.size(); ++i) {
funName, exactMatchOnly ? " only considering exact matches" : ""); if (candidateCosts[i] != -1 && candidateCosts[i] < bestMatchCost)
lPrintOverloadCandidates(argPos, candidateFunctions, argTypes, bestMatchCost = candidateCosts[i];
argCouldBeNULL); }
return false; // None of the candidates matched
if (bestMatchCost == (1<<30))
goto failure;
for (int i = 0; i < (int)candidateCosts.size(); ++i) {
if (candidateCosts[i] == bestMatchCost)
matches.push_back(actualCandidates[i]);
}
if (matches.size() == 1) {
// Only one match: success
matchingFunc = matches[0];
return true;
}
else if (matches.size() > 1) {
// Multiple matches: ambiguous
Error(pos, "Multiple overloaded functions matched call to function "
"\"%s\"%s.", funName,
exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, matches, argTypes, argCouldBeNULL);
return false;
}
else {
// No matches at all
failure:
Error(pos, "Unable to find any matching overload for call to function "
"\"%s\"%s.", funName,
exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, candidateFunctions, argTypes,
argCouldBeNULL);
return false;
}
} }

22
expr.h
View File

@@ -651,20 +651,26 @@ public:
function overloading, this method resolves which actual function function overloading, this method resolves which actual function
the arguments match best. If the argCouldBeNULL parameter is the arguments match best. If the argCouldBeNULL parameter is
non-NULL, each element indicates whether the corresponding argument non-NULL, each element indicates whether the corresponding argument
is the number zero, indicating that it could be a NULL pointer. is the number zero, indicating that it could be a NULL pointer, and
This parameter may be NULL (for cases where overload resolution is if argIsConstant is non-NULL, each element indicates whether the
being done just given type information without the parameter corresponding argument is a compile-time constant value. Both of
argument expressions being available. It returns true on success. these parameters may be NULL (for cases where overload resolution
is being done just given type information without the parameter
argument expressions being available. This function returns true
on success.
*/ */
bool ResolveOverloads(SourcePos argPos, bool ResolveOverloads(SourcePos argPos,
const std::vector<const Type *> &argTypes, const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL = NULL); const std::vector<bool> *argCouldBeNULL = NULL,
const std::vector<bool> *argIsConstant = NULL);
Symbol *GetMatchingFunction(); Symbol *GetMatchingFunction();
private: private:
bool tryResolve(int (*matchFunc)(const Type *, const Type *), std::vector<Symbol *> getCandidateFunctions(int argCount) const;
SourcePos argPos, const std::vector<const Type *> &argTypes, static int computeOverloadCost(const FunctionType *ftype,
const std::vector<bool> *argCouldBeNULL); const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant);
/** Name of the function that is being called. */ /** Name of the function that is being called. */
std::string name; std::string name;

View File

@@ -0,0 +1,12 @@
export uniform int width() { return programCount; }
export void f_f(uniform float RET[], uniform float aFOO[]) {
float a = 1. / aFOO[programIndex];
RET[programIndex] = max(0, a);
}
export void result(uniform float RET[]) {
RET[programIndex] = 1. / (1+programIndex);
}

View File

@@ -1,4 +1,4 @@
// Unable to find matching overload for call to function // Unable to find any matching overload for call to function
void foo(int x); void foo(int x);

View File

@@ -1,4 +1,4 @@
// Unable to find matching overload for call to function // Unable to find any matching overload for call to function
void foo(int x); void foo(int x);

View File

@@ -1,4 +1,4 @@
// Unable to find matching overload for call to function // Unable to find any matching overload for call to function
void foo(); void foo();