Implement new, simpler function overload resolution algorithm.

We now give each conversion a cost and then find the minimum sum
of costs for all of the possible overloads.

Fixes issue #194.
This commit is contained in:
Matt Pharr
2012-03-27 13:25:11 -07:00
parent 247775d1ec
commit f8a39402a2
6 changed files with 214 additions and 269 deletions

443
expr.cpp
View File

@@ -3406,24 +3406,28 @@ FunctionCallExpr::TypeCheck() {
return NULL;
std::vector<const Type *> argTypes;
std::vector<bool> argCouldBeNULL;
std::vector<bool> argCouldBeNULL, argIsConstant;
for (unsigned int i = 0; i < args->exprs.size(); ++i) {
if (args->exprs[i] == NULL)
Expr *expr = args->exprs[i];
if (expr == NULL)
return NULL;
const Type *t = args->exprs[i]->GetType();
const Type *t = expr->GetType();
if (t == NULL)
return NULL;
argTypes.push_back(t);
argCouldBeNULL.push_back(lIsAllIntZeros(args->exprs[i]) ||
dynamic_cast<NullPointerExpr *>(args->exprs[i]) != NULL);
argTypes.push_back(t);
argCouldBeNULL.push_back(lIsAllIntZeros(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
argIsConstant.push_back(dynamic_cast<ConstExpr *>(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
}
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (fse != NULL) {
// Regular function call
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL) == false)
if (fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL,
&argIsConstant) == false)
return NULL;
func = ::TypeCheck(fse);
@@ -7403,282 +7407,183 @@ lPrintOverloadCandidates(SourcePos pos, const std::vector<Symbol *> &funcs,
}
/** Helper function used for function overload resolution: returns zero
cost if the call argument's type exactly matches the function argument
type (modulo a conversion to a const type if needed), otherwise reports
failure.
*/
static int
lExactMatch(const Type *callType, const Type *funcArgType) {
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(funcArgType) != NULL &&
dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = new ReferenceType(callType);
return Type::Equal(callType, funcArgType) ? 0 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match,
modulo conversion to a reference type if needed.
*/
static int
lMatchIgnoringReferences(const Type *callType, const Type *funcArgType) {
int prev = lExactMatch(callType, funcArgType);
if (prev != -1)
return prev;
callType = callType->GetReferenceTarget();
if (funcArgType->IsConstType())
callType = callType->GetAsConstType();
return Type::Equal(callType,
funcArgType->GetReferenceTarget()) ? 1 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if converting the argument to the call type only requires a type
conversion that won't lose information. Otherwise reports failure.
*/
static int
lMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
int prev = lMatchIgnoringReferences(callType, funcArgType);
if (prev != -1)
return prev;
/** Helper function used for function overload resolution: returns true if
converting the argument to the call type only requires a type
conversion that won't lose information. Otherwise return false.
*/
static bool
lIsMatchWithTypeWidening(const Type *callType, const Type *funcArgType) {
const AtomicType *callAt = dynamic_cast<const AtomicType *>(callType);
const AtomicType *funcAt = dynamic_cast<const AtomicType *>(funcArgType);
if (callAt == NULL || funcAt == NULL)
return -1;
return false;
if (callAt->IsUniformType() != funcAt->IsUniformType())
return -1;
return false;
switch (callAt->basicType) {
case AtomicType::TYPE_BOOL:
return 1;
return true;
case AtomicType::TYPE_INT8:
case AtomicType::TYPE_UINT8:
return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1;
return (funcAt->basicType != AtomicType::TYPE_BOOL);
case AtomicType::TYPE_INT16:
case AtomicType::TYPE_UINT16:
return (funcAt->basicType != AtomicType::TYPE_BOOL &&
funcAt->basicType != AtomicType::TYPE_INT8 &&
funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1;
funcAt->basicType != AtomicType::TYPE_UINT8);
case AtomicType::TYPE_INT32:
case AtomicType::TYPE_UINT32:
return (funcAt->basicType == AtomicType::TYPE_INT32 ||
funcAt->basicType == AtomicType::TYPE_UINT32 ||
funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
funcAt->basicType == AtomicType::TYPE_UINT64);
case AtomicType::TYPE_FLOAT:
return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1;
return (funcAt->basicType == AtomicType::TYPE_DOUBLE);
case AtomicType::TYPE_INT64:
case AtomicType::TYPE_UINT64:
return (funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
funcAt->basicType == AtomicType::TYPE_UINT64);
case AtomicType::TYPE_DOUBLE:
return -1;
return false;
default:
FATAL("Unhandled atomic type");
return -1;
return false;
}
}
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match if
we only do a uniform -> varying type conversion but otherwise have
exactly the same type.
/** Helper function used for function overload resolution: returns true if
the call argument type and the function argument type match if we only
do a uniform -> varying type conversion but otherwise have exactly the
same type.
*/
static int
lMatchIgnoringUniform(const Type *callType, const Type *funcArgType) {
int prev = lMatchWithTypeWidening(callType, funcArgType);
if (prev != -1)
return prev;
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
static bool
lIsMatchWithUniformToVarying(const Type *callType, const Type *funcArgType) {
return (callType->IsUniformType() &&
funcArgType->IsVaryingType() &&
Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1;
Type::EqualIgnoringConst(callType->GetAsVaryingType(), funcArgType));
}
/** Helper function used for function overload resolution: returns a cost
of 1 if we can type convert from the call argument type to the function
/** Helper function used for function overload resolution: returns true if
we can type convert from the call argument type to the function
argument type, but without doing a uniform -> varying conversion.
*/
static int
lMatchWithTypeConvSameVariability(const Type *callType,
const Type *funcArgType) {
int prev = lMatchIgnoringUniform(callType, funcArgType);
if (prev != -1)
return prev;
if (CanConvertTypes(callType, funcArgType) &&
(callType->IsUniformType() == funcArgType->IsUniformType()))
return 1;
else
return -1;
static bool
lIsMatchWithTypeConvSameVariability(const Type *callType,
const Type *funcArgType) {
return (CanConvertTypes(callType, funcArgType) &&
(callType->GetVariability() == funcArgType->GetVariability()));
}
/** Helper function used for function overload resolution: returns a cost
of 1 if there is any type conversion that gets us from the caller
argument type to the function argument type.
/* Returns the set of function overloads that are potential matches, given
argCount values being passed as arguments to the function call.
*/
static int
lMatchWithTypeConv(const Type *callType, const Type *funcArgType) {
int prev = lMatchWithTypeConvSameVariability(callType, funcArgType);
if (prev != -1)
return prev;
return CanConvertTypes(callType, funcArgType) ? 0 : -1;
}
/** Given a set of potential matching functions and their associated cost,
return the one with the lowest cost, if unique. Otherwise, if multiple
functions match with the same cost, return NULL.
*/
static Symbol *
lGetBestMatch(std::vector<std::pair<int, Symbol *> > &matches) {
Assert(matches.size() > 0);
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
Symbol *match = NULL;
for (unsigned int i = 0; i < matches.size(); ++i) {
if (matches[i].first == minCost) {
if (match != NULL)
// multiple things had the same cost
return NULL;
else
match = matches[i].second;
}
}
return match;
}
/** See if we can find a single function from the set of overload options
based on the predicate function passed in. Returns true if no more
tries should be made to find a match, either due to success from
finding a single overloaded function that matches or failure due to
finding multiple ambiguous matches.
*/
bool
FunctionSymbolExpr::tryResolve(int (*matchFunc)(const Type *, const Type *),
SourcePos argPos,
const std::vector<const Type *> &callTypes,
const std::vector<bool> *argCouldBeNULL) {
const char *funName = candidateFunctions.front()->name.c_str();
std::vector<std::pair<int, Symbol *> > matches;
std::vector<Symbol *>::iterator iter;
for (iter = candidateFunctions.begin();
iter != candidateFunctions.end(); ++iter) {
// Loop over the set of candidate functions and try each one
Symbol *candidateFunction = *iter;
std::vector<Symbol *>
FunctionSymbolExpr::getCandidateFunctions(int argCount) const {
std::vector<Symbol *> ret;
for (int i = 0; i < (int)candidateFunctions.size(); ++i) {
const FunctionType *ft =
dynamic_cast<const FunctionType *>(candidateFunction->type);
dynamic_cast<const FunctionType *>(candidateFunctions[i]->type);
Assert(ft != NULL);
// There's no way to match if the caller is passing more arguments
// than this function instance takes.
if ((int)callTypes.size() > ft->GetNumParameters())
if (argCount > ft->GetNumParameters())
continue;
int i;
// Note that we're looping over the caller arguments, not the
// function arguments; it may be ok to have more arguments to the
// function than are passed, if the function has default argument
// values. This case is handled below.
int cost = 0;
for (i = 0; i < (int)callTypes.size(); ++i) {
// This may happen if there's an error earlier in compilation.
// It's kind of a silly to redundantly discover this for each
// potential match versus detecting this earlier in the
// matching process and just giving up.
const Type *paramType = ft->GetParameterType(i);
// Not enough arguments, and no default argument value to save us
if (argCount < ft->GetNumParameters() &&
ft->GetParameterDefault(argCount) == NULL)
continue;
if (callTypes[i] == NULL || paramType == NULL ||
dynamic_cast<const FunctionType *>(callTypes[i]) != NULL)
return false;
// Success
ret.push_back(candidateFunctions[i]);
}
return ret;
}
int argCost = matchFunc(callTypes[i], paramType);
if (argCost == -1) {
if (argCouldBeNULL != NULL && (*argCouldBeNULL)[i] == true &&
dynamic_cast<const PointerType *>(paramType) != NULL)
// If the passed argument value is zero and this is a
// pointer type, then it can convert to a NULL value of
// that pointer type.
argCost = 0;
else
// If the predicate function returns -1, we have failed no
// matter what else happens, so we stop trying
break;
}
cost += argCost;
}
if (i == (int)callTypes.size()) {
// All of the arguments matched!
if (i == ft->GetNumParameters())
// And we have exactly as many arguments as the function
// wants, so we're done.
matches.push_back(std::make_pair(cost, candidateFunction));
else if (i < ft->GetNumParameters() &&
ft->GetParameterDefault(i) != NULL)
// Otherwise we can still make it if there are default
// arguments for the rest of the arguments! Because in
// Module::AddFunction() we have verified that once the
// default arguments start, then all of the following ones
// have them as well. Therefore, we just need to check if
// the arg we stopped at has a default value and we're
// done.
matches.push_back(std::make_pair(cost, candidateFunction));
// otherwise, we don't have a match
/** This function computes the value of a cost function that represents the
cost of calling a function of the given type with arguments of the
given types. If it's not possible to call the function, regardless of
any type conversions applied, a cost of -1 is returned.
*/
int
FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype,
const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant) {
int costSum = 0;
// In computing the cost function, we only worry about the actual
// argument types--using function default parameter values is free for
// the purposes here...
for (int i = 0; i < (int)argTypes.size(); ++i) {
// The cost imposed by this argument will be a multiple of
// costScale, which has a value set so that for each of the cost
// buckets, even if all of the function arguments undergo the next
// lower-cost conversion, the sum of their costs will be less than
// a single instance of the next higher-cost conversion.
int costScale = argTypes.size() + 1;
const Type *fargType = ftype->GetParameterType(i);
const Type *callType = argTypes[i];
// For convenience, normalize to non-const types (except for
// references, where const-ness matters). For all other types,
// we're passing by value anyway, so const doesn't matter.
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(fargType) == NULL)
fargType = fargType->GetAsNonConstType();
if (Type::Equal(callType, fargType))
// Perfect match: no cost
costSum += 0;
else if (argCouldBeNULL && (*argCouldBeNULL)[i] &&
dynamic_cast<const PointerType *>(fargType) != NULL)
// Passing NULL to a pointer-typed parameter is also a no-cost
// operation
costSum += 0;
else {
// If the argument is a compile-time constant, we'd like to
// count the cost of various conversions as much lower than the
// cost if it wasn't--so scale up the cost when this isn't the
// case..
if (argIsConstant == NULL || (*argIsConstant)[i] == false)
costScale *= 32;
if (Type::Equal(callType, fargType))
// Exact match (after dealing with references, above)
costSum += 1 * costScale;
else if (lIsMatchWithTypeWidening(callType, fargType))
costSum += 2 * costScale;
else if (lIsMatchWithUniformToVarying(callType, fargType))
costSum += 4 * costScale;
else if (lIsMatchWithTypeConvSameVariability(callType, fargType))
costSum += 8 * costScale;
else if (CanConvertTypes(callType, fargType))
costSum += 16 * costScale;
else
// Failure--no type conversion possible...
return -1;
}
}
if (matches.size() == 0)
return false;
else if ((matchingFunc = lGetBestMatch(matches)) != NULL)
// We have a match!
return true;
else {
Error(pos, "Multiple overloaded instances of function \"%s\" matched.",
funName);
// select the matches that have the lowest cost
std::vector<Symbol *> bestMatches;
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
for (unsigned int i = 0; i < matches.size(); ++i)
if (matches[i].first == minCost)
bestMatches.push_back(matches[i].second);
// And print a useful error message
lPrintOverloadCandidates(argPos, bestMatches, callTypes, argCouldBeNULL);
// Stop trying to find more matches after an ambigious set of
// matches.
return true;
}
return costSum;
}
bool
FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
const std::vector<const Type *> &argTypes,
const std::vector<bool> *argCouldBeNULL) {
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant) {
const char *funName = candidateFunctions.front()->name.c_str();
triedToResolve = true;
// Functions with names that start with "__" should only be various
@@ -7689,45 +7594,67 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
// called.
bool exactMatchOnly = (name.substr(0,2) == "__");
// Is there an exact match that doesn't require any argument type
// conversion (other than converting type -> reference type)?
if (tryResolve(lExactMatch, argPos, argTypes, argCouldBeNULL))
return true;
// First, find the subset of overload candidates that take the same
// number of arguments as have parameters (including functions that
// take more arguments but have defaults starting no later than after
// our last parameter).
std::vector<Symbol *> actualCandidates =
getCandidateFunctions(argTypes.size());
if (exactMatchOnly == false) {
// Try to find a single match ignoring references
if (tryResolve(lMatchIgnoringReferences, argPos, argTypes,
argCouldBeNULL))
return true;
int bestMatchCost = 1<<30;
std::vector<Symbol *> matches;
std::vector<int> candidateCosts;
// Try to find an exact match via type widening--i.e. int8 ->
// int16, etc.--things that don't lose data.
if (tryResolve(lMatchWithTypeWidening, argPos, argTypes, argCouldBeNULL))
return true;
if (actualCandidates.size() == 0)
goto failure;
// Next try to see if there's a match via just uniform -> varying
// promotions.
if (tryResolve(lMatchIgnoringUniform, argPos, argTypes, argCouldBeNULL))
return true;
// Try to find a match via type conversion, but don't change
// unif->varying
if (tryResolve(lMatchWithTypeConvSameVariability, argPos, argTypes,
argCouldBeNULL))
return true;
// Last chance: try to find a match via arbitrary type conversion.
if (tryResolve(lMatchWithTypeConv, argPos, argTypes, argCouldBeNULL))
return true;
// Compute the cost for calling each of the candidate functions
for (int i = 0; i < (int)actualCandidates.size(); ++i) {
const FunctionType *ft =
dynamic_cast<const FunctionType *>(actualCandidates[i]->type);
Assert(ft != NULL);
candidateCosts.push_back(computeOverloadCost(ft, argTypes,
argCouldBeNULL,
argIsConstant));
}
// failure :-(
const char *funName = candidateFunctions.front()->name.c_str();
Error(pos, "Unable to find matching overload for call to function \"%s\"%s.",
funName, exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, candidateFunctions, argTypes,
argCouldBeNULL);
return false;
// Find the best cost, and then the candidate or candidates that have
// that cost.
for (int i = 0; i < (int)candidateCosts.size(); ++i) {
if (candidateCosts[i] != -1 && candidateCosts[i] < bestMatchCost)
bestMatchCost = candidateCosts[i];
}
// None of the candidates matched
if (bestMatchCost == (1<<30))
goto failure;
for (int i = 0; i < (int)candidateCosts.size(); ++i) {
if (candidateCosts[i] == bestMatchCost)
matches.push_back(actualCandidates[i]);
}
if (matches.size() == 1) {
// Only one match: success
matchingFunc = matches[0];
return true;
}
else if (matches.size() > 1) {
// Multiple matches: ambiguous
Error(pos, "Multiple overloaded functions matched call to function "
"\"%s\"%s.", funName,
exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, matches, argTypes, argCouldBeNULL);
return false;
}
else {
// No matches at all
failure:
Error(pos, "Unable to find any matching overload for call to function "
"\"%s\"%s.", funName,
exactMatchOnly ? " only considering exact matches" : "");
lPrintOverloadCandidates(argPos, candidateFunctions, argTypes,
argCouldBeNULL);
return false;
}
}