Various improvements to function overload resolution.

Generalize the overload resolution code to be based on estimating a
  cost for various overload options and picking the one with the
  minimal cost.
Add a step that considers type conversions that are guaranteed to
  not lose information in function overload resolution.
Print better diagnostics when we can't find an unambiguous match.
This commit is contained in:
Matt Pharr
2011-10-16 20:46:56 -04:00
parent 209d093720
commit 39ed7e14b2
3 changed files with 250 additions and 66 deletions

304
expr.cpp
View File

@@ -2014,19 +2014,107 @@ SelectExpr::Print() const {
///////////////////////////////////////////////////////////////////////////
// FunctionCallExpr
static std::string
lGetFunctionDeclaration(const std::string &name, const FunctionType *type) {
std::string ret;
ret += type->GetReturnType()->GetString();
ret += " ";
ret += name;
ret += "(";
const std::vector<const Type *> &argTypes = type->GetArgumentTypes();
const std::vector<ConstExpr *> &argDefaults = type->GetArgumentDefaults();
for (unsigned int i = 0; i < argTypes.size(); ++i) {
// If the parameter is a reference to an array, just print its type
// as the array type, since we always pass arrays by reference.
if (dynamic_cast<const ReferenceType *>(argTypes[i]) &&
dynamic_cast<const ArrayType *>(argTypes[i]->GetReferenceTarget()))
ret += argTypes[i]->GetReferenceTarget()->GetString();
else
ret += argTypes[i]->GetString();
ret += " ";
ret += type->GetArgumentName(i);
// Print the default value if present
if (argDefaults[i] != NULL) {
char buf[32];
if (argTypes[i]->IsFloatType()) {
double val;
int count = argDefaults[i]->AsDouble(&val);
assert(count == 1);
sprintf(buf, " = %g", val);
}
else if (argTypes[i]->IsBoolType()) {
bool val;
int count = argDefaults[i]->AsBool(&val);
assert(count == 1);
sprintf(buf, " = %s", val ? "true" : "false");
}
else if (argTypes[i]->IsUnsignedType()) {
uint64_t val;
int count = argDefaults[i]->AsUInt64(&val);
assert(count == 1);
#ifdef ISPC_IS_LINUX
sprintf(buf, " = %lu", val);
#else
sprintf(buf, " = %llu", val);
#endif
}
else {
int64_t val;
int count = argDefaults[i]->AsInt64(&val);
assert(count == 1);
#ifdef ISPC_IS_LINUX
sprintf(buf, " = %ld", val);
#else
sprintf(buf, " = %lld", val);
#endif
}
ret += buf;
}
if (i != argTypes.size() - 1)
ret += ", ";
}
ret += ")";
return ret;
}
static void
lPrintFunctionOverloads(const std::vector<Symbol *> &matches) {
lPrintFunctionOverloads(const std::string &name,
const std::vector<std::pair<int, Symbol *> > &matches) {
fprintf(stderr, "Matching functions:\n");
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
for (unsigned int i = 0; i < matches.size(); ++i) {
const FunctionType *t = dynamic_cast<const FunctionType *>(matches[i]->type);
const FunctionType *t =
dynamic_cast<const FunctionType *>(matches[i].second->type);
assert(t != NULL);
fprintf(stderr, "\t%s\n", t->GetString().c_str());
if (matches[i].first == minCost)
fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str());
}
}
static void
lPrintFunctionOverloads(const std::string &name,
const std::vector<Symbol *> &funcs) {
fprintf(stderr, "Candidate functions:\n");
for (unsigned int i = 0; i < funcs.size(); ++i) {
const FunctionType *t =
dynamic_cast<const FunctionType *>(funcs[i]->type);
assert(t != NULL);
fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str());
}
}
static void
lPrintPassedTypes(const char *funName, const std::vector<Expr *> &argExprs) {
fprintf(stderr, "Passed types:\n\t%s(", funName);
fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' ');
for (unsigned int i = 0; i < argExprs.size(); ++i) {
const Type *t;
if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL)
@@ -2039,80 +2127,173 @@ lPrintPassedTypes(const char *funName, const std::vector<Expr *> &argExprs) {
}
/** Helper function used for function overload resolution: returns true if
the call argument's type exactly matches the function argument type
(modulo a conversion to a const type if needed).
/** Helper function used for function overload resolution: returns zero
cost if the call argument's type exactly matches the function argument
type (modulo a conversion to a const type if needed), otherwise reports
failure.
*/
static bool
static int
lExactMatch(Expr *callArg, const Type *funcArgType) {
const Type *callType = callArg->GetType();
// FIXME MOVE THESE TWO TO ALWAYS DO IT...
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (dynamic_cast<const ReferenceType *>(funcArgType) != NULL &&
dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = new ReferenceType(callType, funcArgType->IsConstType());
return Type::Equal(callType, funcArgType);
return Type::Equal(callType, funcArgType) ? 0 : -1;
}
/** Helper function used for function overload resolution: returns true if
the call argument type and the function argument type match, modulo
conversion to a reference type if needed.
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match,
modulo conversion to a reference type if needed.
*/
static bool
static int
lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) {
int prev = lExactMatch(callArg, funcArgType);
if (prev != -1)
return prev;
const Type *callType = callArg->GetType()->GetReferenceTarget();
if (funcArgType->IsConstType())
callType = callType->GetAsConstType();
return Type::Equal(callType,
funcArgType->GetReferenceTarget());
funcArgType->GetReferenceTarget()) ? 1 : -1;
}
/** Helper function used for function overload resolution: returns a cost
of 1 if converting the argument to the call type only requires a type
conversion that won't lose information. Otherwise reports failure.
*/
static int
lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) {
int prev = lMatchIgnoringReferences(callArg, funcArgType);
if (prev != -1)
return prev;
const Type *callType = callArg->GetType();
const AtomicType *callAt = dynamic_cast<const AtomicType *>(callType);
const AtomicType *funcAt = dynamic_cast<const AtomicType *>(funcArgType);
if (callAt == NULL || funcAt == NULL)
return -1;
if (callAt->IsUniformType() != funcAt->IsUniformType())
return -1;
switch (callAt->basicType) {
case AtomicType::TYPE_BOOL:
return 1;
case AtomicType::TYPE_INT8:
case AtomicType::TYPE_UINT8:
return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1;
case AtomicType::TYPE_INT16:
case AtomicType::TYPE_UINT16:
return (funcAt->basicType != AtomicType::TYPE_BOOL &&
funcAt->basicType != AtomicType::TYPE_INT8 &&
funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1;
case AtomicType::TYPE_INT32:
case AtomicType::TYPE_UINT32:
return (funcAt->basicType == AtomicType::TYPE_INT32 ||
funcAt->basicType == AtomicType::TYPE_UINT32 ||
funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
case AtomicType::TYPE_FLOAT:
return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1;
case AtomicType::TYPE_INT64:
case AtomicType::TYPE_UINT64:
return (funcAt->basicType == AtomicType::TYPE_INT64 ||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
case AtomicType::TYPE_DOUBLE:
return -1;
default:
FATAL("Unhandled atomic type");
return -1;
}
}
/** Helper function used for function overload resolution: returns true if
the call argument type and the function argument type match if we only
do a uniform -> varying type conversion but otherwise have exactly the
same type.
/** Helper function used for function overload resolution: returns a cost
of 1 if the call argument type and the function argument type match if
we only do a uniform -> varying type conversion but otherwise have
exactly the same type.
*/
static bool
static int
lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) {
int prev = lMatchWithTypeWidening(callArg, funcArgType);
if (prev != -1)
return prev;
const Type *callType = callArg->GetType();
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
callType = callType->GetAsNonConstType();
if (Type::Equal(callType, funcArgType))
return true;
return (callType->IsUniformType() &&
funcArgType->IsVaryingType() &&
Type::Equal(callType->GetAsVaryingType(), funcArgType));
Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1;
}
/** Helper function used for function overload resolution: returns true if
we can type convert from the call argument type to the function
/** Helper function used for function overload resolution: returns a cost
of 1 if we can type convert from the call argument type to the function
argument type, but without doing a uniform -> varying conversion.
*/
static bool
static int
lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) {
int prev = lMatchIgnoringUniform(callArg, funcArgType);
if (prev != -1)
return prev;
Expr *te = callArg->TypeConv(funcArgType,
"function call argument", true);
return (te != NULL &&
te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType());
if (te != NULL &&
te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType())
return 1;
else
return -1;
}
/** Helper function used for function overload resolution: returns true if
there is any type conversion that gets us from the caller argument type
to the function argument type.
/** Helper function used for function overload resolution: returns a cost
of 1 if there is any type conversion that gets us from the caller
argument type to the function argument type.
*/
static bool
static int
lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) {
int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType);
if (prev != -1)
return prev;
Expr *te = callArg->TypeConv(funcArgType,
"function call argument", true);
return (te != NULL);
return (te != NULL) ? 0 : -1;
}
/** Given a set of potential matching functions and their associated cost,
return the one with the lowest cost, if unique. Otherwise, if multiple
functions match with the same cost, return NULL.
*/
static Symbol *
lGetBestMatch(std::vector<std::pair<int, Symbol *> > &matches) {
assert(matches.size() > 0);
int minCost = matches[0].first;
for (unsigned int i = 1; i < matches.size(); ++i)
minCost = std::min(minCost, matches[i].first);
Symbol *match = NULL;
for (unsigned int i = 0; i < matches.size(); ++i) {
if (matches[i].first == minCost) {
if (match != NULL)
// multiple things had the same cost
return NULL;
else
match = matches[i].second;
}
}
return match;
}
@@ -2123,16 +2304,13 @@ lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) {
finding multiple ambiguous matches.
*/
bool
FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
FunctionCallExpr::tryResolve(int (*matchFunc)(Expr *, const Type *)) {
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (!fse)
// error will be issued later if not calling an actual function
return false;
const char *funName = fse->candidateFunctions->front()->name.c_str();
std::vector<Expr *> &callArgs = args->exprs;
std::vector<Symbol *> matches;
std::vector<std::pair<int, Symbol *> > matches;
std::vector<Symbol *>::iterator iter;
for (iter = fse->candidateFunctions->begin();
iter != fse->candidateFunctions->end(); ++iter) {
@@ -2141,12 +2319,12 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
const FunctionType *ft =
dynamic_cast<const FunctionType *>(candidateFunction->type);
assert(ft != NULL);
const std::vector<const Type *> &candArgTypes = ft->GetArgumentTypes();
const std::vector<const Type *> &funcArgTypes = ft->GetArgumentTypes();
const std::vector<ConstExpr *> &argumentDefaults = ft->GetArgumentDefaults();
// There's no way to match if the caller is passing more arguments
// than this function instance takes.
if (callArgs.size() > candArgTypes.size())
if (callArgs.size() > funcArgTypes.size())
continue;
unsigned int i;
@@ -2154,28 +2332,30 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
// function arguments; it may be ok to have more arguments to the
// function than are passed, if the function has default argument
// values. This case is handled below.
int cost = 0;
for (i = 0; i < callArgs.size(); ++i) {
// This may happen if there's an error earlier in compilation.
// It's kind of a silly to redundantly discover this for each
// potential match versus detecting this earlier in the
// matching process and just giving up.
if (!callArgs[i] || !callArgs[i]->GetType() || !candArgTypes[i] ||
if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] ||
dynamic_cast<const FunctionType *>(callArgs[i]->GetType()) != NULL)
return false;
// See if this caller argument matches the type of the
// corresponding function argument according to the given
// predicate function. If not, break out and stop trying.
if (!matchFunc(callArgs[i], candArgTypes[i]))
int argCost = matchFunc(callArgs[i], funcArgTypes[i]);
if (argCost == -1)
// If the predicate function returns -1, we have failed no
// matter what else happens, so we stop trying
break;
cost += argCost;
}
if (i == callArgs.size()) {
// All of the arguments matched!
if (i == candArgTypes.size())
if (i == funcArgTypes.size())
// And we have exactly as many arguments as the function
// wants, so we're done.
matches.push_back(candidateFunction);
else if (i < candArgTypes.size() && argumentDefaults[i] != NULL)
matches.push_back(std::make_pair(cost, candidateFunction));
else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL)
// Otherwise we can still make it if there are default
// arguments for the rest of the arguments! Because in
// Module::AddFunction() we have verified that once the
@@ -2183,17 +2363,16 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
// have them as well. Therefore, we just need to check if
// the arg we stopped at has a default value and we're
// done.
matches.push_back(candidateFunction);
matches.push_back(std::make_pair(cost, candidateFunction));
// otherwise, we don't have a match
}
}
if (matches.size() == 0)
return false;
else if (matches.size() == 1) {
fse->matchingFunc = matches[0];
// fill in any function defaults required
else if ((fse->matchingFunc = lGetBestMatch(matches)) != NULL) {
// We have a match--fill in with any default argument values
// needed.
const FunctionType *ft =
dynamic_cast<const FunctionType *>(fse->matchingFunc->type);
assert(ft != NULL);
@@ -2209,7 +2388,7 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
else {
Error(fse->pos, "Multiple overloaded instances of function \"%s\" matched.",
funName);
lPrintFunctionOverloads(matches);
lPrintFunctionOverloads(funName, matches);
lPrintPassedTypes(funName, args->exprs);
// Stop trying to find more matches after failure
return true;
@@ -2225,8 +2404,6 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
return;
assert(args);
// Try to find the best overload for the function...
// Is there an exact match that doesn't require any argument type
// conversion (other than converting type -> reference type)?
if (tryResolve(lExactMatch))
@@ -2237,11 +2414,13 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
if (tryResolve(lMatchIgnoringReferences))
return;
// TODO: next, try to find an exact match via type promotion--i.e. char
// -> int, etc--things that don't lose data
// Try to find an exact match via type widening--i.e. int8 ->
// int16, etc.--things that don't lose data.
if (tryResolve(lMatchWithTypeWidening))
return;
// Next try to see if there's a match via just uniform -> varying
// promotions. TODO: look for one with a minimal number of them?
// promotions.
if (tryResolve(lMatchIgnoringUniform))
return;
@@ -2259,8 +2438,7 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
const char *funName = fse->candidateFunctions->front()->name.c_str();
Error(pos, "Unable to find matching overload for call to function \"%s\"%s.",
funName, exactMatchOnly ? " only considering exact matches" : "");
fprintf(stderr, "Candidates are:\n");
lPrintFunctionOverloads(*fse->candidateFunctions);
lPrintFunctionOverloads(funName, *fse->candidateFunctions);
lPrintPassedTypes(funName, args->exprs);
}
@@ -2293,7 +2471,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const {
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (!fse) {
Error(pos, "Invalid function name for function call.");
Error(pos, "No valid function available for function call.");
return NULL;
}