Various improvements to function overload resolution.
Generalize the overload resolution code to be based on estimating a cost for various overload options and picking the one with the minimal cost. Add a step that considers type conversions that are guaranteed to not lose information in function overload resolution. Print better diagnostics when we can't find an unambiguous match.
This commit is contained in:
@@ -956,10 +956,16 @@ a given type are found, an error is issued.
|
||||
* All parameter types match exactly.
|
||||
* All parameter types match exactly, where any ``reference``-qualified
|
||||
parameters are considered equivalent to their underlying type.
|
||||
* Parameters match with only type conversions that don't risk losing any
|
||||
information (for example, converting an ``int16`` value to an ``int32``
|
||||
parameter value.)
|
||||
* Parameters match with only promotions from ``uniform`` to ``varying``
|
||||
type.
|
||||
* Parameters match using standard type conversion (``int`` to ``float``,
|
||||
types.
|
||||
* Parameters match using arbitrary type conversion, without changing
|
||||
variability from ``uniform`` to ``varying`` (e.g., ``int`` to ``float``,
|
||||
``float`` to ``int``.)
|
||||
* Parameters match using arbitrary type conversion, including also changing
|
||||
variability from ``uniform`` to ``varying`` as needed.
|
||||
|
||||
Also like C, arrays are passed to functions by reference.
|
||||
|
||||
|
||||
304
expr.cpp
304
expr.cpp
@@ -2014,19 +2014,107 @@ SelectExpr::Print() const {
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// FunctionCallExpr
|
||||
|
||||
static std::string
|
||||
lGetFunctionDeclaration(const std::string &name, const FunctionType *type) {
|
||||
std::string ret;
|
||||
ret += type->GetReturnType()->GetString();
|
||||
ret += " ";
|
||||
ret += name;
|
||||
ret += "(";
|
||||
|
||||
const std::vector<const Type *> &argTypes = type->GetArgumentTypes();
|
||||
const std::vector<ConstExpr *> &argDefaults = type->GetArgumentDefaults();
|
||||
|
||||
for (unsigned int i = 0; i < argTypes.size(); ++i) {
|
||||
// If the parameter is a reference to an array, just print its type
|
||||
// as the array type, since we always pass arrays by reference.
|
||||
if (dynamic_cast<const ReferenceType *>(argTypes[i]) &&
|
||||
dynamic_cast<const ArrayType *>(argTypes[i]->GetReferenceTarget()))
|
||||
ret += argTypes[i]->GetReferenceTarget()->GetString();
|
||||
else
|
||||
ret += argTypes[i]->GetString();
|
||||
ret += " ";
|
||||
ret += type->GetArgumentName(i);
|
||||
|
||||
// Print the default value if present
|
||||
if (argDefaults[i] != NULL) {
|
||||
char buf[32];
|
||||
if (argTypes[i]->IsFloatType()) {
|
||||
double val;
|
||||
int count = argDefaults[i]->AsDouble(&val);
|
||||
assert(count == 1);
|
||||
sprintf(buf, " = %g", val);
|
||||
}
|
||||
else if (argTypes[i]->IsBoolType()) {
|
||||
bool val;
|
||||
int count = argDefaults[i]->AsBool(&val);
|
||||
assert(count == 1);
|
||||
sprintf(buf, " = %s", val ? "true" : "false");
|
||||
}
|
||||
else if (argTypes[i]->IsUnsignedType()) {
|
||||
uint64_t val;
|
||||
int count = argDefaults[i]->AsUInt64(&val);
|
||||
assert(count == 1);
|
||||
#ifdef ISPC_IS_LINUX
|
||||
sprintf(buf, " = %lu", val);
|
||||
#else
|
||||
sprintf(buf, " = %llu", val);
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
int64_t val;
|
||||
int count = argDefaults[i]->AsInt64(&val);
|
||||
assert(count == 1);
|
||||
#ifdef ISPC_IS_LINUX
|
||||
sprintf(buf, " = %ld", val);
|
||||
#else
|
||||
sprintf(buf, " = %lld", val);
|
||||
#endif
|
||||
}
|
||||
ret += buf;
|
||||
}
|
||||
if (i != argTypes.size() - 1)
|
||||
ret += ", ";
|
||||
}
|
||||
ret += ")";
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
lPrintFunctionOverloads(const std::vector<Symbol *> &matches) {
|
||||
lPrintFunctionOverloads(const std::string &name,
|
||||
const std::vector<std::pair<int, Symbol *> > &matches) {
|
||||
fprintf(stderr, "Matching functions:\n");
|
||||
int minCost = matches[0].first;
|
||||
for (unsigned int i = 1; i < matches.size(); ++i)
|
||||
minCost = std::min(minCost, matches[i].first);
|
||||
|
||||
for (unsigned int i = 0; i < matches.size(); ++i) {
|
||||
const FunctionType *t = dynamic_cast<const FunctionType *>(matches[i]->type);
|
||||
const FunctionType *t =
|
||||
dynamic_cast<const FunctionType *>(matches[i].second->type);
|
||||
assert(t != NULL);
|
||||
fprintf(stderr, "\t%s\n", t->GetString().c_str());
|
||||
if (matches[i].first == minCost)
|
||||
fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
lPrintFunctionOverloads(const std::string &name,
|
||||
const std::vector<Symbol *> &funcs) {
|
||||
fprintf(stderr, "Candidate functions:\n");
|
||||
for (unsigned int i = 0; i < funcs.size(); ++i) {
|
||||
const FunctionType *t =
|
||||
dynamic_cast<const FunctionType *>(funcs[i]->type);
|
||||
assert(t != NULL);
|
||||
fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
lPrintPassedTypes(const char *funName, const std::vector<Expr *> &argExprs) {
|
||||
fprintf(stderr, "Passed types:\n\t%s(", funName);
|
||||
fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' ');
|
||||
for (unsigned int i = 0; i < argExprs.size(); ++i) {
|
||||
const Type *t;
|
||||
if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL)
|
||||
@@ -2039,80 +2127,173 @@ lPrintPassedTypes(const char *funName, const std::vector<Expr *> &argExprs) {
|
||||
}
|
||||
|
||||
|
||||
/** Helper function used for function overload resolution: returns true if
|
||||
the call argument's type exactly matches the function argument type
|
||||
(modulo a conversion to a const type if needed).
|
||||
/** Helper function used for function overload resolution: returns zero
|
||||
cost if the call argument's type exactly matches the function argument
|
||||
type (modulo a conversion to a const type if needed), otherwise reports
|
||||
failure.
|
||||
*/
|
||||
static bool
|
||||
static int
|
||||
lExactMatch(Expr *callArg, const Type *funcArgType) {
|
||||
const Type *callType = callArg->GetType();
|
||||
// FIXME MOVE THESE TWO TO ALWAYS DO IT...
|
||||
|
||||
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
|
||||
callType = callType->GetAsNonConstType();
|
||||
if (dynamic_cast<const ReferenceType *>(funcArgType) != NULL &&
|
||||
dynamic_cast<const ReferenceType *>(callType) == NULL)
|
||||
callType = new ReferenceType(callType, funcArgType->IsConstType());
|
||||
|
||||
return Type::Equal(callType, funcArgType);
|
||||
return Type::Equal(callType, funcArgType) ? 0 : -1;
|
||||
}
|
||||
|
||||
/** Helper function used for function overload resolution: returns true if
|
||||
the call argument type and the function argument type match, modulo
|
||||
conversion to a reference type if needed.
|
||||
/** Helper function used for function overload resolution: returns a cost
|
||||
of 1 if the call argument type and the function argument type match,
|
||||
modulo conversion to a reference type if needed.
|
||||
*/
|
||||
static bool
|
||||
static int
|
||||
lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) {
|
||||
int prev = lExactMatch(callArg, funcArgType);
|
||||
if (prev != -1)
|
||||
return prev;
|
||||
|
||||
const Type *callType = callArg->GetType()->GetReferenceTarget();
|
||||
if (funcArgType->IsConstType())
|
||||
callType = callType->GetAsConstType();
|
||||
|
||||
return Type::Equal(callType,
|
||||
funcArgType->GetReferenceTarget());
|
||||
funcArgType->GetReferenceTarget()) ? 1 : -1;
|
||||
}
|
||||
|
||||
/** Helper function used for function overload resolution: returns a cost
|
||||
of 1 if converting the argument to the call type only requires a type
|
||||
conversion that won't lose information. Otherwise reports failure.
|
||||
*/
|
||||
static int
|
||||
lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) {
|
||||
int prev = lMatchIgnoringReferences(callArg, funcArgType);
|
||||
if (prev != -1)
|
||||
return prev;
|
||||
|
||||
const Type *callType = callArg->GetType();
|
||||
const AtomicType *callAt = dynamic_cast<const AtomicType *>(callType);
|
||||
const AtomicType *funcAt = dynamic_cast<const AtomicType *>(funcArgType);
|
||||
if (callAt == NULL || funcAt == NULL)
|
||||
return -1;
|
||||
|
||||
if (callAt->IsUniformType() != funcAt->IsUniformType())
|
||||
return -1;
|
||||
|
||||
switch (callAt->basicType) {
|
||||
case AtomicType::TYPE_BOOL:
|
||||
return 1;
|
||||
case AtomicType::TYPE_INT8:
|
||||
case AtomicType::TYPE_UINT8:
|
||||
return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1;
|
||||
case AtomicType::TYPE_INT16:
|
||||
case AtomicType::TYPE_UINT16:
|
||||
return (funcAt->basicType != AtomicType::TYPE_BOOL &&
|
||||
funcAt->basicType != AtomicType::TYPE_INT8 &&
|
||||
funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1;
|
||||
case AtomicType::TYPE_INT32:
|
||||
case AtomicType::TYPE_UINT32:
|
||||
return (funcAt->basicType == AtomicType::TYPE_INT32 ||
|
||||
funcAt->basicType == AtomicType::TYPE_UINT32 ||
|
||||
funcAt->basicType == AtomicType::TYPE_INT64 ||
|
||||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
|
||||
case AtomicType::TYPE_FLOAT:
|
||||
return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1;
|
||||
case AtomicType::TYPE_INT64:
|
||||
case AtomicType::TYPE_UINT64:
|
||||
return (funcAt->basicType == AtomicType::TYPE_INT64 ||
|
||||
funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1;
|
||||
case AtomicType::TYPE_DOUBLE:
|
||||
return -1;
|
||||
default:
|
||||
FATAL("Unhandled atomic type");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Helper function used for function overload resolution: returns true if
|
||||
the call argument type and the function argument type match if we only
|
||||
do a uniform -> varying type conversion but otherwise have exactly the
|
||||
same type.
|
||||
/** Helper function used for function overload resolution: returns a cost
|
||||
of 1 if the call argument type and the function argument type match if
|
||||
we only do a uniform -> varying type conversion but otherwise have
|
||||
exactly the same type.
|
||||
*/
|
||||
static bool
|
||||
static int
|
||||
lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) {
|
||||
int prev = lMatchWithTypeWidening(callArg, funcArgType);
|
||||
if (prev != -1)
|
||||
return prev;
|
||||
|
||||
const Type *callType = callArg->GetType();
|
||||
if (dynamic_cast<const ReferenceType *>(callType) == NULL)
|
||||
callType = callType->GetAsNonConstType();
|
||||
|
||||
if (Type::Equal(callType, funcArgType))
|
||||
return true;
|
||||
|
||||
return (callType->IsUniformType() &&
|
||||
funcArgType->IsVaryingType() &&
|
||||
Type::Equal(callType->GetAsVaryingType(), funcArgType));
|
||||
Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1;
|
||||
}
|
||||
|
||||
|
||||
/** Helper function used for function overload resolution: returns true if
|
||||
we can type convert from the call argument type to the function
|
||||
/** Helper function used for function overload resolution: returns a cost
|
||||
of 1 if we can type convert from the call argument type to the function
|
||||
argument type, but without doing a uniform -> varying conversion.
|
||||
*/
|
||||
static bool
|
||||
static int
|
||||
lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) {
|
||||
int prev = lMatchIgnoringUniform(callArg, funcArgType);
|
||||
if (prev != -1)
|
||||
return prev;
|
||||
|
||||
Expr *te = callArg->TypeConv(funcArgType,
|
||||
"function call argument", true);
|
||||
return (te != NULL &&
|
||||
te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType());
|
||||
if (te != NULL &&
|
||||
te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType())
|
||||
return 1;
|
||||
else
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
/** Helper function used for function overload resolution: returns true if
|
||||
there is any type conversion that gets us from the caller argument type
|
||||
to the function argument type.
|
||||
/** Helper function used for function overload resolution: returns a cost
|
||||
of 1 if there is any type conversion that gets us from the caller
|
||||
argument type to the function argument type.
|
||||
*/
|
||||
static bool
|
||||
static int
|
||||
lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) {
|
||||
int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType);
|
||||
if (prev != -1)
|
||||
return prev;
|
||||
|
||||
Expr *te = callArg->TypeConv(funcArgType,
|
||||
"function call argument", true);
|
||||
return (te != NULL);
|
||||
return (te != NULL) ? 0 : -1;
|
||||
}
|
||||
|
||||
|
||||
/** Given a set of potential matching functions and their associated cost,
|
||||
return the one with the lowest cost, if unique. Otherwise, if multiple
|
||||
functions match with the same cost, return NULL.
|
||||
*/
|
||||
static Symbol *
|
||||
lGetBestMatch(std::vector<std::pair<int, Symbol *> > &matches) {
|
||||
assert(matches.size() > 0);
|
||||
int minCost = matches[0].first;
|
||||
|
||||
for (unsigned int i = 1; i < matches.size(); ++i)
|
||||
minCost = std::min(minCost, matches[i].first);
|
||||
|
||||
Symbol *match = NULL;
|
||||
for (unsigned int i = 0; i < matches.size(); ++i) {
|
||||
if (matches[i].first == minCost) {
|
||||
if (match != NULL)
|
||||
// multiple things had the same cost
|
||||
return NULL;
|
||||
else
|
||||
match = matches[i].second;
|
||||
}
|
||||
}
|
||||
return match;
|
||||
}
|
||||
|
||||
|
||||
@@ -2123,16 +2304,13 @@ lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) {
|
||||
finding multiple ambiguous matches.
|
||||
*/
|
||||
bool
|
||||
FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
|
||||
FunctionCallExpr::tryResolve(int (*matchFunc)(Expr *, const Type *)) {
|
||||
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
|
||||
if (!fse)
|
||||
// error will be issued later if not calling an actual function
|
||||
return false;
|
||||
|
||||
const char *funName = fse->candidateFunctions->front()->name.c_str();
|
||||
std::vector<Expr *> &callArgs = args->exprs;
|
||||
|
||||
std::vector<Symbol *> matches;
|
||||
std::vector<std::pair<int, Symbol *> > matches;
|
||||
std::vector<Symbol *>::iterator iter;
|
||||
for (iter = fse->candidateFunctions->begin();
|
||||
iter != fse->candidateFunctions->end(); ++iter) {
|
||||
@@ -2141,12 +2319,12 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
|
||||
const FunctionType *ft =
|
||||
dynamic_cast<const FunctionType *>(candidateFunction->type);
|
||||
assert(ft != NULL);
|
||||
const std::vector<const Type *> &candArgTypes = ft->GetArgumentTypes();
|
||||
const std::vector<const Type *> &funcArgTypes = ft->GetArgumentTypes();
|
||||
const std::vector<ConstExpr *> &argumentDefaults = ft->GetArgumentDefaults();
|
||||
|
||||
// There's no way to match if the caller is passing more arguments
|
||||
// than this function instance takes.
|
||||
if (callArgs.size() > candArgTypes.size())
|
||||
if (callArgs.size() > funcArgTypes.size())
|
||||
continue;
|
||||
|
||||
unsigned int i;
|
||||
@@ -2154,28 +2332,30 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
|
||||
// function arguments; it may be ok to have more arguments to the
|
||||
// function than are passed, if the function has default argument
|
||||
// values. This case is handled below.
|
||||
int cost = 0;
|
||||
for (i = 0; i < callArgs.size(); ++i) {
|
||||
// This may happen if there's an error earlier in compilation.
|
||||
// It's kind of a silly to redundantly discover this for each
|
||||
// potential match versus detecting this earlier in the
|
||||
// matching process and just giving up.
|
||||
if (!callArgs[i] || !callArgs[i]->GetType() || !candArgTypes[i] ||
|
||||
if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] ||
|
||||
dynamic_cast<const FunctionType *>(callArgs[i]->GetType()) != NULL)
|
||||
return false;
|
||||
|
||||
// See if this caller argument matches the type of the
|
||||
// corresponding function argument according to the given
|
||||
// predicate function. If not, break out and stop trying.
|
||||
if (!matchFunc(callArgs[i], candArgTypes[i]))
|
||||
int argCost = matchFunc(callArgs[i], funcArgTypes[i]);
|
||||
if (argCost == -1)
|
||||
// If the predicate function returns -1, we have failed no
|
||||
// matter what else happens, so we stop trying
|
||||
break;
|
||||
cost += argCost;
|
||||
}
|
||||
if (i == callArgs.size()) {
|
||||
// All of the arguments matched!
|
||||
if (i == candArgTypes.size())
|
||||
if (i == funcArgTypes.size())
|
||||
// And we have exactly as many arguments as the function
|
||||
// wants, so we're done.
|
||||
matches.push_back(candidateFunction);
|
||||
else if (i < candArgTypes.size() && argumentDefaults[i] != NULL)
|
||||
matches.push_back(std::make_pair(cost, candidateFunction));
|
||||
else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL)
|
||||
// Otherwise we can still make it if there are default
|
||||
// arguments for the rest of the arguments! Because in
|
||||
// Module::AddFunction() we have verified that once the
|
||||
@@ -2183,17 +2363,16 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
|
||||
// have them as well. Therefore, we just need to check if
|
||||
// the arg we stopped at has a default value and we're
|
||||
// done.
|
||||
matches.push_back(candidateFunction);
|
||||
matches.push_back(std::make_pair(cost, candidateFunction));
|
||||
// otherwise, we don't have a match
|
||||
}
|
||||
}
|
||||
|
||||
if (matches.size() == 0)
|
||||
return false;
|
||||
else if (matches.size() == 1) {
|
||||
fse->matchingFunc = matches[0];
|
||||
|
||||
// fill in any function defaults required
|
||||
else if ((fse->matchingFunc = lGetBestMatch(matches)) != NULL) {
|
||||
// We have a match--fill in with any default argument values
|
||||
// needed.
|
||||
const FunctionType *ft =
|
||||
dynamic_cast<const FunctionType *>(fse->matchingFunc->type);
|
||||
assert(ft != NULL);
|
||||
@@ -2209,7 +2388,7 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) {
|
||||
else {
|
||||
Error(fse->pos, "Multiple overloaded instances of function \"%s\" matched.",
|
||||
funName);
|
||||
lPrintFunctionOverloads(matches);
|
||||
lPrintFunctionOverloads(funName, matches);
|
||||
lPrintPassedTypes(funName, args->exprs);
|
||||
// Stop trying to find more matches after failure
|
||||
return true;
|
||||
@@ -2225,8 +2404,6 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
|
||||
return;
|
||||
assert(args);
|
||||
|
||||
// Try to find the best overload for the function...
|
||||
|
||||
// Is there an exact match that doesn't require any argument type
|
||||
// conversion (other than converting type -> reference type)?
|
||||
if (tryResolve(lExactMatch))
|
||||
@@ -2237,11 +2414,13 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
|
||||
if (tryResolve(lMatchIgnoringReferences))
|
||||
return;
|
||||
|
||||
// TODO: next, try to find an exact match via type promotion--i.e. char
|
||||
// -> int, etc--things that don't lose data
|
||||
// Try to find an exact match via type widening--i.e. int8 ->
|
||||
// int16, etc.--things that don't lose data.
|
||||
if (tryResolve(lMatchWithTypeWidening))
|
||||
return;
|
||||
|
||||
// Next try to see if there's a match via just uniform -> varying
|
||||
// promotions. TODO: look for one with a minimal number of them?
|
||||
// promotions.
|
||||
if (tryResolve(lMatchIgnoringUniform))
|
||||
return;
|
||||
|
||||
@@ -2259,8 +2438,7 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) {
|
||||
const char *funName = fse->candidateFunctions->front()->name.c_str();
|
||||
Error(pos, "Unable to find matching overload for call to function \"%s\"%s.",
|
||||
funName, exactMatchOnly ? " only considering exact matches" : "");
|
||||
fprintf(stderr, "Candidates are:\n");
|
||||
lPrintFunctionOverloads(*fse->candidateFunctions);
|
||||
lPrintFunctionOverloads(funName, *fse->candidateFunctions);
|
||||
lPrintPassedTypes(funName, args->exprs);
|
||||
}
|
||||
|
||||
@@ -2293,7 +2471,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||
|
||||
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
|
||||
if (!fse) {
|
||||
Error(pos, "Invalid function name for function call.");
|
||||
Error(pos, "No valid function available for function call.");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user