Implement new, simpler function overload resolution algorithm.

We now give each conversion a cost and then find the minimum sum
of costs for all of the possible overloads.

Fixes issue #194.
This commit is contained in:
Matt Pharr
2012-03-27 13:25:11 -07:00
parent 247775d1ec
commit f8a39402a2
6 changed files with 214 additions and 269 deletions

443
expr.cpp
View File

@@ -3406,24 +3406,28 @@ FunctionCallExpr::TypeCheck() {
return NULL;
std::vector<const Type *> argTypes;
std::vector<bool> argCouldBeNULL;
std::vector<bool> argCouldBeNULL, argIsConstant;
for (unsigned int i = 0; i < args->exprs.size(); ++i) {
if (args->exprs[i] == NULL)
Expr *expr = args->exprs[i];
if (expr == NULL)
return NULL;
const Type *t = args->exprs[i]->GetType();
const Type *t = expr->GetType();
if (t == NULL)
return NULL;
argTypes.push_back(t);
argCouldBeNULL.push_back(lIsAllIntZeros(args->exprs[i]) ||
dynamic_cast<NullPointerExpr *>(args->exprs[i]) != NULL);
argTypes.push_back(t);
argCouldBeNULL.push_back(lIsAllIntZeros(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
argIsConstant.push_back(dynamic_cast<ConstExpr *>(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
}
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (fse != NULL) {
// Regular function call
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false)
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL,
&argIsConstant) == false)
return NULL;
func = ::TypeCheck(fse);
@@ -7403,282 +7407,183 @@ lPrintOverloadCandidates(SourcePos pos, const std::vector<Symbol *> &funcs,
}
/** Helper function used for function overload resolution: returns zero
cost if the call argument's type exactly matches the function argument
type (modulo a conversion to a const type if needed), otherwise reports
failure.
*/
static int
lExactMatch(const Type *callType, const Type *funcArgType) {
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(funcArgType) != NULL &&
dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = new ReferenceType(callType);
return Type::Equal(callType, funcArgType) ? 0 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match,
modulo conversion to a reference type if needed.
*/
static int
lMatchIgnoringReferences(const Type *callType, const Type *funcArgType) {
int prev = lExactMatch(callType, funcArgType);
if (prev != -1)
return prev;
callType = callType->GetReferenceTarget();
if (funcArgType->IsConstType())
callType = callType->GetAsConstType();
return Type::Equal(callType,
funcArgType->GetReferenceTarget()) ? 1 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if converting the argument to the call type only requires a type
conversion that won't lose information. Otherwise reports failure.
*/
static int
lMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
int prev = lMatchIgnoringReferences(callType, funcArgType);
if (prev != -1)
return prev;
/** Helper function used for function overload resolution: returns true if
converting the argument to the call type only requires a type
conversion that won't lose information. Otherwise return false.
*/
static bool
lIsMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
const AtomicType *callAt = dynamic_cast<const AtomicType *>(callType);
const AtomicType *funcAt = dynamic_cast<const AtomicType *>(funcArgType);
if (callAt == NULL || funcAt == NULL)
return -1;
return false;
if (callAt->IsUniformType() != funcAt->IsUniformType())
return -1;
return false;
switch (callAt->basicType) {
case AtomicType::TYPE_BOOL:
return 1;
return true;
case AtomicType::TYPE_INT8:
case AtomicType::TYPE_UINT8:
return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1;
return (funcAt->basicType != AtomicType::TYPE_BOOL);
case AtomicType::TYPE_INT16:
case AtomicType::TYPE_UINT16:
return (funcAt->basicType != AtomicType::TYPE_BOOL &&
funcAt->basicType != AtomicType::TYPE_INT8 &&
funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1;
funcAt->basicType != AtomicType::TYPE_UINT8);
case AtomicType::TYPE_INT32:
case AtomicType::TYPE_UINT32:
return (funcAt->basicType == AtomicType::TYPE_INT32 ||
funcAt->basicType == AtomicType::TYPE_UINT32 ||
funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
funcAt->basicType == AtomicType::TYPE_UINT64);
case AtomicType::TYPE_FLOAT:
return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1;
return (funcAt->basicType == AtomicType::TYPE_DOUBLE);
case AtomicType::TYPE_INT64:
case AtomicType::TYPE_UINT64:
return (funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
funcAt->basicType == AtomicType::TYPE_UINT64);
case AtomicType::TYPE_DOUBLE:
return -1;
return false;
default:
FATAL("Unhandled atomic type");
return -1;
return false;
}
}
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match if
we only do a uniform -> varying type conversion but otherwise have
exactly the same type.
/** Helper function used for function overload resolution: returns true if
the call argument type and the function argument type match if we only
do a uniform -> varying type conversion but otherwise have exactly the
same type.
*/
static int
lMatchIgnoringUniform(const Type *callType, const Type *funcArgType) {
int prev = lMatchWithTypeWidening(callType, funcArgType);
if (prev != -1)
return prev;
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
static bool
lIsMatchWithUniformToVarying(const Type *callType, const Type *funcArgType) {
return (callType->IsUniformType() &&
funcArgType->IsVaryingType() &&
Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1;
Type::EqualIgnoringConst(callType->GetAsVaryingType(), funcArgType));
}
/** Helper function used for function overload resolution: returns a cost
of 1 if we can type convert from the call argument type to the function
/** Helper function used for function overload resolution: returns true if
we can type convert from the call argument type to the function
argument type, but without doing a uniform -> varying conversion.
*/
static int
lMatchWithTypeConvSameVariability(const Type *callType,
const Type *funcArgType) {
int prev = lMatchIgnoringUniform(callType, funcArgType);
if (prev != -1)
return prev;
if (CanConvertTypes(callType, funcArgType) &&
(callType->IsUniformType() == funcArgType->IsUniformType()))
return 1;
else
return -1;
static bool
lIsMatchWithTypeConvSameVariability(const Type *callType,
const Type *funcArgType) {
return (CanConvertTypes(callType, funcArgType) &&
(callType->GetVariability() == funcArgType->GetVariability()));
}
/** Helper function used for function overload resolution: returns a cost
of 1 if there is any type conversion that gets us from the caller
argument type to the function argument type.
/* Returns the set of function overloads that are potential matches, given
argCount values being passed as arguments to the function call.
*/
static int
lMatchWithTypeConv(const Type *callType, const Type *funcArgType) {
int prev = lMatchWithTypeConvSameVariability(callType, funcArgType);
if (prev != -1)
return prev;
return CanConvertTypes(callType, funcArgType) ? 0 : -1;
}
/** Given a set of potential matching functions and their associated cost,
return the one with the lowest cost, if unique. Otherwise, if multiple
functions match with the same cost, return NULL.
*/
static Symbol *
lGetBestMatch(std::vector<std::pair<int, Symbol *> > &matches) {
Assert(matches.size() > 0);
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
Symbol *match = NULL;
for (unsigned int i = 0; i < matches.size(); ++i) {
if (matches[i].first == minCost) {
if (match != NULL)
// multiple things had the same cost
return NULL;
else
match = matches[i].second;
}
}
return match;
}
/** See if we can find a single function from the set of overload options
based on the predicate function passed in. Returns true if no more
tries should be made to find a match, either due to success from
finding a single overloaded function that matches or failure due to
finding multiple ambiguous matches.
*/
bool
FunctionSymbolExpr::tryResolve(int (*matchFunc)(const Type *, const Type *),
SourcePos argPos,
const std::vector<const Type *> &callTypes,
const std::vector<bool> *argCouldBeNULL) {
const char *funName = candidateFunctions.front()->name.c_str();
std::vector<std::pair<int, Symbol *> > matches;
std::vector<Symbol *>::iterator iter;
for (iter = candidateFunctions.begin();
iter != candidateFunctions.end(); ++iter) {
// Loop over the set of candidate functions and try each one
Symbol *candidateFunction = *iter;
std::vector<Symbol *>
FunctionSymbolExpr::getCandidateFunctions(int argCount) const {
std::vector<Symbol *> ret;
for (int i = 0; i < (int)candidateFunctions.size(); ++i) {
const FunctionType *ft =
dynamic_cast<const FunctionType *>(candidateFunction->type);
dynamic_cast<const FunctionType *>(candidateFunctions[i]->type);
Assert(ft != NULL);
// There's no way to match if the caller is passing more arguments
// than this function instance takes.
if ((int)callTypes.size() > ft->GetNumParameters())
if (argCount > ft->GetNumParameters())
continue;
int i;
// Note that we're looping over the caller arguments, not the
// function arguments; it may be ok to have more arguments to the
// function than are passed, if the function has default argument
// values. This case is handled below.
int cost = 0;
for (i = 0; i < (int)callTypes.size(); ++i) {
// This may happen if there's an error earlier in compilation.
// It's kind of a silly to redundantly discover this for each
// potential match versus detecting this earlier in the
// matching process and just giving up.
const Type *paramType = ft->GetParameterType(i);
// Not enough arguments, and no default argument value to save us
if (argCount < ft->GetNumParameters() &&
ft->GetParameterDefault(argCount) == NULL)
continue;
if (callTypes[i] == NULL || paramType == NULL ||
dynamic_cast<const FunctionType *>(callTypes[i]) != NULL)
return false;
// Success
ret.push_back(candidateFunctions[i]);
}
return ret;
}
int argCost = matchFunc(callTypes[i], paramType);
if (argCost == -1) {
if (argCouldBeNULL != NULL && (*argCouldBeNULL)[i] == true &&
dynamic_cast<const PointerType *>(paramType) != NULL)
// If the passed argument value is zero and this is a
// pointer type, then it can convert to a NULL value of
// that pointer type.
argCost = 0;
else
// If the predicate function returns -1, we have failed no
// matter what else happens, so we stop trying
break;
}
cost += argCost;
}
if (i == (int)callTypes.size()) {
// All of the arguments matched!
if (i == ft->GetNumParameters())
// And we have exactly as many arguments as the function
// wants, so we're done.
matches.push_back(std::make_pair(cost, candidateFunction));
else if (i < ft->GetNumParameters() &&
ft->GetParameterDefault(i) != NULL)
// Otherwise we can still make it if there are default
// arguments for the rest of the arguments! Because in
// Module::AddFunction() we have verified that once the
// default arguments start, then all of the following ones
// have them as well. Therefore, we just need to check if
// the arg we stopped at has a default value and we're
// done.
matches.push_back(std::make_pair(cost, candidateFunction));
// otherwise, we don't have a match
/** This function computes the value of a cost function that represents the
cost of calling a function of the given type with arguments of the
given types. If it's not possible to call the function, regardless of
any type conversions applied, a cost of -1 is returned.
*/
int
FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype,
const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant) {
int costSum = 0;
// In computing the cost function, we only worry about the actual
// argument types--using function default parameter values is free for
// the purposes here...
for (int i = 0; i < (int)argTypes.size(); ++i) {
// The cost imposed by this argument will be a multiple of
// costScale, which has a value set so that for each of the cost
// buckets, even if all of the function arguments undergo the next
// lower-cost conversion, the sum of their costs will be less than
// a single instance of the next higher-cost conversion.
int costScale = argTypes.size() + 1;
const Type *fargType = ftype->GetParameterType(i);
const Type *callType = argTypes[i];
// For convenience, normalize to non-const types (except for
// references, where const-ness matters). For all other types,
// we're passing by value anyway, so const doesn't matter.
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(fargType) == NULL)
fargType = fargType->GetAsNonConstType();
if (Type::Equal(callType, fargType))
// Perfect match: no cost
costSum += 0;
else if (argCouldBeNULL && (*argCouldBeNULL)[i] &&
dynamic_cast<const PointerType *>(fargType) != NULL)
// Passing NULL to a pointer-typed parameter is also a no-cost
// operation
costSum += 0;
else {
// If the argument is a compile-time constant, we'd like to
// count the cost of various conversions as much lower than the
// cost if it wasn't--so scale up the cost when this isn't the
// case..
if (argIsConstant == NULL || (*argIsConstant)[i] == false)
costScale *= 32;
if (Type::Equal(callType, fargType))
// Exact match (after dealing with references, above)
costSum += 1 * costScale;
else if (lIsMatchWithTypeWidening(callType, fargType))
costSum += 2 * costScale;
else if (lIsMatchWithUniformToVarying(callType, fargType))
costSum += 4 * costScale;
else if (lIsMatchWithTypeConvSameVariability(callType, fargType))
costSum += 8 * costScale;
else if (CanConvertTypes(callType, fargType))
costSum += 16 * costScale;
else
// Failure--no type conversion possible...
return -1;
}
}
if (matches.size() == 0)
return false;
else if ((matchingFunc = lGetBestMatch(matches)) != NULL)
// We have a match!
return true;
else {
Error(pos, "Multiple overloaded instances of function \"%s\" matched.",
funName);
// select the matches that have the lowest cost
std::vector<Symbol *> bestMatches;
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
for (unsigned int i = 0; i < matches.size(); ++i)
if (matches[i].first == minCost)
bestMatches.push_back(matches[i].second);
// And print a useful error message
lPrintOverloadCandidates(argPos, bestMatches, callTypes, argCouldBeNULL);
// Stop trying to find more matches after an ambigious set of
// matches.
return true;
}
return costSum;
}
bool
FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL) {
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant) {
const char *funName = candidateFunctions.front()->name.c_str();
triedToResolve = true;
// Functions with names that start with "__" should only be various
@@ -7689,45 +7594,67 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
// called.
bool exactMatchOnly = (name.substr(0,2) == "__");
// Is there an exact match that doesn't require any argument type
// conversion (other than converting type -> reference type)?
if (tryResolve(lExactMatch, argPos, argTypes, argCouldBeNULL))
return true;
// First, find the subset of overload candidates that take the same
// number of arguments as have parameters (including functions that
// take more arguments but have defaults starting no later than after
// our last parameter).
std::vector<Symbol *> actualCandidates =
getCandidateFunctions(argTypes.size());
if (exactMatchOnly == false) {
// Try to find a single match ignoring references
if (tryResolve(lMatchIgnoringReferences, argPos, argTypes,
argCouldBeNULL))
return true;
int bestMatchCost = 1<<30;
std::vector<Symbol *> matches;
std::vector<int> candidateCosts;
// Try to find an exact match via type widening--i.e. int8 ->
// int16, etc.--things that don't lose data.
if (tryResolve(lMatchWithTypeWidening, argPos, argTypes, argCouldBeNULL))
return true;
if (actualCandidates.size() == 0)
goto failure;
// Next try to see if there's a match via just uniform -> varying
// promotions.
if (tryResolve(lMatchIgnoringUniform, argPos, argTypes, argCouldBeNULL))
return true;
// Try to find a match via type conversion, but don't change
// unif->varying
if (tryResolve(lMatchWithTypeConvSameVariability, argPos, argTypes,
argCouldBeNULL))
return true;
// Last chance: try to find a match via arbitrary type conversion.
if (tryResolve(lMatchWithTypeConv, argPos, argTypes, argCouldBeNULL))
return true;
// Compute the cost for calling each of the candidate functions
for (int i = 0; i < (int)actualCandidates.size(); ++i) {
const FunctionType *ft =
dynamic_cast<const FunctionType *>(actualCandidates[i]->type);
Assert(ft != NULL);
candidateCosts.push_back(computeOverloadCost(ft, argTypes,
argCouldBeNULL,
argIsConstant));
}
// failure :-(
const char *funName = candidateFunctions.front()->name.c_str();
Error(pos, "Unable to find matching overload for call to function \"%s\"%s.",
funName, exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, candidateFunctions, argTypes,
argCouldBeNULL);
return false;
// Find the best cost, and then the candidate or candidates that have
// that cost.
for (int i = 0; i < (int)candidateCosts.size(); ++i) {
if (candidateCosts[i] != -1 && candidateCosts[i] < bestMatchCost)
bestMatchCost = candidateCosts[i];
}
// None of the candidates matched
if (bestMatchCost == (1<<30))
goto failure;
for (int i = 0; i < (int)candidateCosts.size(); ++i) {
if (candidateCosts[i] == bestMatchCost)
matches.push_back(actualCandidates[i]);
}
if (matches.size() == 1) {
// Only one match: success
matchingFunc = matches[0];
return true;
}
else if (matches.size() > 1) {
// Multiple matches: ambiguous
Error(pos, "Multiple overloaded functions matched call to function "
"\"%s\"%s.", funName,
exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, matches, argTypes, argCouldBeNULL);
return false;
}
else {
// No matches at all
failure:
Error(pos, "Unable to find any matching overload for call to function "
"\"%s\"%s.", funName,
exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, candidateFunctions, argTypes,
argCouldBeNULL);
return false;
}
}

22
expr.h
View File

@@ -651,20 +651,26 @@ public:
function overloading, this method resolves which actual function
the arguments match best. If the argCouldBeNULL parameter is
non-NULL, each element indicates whether the corresponding argument
is the number zero, indicating that it could be a NULL pointer.
This parameter may be NULL (for cases where overload resolution is
being done just given type information without the parameter
argument expressions being available. It returns true on success.
is the number zero, indicating that it could be a NULL pointer, and
if argIsConstant is non-NULL, each element indicates whether the
corresponding argument is a compile-time constant value. Both of
these parameters may be NULL (for cases where overload resolution
is being done just given type information without the parameter
argument expressions being available. This function returns true
on success.
*/
bool ResolveOverloads(SourcePos argPos,
const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL = NULL);
const std::vector<bool> *argCouldBeNULL = NULL,
const std::vector<bool> *argIsConstant = NULL);
Symbol *GetMatchingFunction();
private:
bool tryResolve(int (*matchFunc)(const Type *, const Type *),
SourcePos argPos, const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL);
std::vector<Symbol *> getCandidateFunctions(int argCount) const;
static int computeOverloadCost(const FunctionType *ftype,
const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant);
/** Name of the function that is being called. */
std::string name;

View File

@@ -0,0 +1,12 @@
export uniform int width() { return programCount; }
export void f_f(uniform float RET[], uniform float aFOO[]) {
float a = 1. / aFOO[programIndex];
RET[programIndex] = max(0, a);
}
export void result(uniform float RET[]) {
RET[programIndex] = 1. / (1+programIndex);
}

View File

@@ -1,4 +1,4 @@
// Unable to find matching overload for call to function
// Unable to find any matching overload for call to function
void foo(int x);

View File

@@ -1,4 +1,4 @@
// Unable to find matching overload for call to function
// Unable to find any matching overload for call to function
void foo(int x);

View File

@@ -1,4 +1,4 @@
// Unable to find matching overload for call to function
// Unable to find any matching overload for call to function
void foo();