diff --git a/docs/ispc.txt b/docs/ispc.txt index f0659b48..1c39cd71 100644 --- a/docs/ispc.txt +++ b/docs/ispc.txt @@ -956,10 +956,16 @@ a given type are found, an error is issued. * All parameter types match exactly. * All parameter types match exactly, where any ``reference``-qualified parameters are considered equivalent to their underlying type. +* Parameters match with only type conversions that don't risk losing any + information (for example, converting an ``int16`` value to an ``int32`` + parameter value.) * Parameters match with only promotions from ``uniform`` to ``varying`` - type. -* Parameters match using standard type conversion (``int`` to ``float``, + types. +* Parameters match using arbitrary type conversion, without changing + variability from ``uniform`` to ``varying`` (e.g., ``int`` to ``float``, ``float`` to ``int``.) +* Parameters match using arbitrary type conversion, including also changing + variability from ``uniform`` to ``varying`` as needed. Also like C, arrays are passed to functions by reference. diff --git a/expr.cpp b/expr.cpp index a85a4086..b19d2eed 100644 --- a/expr.cpp +++ b/expr.cpp @@ -2014,19 +2014,107 @@ SelectExpr::Print() const { /////////////////////////////////////////////////////////////////////////// // FunctionCallExpr +static std::string +lGetFunctionDeclaration(const std::string &name, const FunctionType *type) { + std::string ret; + ret += type->GetReturnType()->GetString(); + ret += " "; + ret += name; + ret += "("; + + const std::vector &argTypes = type->GetArgumentTypes(); + const std::vector &argDefaults = type->GetArgumentDefaults(); + + for (unsigned int i = 0; i < argTypes.size(); ++i) { + // If the parameter is a reference to an array, just print its type + // as the array type, since we always pass arrays by reference. + if (dynamic_cast(argTypes[i]) && + dynamic_cast(argTypes[i]->GetReferenceTarget())) + ret += argTypes[i]->GetReferenceTarget()->GetString(); + else + ret += argTypes[i]->GetString(); + ret += " "; + ret += type->GetArgumentName(i); + + // Print the default value if present + if (argDefaults[i] != NULL) { + char buf[32]; + if (argTypes[i]->IsFloatType()) { + double val; + int count = argDefaults[i]->AsDouble(&val); + assert(count == 1); + sprintf(buf, " = %g", val); + } + else if (argTypes[i]->IsBoolType()) { + bool val; + int count = argDefaults[i]->AsBool(&val); + assert(count == 1); + sprintf(buf, " = %s", val ? "true" : "false"); + } + else if (argTypes[i]->IsUnsignedType()) { + uint64_t val; + int count = argDefaults[i]->AsUInt64(&val); + assert(count == 1); +#ifdef ISPC_IS_LINUX + sprintf(buf, " = %lu", val); +#else + sprintf(buf, " = %llu", val); +#endif + } + else { + int64_t val; + int count = argDefaults[i]->AsInt64(&val); + assert(count == 1); +#ifdef ISPC_IS_LINUX + sprintf(buf, " = %ld", val); +#else + sprintf(buf, " = %lld", val); +#endif + } + ret += buf; + } + if (i != argTypes.size() - 1) + ret += ", "; + } + ret += ")"; + return ret; +} + + static void -lPrintFunctionOverloads(const std::vector &matches) { +lPrintFunctionOverloads(const std::string &name, + const std::vector > &matches) { + fprintf(stderr, "Matching functions:\n"); + int minCost = matches[0].first; + for (unsigned int i = 1; i < matches.size(); ++i) + minCost = std::min(minCost, matches[i].first); + for (unsigned int i = 0; i < matches.size(); ++i) { - const FunctionType *t = dynamic_cast(matches[i]->type); + const FunctionType *t = + dynamic_cast(matches[i].second->type); assert(t != NULL); - fprintf(stderr, "\t%s\n", t->GetString().c_str()); + if (matches[i].first == minCost) + fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); + } +} + + +static void +lPrintFunctionOverloads(const std::string &name, + const std::vector &funcs) { + fprintf(stderr, "Candidate functions:\n"); + for (unsigned int i = 0; i < funcs.size(); ++i) { + const FunctionType *t = + dynamic_cast(funcs[i]->type); + assert(t != NULL); + fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); } } static void lPrintPassedTypes(const char *funName, const std::vector &argExprs) { - fprintf(stderr, "Passed types:\n\t%s(", funName); + fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' '); for (unsigned int i = 0; i < argExprs.size(); ++i) { const Type *t; if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL) @@ -2039,80 +2127,173 @@ lPrintPassedTypes(const char *funName, const std::vector &argExprs) { } -/** Helper function used for function overload resolution: returns true if - the call argument's type exactly matches the function argument type - (modulo a conversion to a const type if needed). +/** Helper function used for function overload resolution: returns zero + cost if the call argument's type exactly matches the function argument + type (modulo a conversion to a const type if needed), otherwise reports + failure. */ -static bool +static int lExactMatch(Expr *callArg, const Type *funcArgType) { const Type *callType = callArg->GetType(); -// FIXME MOVE THESE TWO TO ALWAYS DO IT... + if (dynamic_cast(callType) == NULL) callType = callType->GetAsNonConstType(); if (dynamic_cast(funcArgType) != NULL && dynamic_cast(callType) == NULL) callType = new ReferenceType(callType, funcArgType->IsConstType()); - return Type::Equal(callType, funcArgType); + return Type::Equal(callType, funcArgType) ? 0 : -1; } -/** Helper function used for function overload resolution: returns true if - the call argument type and the function argument type match, modulo - conversion to a reference type if needed. +/** Helper function used for function overload resolution: returns a cost + of 1 if the call argument type and the function argument type match, + modulo conversion to a reference type if needed. */ -static bool +static int lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) { + int prev = lExactMatch(callArg, funcArgType); + if (prev != -1) + return prev; + const Type *callType = callArg->GetType()->GetReferenceTarget(); if (funcArgType->IsConstType()) callType = callType->GetAsConstType(); return Type::Equal(callType, - funcArgType->GetReferenceTarget()); + funcArgType->GetReferenceTarget()) ? 1 : -1; +} + +/** Helper function used for function overload resolution: returns a cost + of 1 if converting the argument to the call type only requires a type + conversion that won't lose information. Otherwise reports failure. +*/ +static int +lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) { + int prev = lMatchIgnoringReferences(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType(); + const AtomicType *callAt = dynamic_cast(callType); + const AtomicType *funcAt = dynamic_cast(funcArgType); + if (callAt == NULL || funcAt == NULL) + return -1; + + if (callAt->IsUniformType() != funcAt->IsUniformType()) + return -1; + + switch (callAt->basicType) { + case AtomicType::TYPE_BOOL: + return 1; + case AtomicType::TYPE_INT8: + case AtomicType::TYPE_UINT8: + return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; + case AtomicType::TYPE_INT16: + case AtomicType::TYPE_UINT16: + return (funcAt->basicType != AtomicType::TYPE_BOOL && + funcAt->basicType != AtomicType::TYPE_INT8 && + funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; + case AtomicType::TYPE_INT32: + case AtomicType::TYPE_UINT32: + return (funcAt->basicType == AtomicType::TYPE_INT32 || + funcAt->basicType == AtomicType::TYPE_UINT32 || + funcAt->basicType == AtomicType::TYPE_INT64 || + funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + case AtomicType::TYPE_FLOAT: + return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; + case AtomicType::TYPE_INT64: + case AtomicType::TYPE_UINT64: + return (funcAt->basicType == AtomicType::TYPE_INT64 || + funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + case AtomicType::TYPE_DOUBLE: + return -1; + default: + FATAL("Unhandled atomic type"); + return -1; + } } -/** Helper function used for function overload resolution: returns true if - the call argument type and the function argument type match if we only - do a uniform -> varying type conversion but otherwise have exactly the - same type. +/** Helper function used for function overload resolution: returns a cost + of 1 if the call argument type and the function argument type match if + we only do a uniform -> varying type conversion but otherwise have + exactly the same type. */ -static bool +static int lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) { + int prev = lMatchWithTypeWidening(callArg, funcArgType); + if (prev != -1) + return prev; + const Type *callType = callArg->GetType(); if (dynamic_cast(callType) == NULL) callType = callType->GetAsNonConstType(); - if (Type::Equal(callType, funcArgType)) - return true; - return (callType->IsUniformType() && funcArgType->IsVaryingType() && - Type::Equal(callType->GetAsVaryingType(), funcArgType)); + Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; } -/** Helper function used for function overload resolution: returns true if - we can type convert from the call argument type to the function +/** Helper function used for function overload resolution: returns a cost + of 1 if we can type convert from the call argument type to the function argument type, but without doing a uniform -> varying conversion. */ -static bool +static int lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) { + int prev = lMatchIgnoringUniform(callArg, funcArgType); + if (prev != -1) + return prev; + Expr *te = callArg->TypeConv(funcArgType, "function call argument", true); - return (te != NULL && - te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType()); + if (te != NULL && + te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType()) + return 1; + else + return -1; } -/** Helper function used for function overload resolution: returns true if - there is any type conversion that gets us from the caller argument type - to the function argument type. +/** Helper function used for function overload resolution: returns a cost + of 1 if there is any type conversion that gets us from the caller + argument type to the function argument type. */ -static bool +static int lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) { + int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType); + if (prev != -1) + return prev; + Expr *te = callArg->TypeConv(funcArgType, "function call argument", true); - return (te != NULL); + return (te != NULL) ? 0 : -1; +} + + +/** Given a set of potential matching functions and their associated cost, + return the one with the lowest cost, if unique. Otherwise, if multiple + functions match with the same cost, return NULL. + */ +static Symbol * +lGetBestMatch(std::vector > &matches) { + assert(matches.size() > 0); + int minCost = matches[0].first; + + for (unsigned int i = 1; i < matches.size(); ++i) + minCost = std::min(minCost, matches[i].first); + + Symbol *match = NULL; + for (unsigned int i = 0; i < matches.size(); ++i) { + if (matches[i].first == minCost) { + if (match != NULL) + // multiple things had the same cost + return NULL; + else + match = matches[i].second; + } + } + return match; } @@ -2123,16 +2304,13 @@ lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) { finding multiple ambiguous matches. */ bool -FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) { +FunctionCallExpr::tryResolve(int (*matchFunc)(Expr *, const Type *)) { FunctionSymbolExpr *fse = dynamic_cast(func); - if (!fse) - // error will be issued later if not calling an actual function - return false; const char *funName = fse->candidateFunctions->front()->name.c_str(); std::vector &callArgs = args->exprs; - std::vector matches; + std::vector > matches; std::vector::iterator iter; for (iter = fse->candidateFunctions->begin(); iter != fse->candidateFunctions->end(); ++iter) { @@ -2141,12 +2319,12 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) { const FunctionType *ft = dynamic_cast(candidateFunction->type); assert(ft != NULL); - const std::vector &candArgTypes = ft->GetArgumentTypes(); + const std::vector &funcArgTypes = ft->GetArgumentTypes(); const std::vector &argumentDefaults = ft->GetArgumentDefaults(); // There's no way to match if the caller is passing more arguments // than this function instance takes. - if (callArgs.size() > candArgTypes.size()) + if (callArgs.size() > funcArgTypes.size()) continue; unsigned int i; @@ -2154,28 +2332,30 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) { // function arguments; it may be ok to have more arguments to the // function than are passed, if the function has default argument // values. This case is handled below. + int cost = 0; for (i = 0; i < callArgs.size(); ++i) { // This may happen if there's an error earlier in compilation. // It's kind of a silly to redundantly discover this for each // potential match versus detecting this earlier in the // matching process and just giving up. - if (!callArgs[i] || !callArgs[i]->GetType() || !candArgTypes[i] || + if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] || dynamic_cast(callArgs[i]->GetType()) != NULL) return false; - // See if this caller argument matches the type of the - // corresponding function argument according to the given - // predicate function. If not, break out and stop trying. - if (!matchFunc(callArgs[i], candArgTypes[i])) + int argCost = matchFunc(callArgs[i], funcArgTypes[i]); + if (argCost == -1) + // If the predicate function returns -1, we have failed no + // matter what else happens, so we stop trying break; + cost += argCost; } if (i == callArgs.size()) { // All of the arguments matched! - if (i == candArgTypes.size()) + if (i == funcArgTypes.size()) // And we have exactly as many arguments as the function // wants, so we're done. - matches.push_back(candidateFunction); - else if (i < candArgTypes.size() && argumentDefaults[i] != NULL) + matches.push_back(std::make_pair(cost, candidateFunction)); + else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL) // Otherwise we can still make it if there are default // arguments for the rest of the arguments! Because in // Module::AddFunction() we have verified that once the @@ -2183,17 +2363,16 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) { // have them as well. Therefore, we just need to check if // the arg we stopped at has a default value and we're // done. - matches.push_back(candidateFunction); + matches.push_back(std::make_pair(cost, candidateFunction)); // otherwise, we don't have a match } } if (matches.size() == 0) return false; - else if (matches.size() == 1) { - fse->matchingFunc = matches[0]; - - // fill in any function defaults required + else if ((fse->matchingFunc = lGetBestMatch(matches)) != NULL) { + // We have a match--fill in with any default argument values + // needed. const FunctionType *ft = dynamic_cast(fse->matchingFunc->type); assert(ft != NULL); @@ -2209,7 +2388,7 @@ FunctionCallExpr::tryResolve(bool (*matchFunc)(Expr *, const Type *)) { else { Error(fse->pos, "Multiple overloaded instances of function \"%s\" matched.", funName); - lPrintFunctionOverloads(matches); + lPrintFunctionOverloads(funName, matches); lPrintPassedTypes(funName, args->exprs); // Stop trying to find more matches after failure return true; @@ -2225,8 +2404,6 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) { return; assert(args); - // Try to find the best overload for the function... - // Is there an exact match that doesn't require any argument type // conversion (other than converting type -> reference type)? if (tryResolve(lExactMatch)) @@ -2237,11 +2414,13 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) { if (tryResolve(lMatchIgnoringReferences)) return; - // TODO: next, try to find an exact match via type promotion--i.e. char - // -> int, etc--things that don't lose data + // Try to find an exact match via type widening--i.e. int8 -> + // int16, etc.--things that don't lose data. + if (tryResolve(lMatchWithTypeWidening)) + return; // Next try to see if there's a match via just uniform -> varying - // promotions. TODO: look for one with a minimal number of them? + // promotions. if (tryResolve(lMatchIgnoringUniform)) return; @@ -2259,8 +2438,7 @@ FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) { const char *funName = fse->candidateFunctions->front()->name.c_str(); Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", funName, exactMatchOnly ? " only considering exact matches" : ""); - fprintf(stderr, "Candidates are:\n"); - lPrintFunctionOverloads(*fse->candidateFunctions); + lPrintFunctionOverloads(funName, *fse->candidateFunctions); lPrintPassedTypes(funName, args->exprs); } @@ -2293,7 +2471,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { FunctionSymbolExpr *fse = dynamic_cast(func); if (!fse) { - Error(pos, "Invalid function name for function call."); + Error(pos, "No valid function available for function call."); return NULL; } diff --git a/expr.h b/expr.h index 80491d61..0ee6c80c 100644 --- a/expr.h +++ b/expr.h @@ -269,7 +269,7 @@ public: private: void resolveFunctionOverloads(bool exactMatchOnly); - bool tryResolve(bool (*matchFunc)(Expr *, const Type *)); + bool tryResolve(int (*matchFunc)(Expr *, const Type *)); };