From d5a8538192c413414f1e80a9d62eb07e54d53eeb Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Sat, 29 Oct 2011 10:46:59 -0700 Subject: [PATCH] Move logic for resolving function call overloads. This code previously lived in FunctionCallExpr but is now part of FunctionSymbolExpr. This change doesn't change any current functionality, but lays groundwork for function pointers in the language, where we'll want to do function call overload resolution at other times besides when a function call is actually being made. --- expr.cpp | 1019 ++++++++++++++++++++++++++++-------------------------- expr.h | 19 +- parse.yy | 14 +- stmt.cpp | 19 +- type.cpp | 9 + type.h | 2 +- 6 files changed, 555 insertions(+), 527 deletions(-) diff --git a/expr.cpp b/expr.cpp index 6e522a80..2e1be3c5 100644 --- a/expr.cpp +++ b/expr.cpp @@ -165,7 +165,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // Convert from type T -> const T; just return a TypeCast expr, which // can handle this if (Type::Equal(toType, fromType->GetAsConstType())) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); if (dynamic_cast(fromType)) { if (dynamic_cast(toType)) { @@ -173,13 +173,13 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // this is handled by TypeCastExpr if (Type::Equal(toType->GetReferenceTarget(), fromType->GetReferenceTarget()->GetAsConstType())) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); const ArrayType *atFrom = dynamic_cast(fromType->GetReferenceTarget()); const ArrayType *atTo = dynamic_cast(toType->GetReferenceTarget()); if (atFrom != NULL && atTo != NULL && Type::Equal(atFrom->GetElementType(), atTo->GetElementType())) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); else { if (!failureOk) @@ -206,7 +206,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, } else if (Type::Equal(toType, fromType->GetAsNonConstType())) // convert: const T -> T (as long as T isn't a reference) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); fromType = fromType->GetReferenceTarget(); toType = toType->GetReferenceTarget(); @@ -217,16 +217,19 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, const ArrayType *toArrayType = dynamic_cast(toType); const ArrayType *fromArrayType = dynamic_cast(fromType); if (toArrayType && fromArrayType) { - if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType())) { + if (Type::Equal(toArrayType->GetElementType(), + fromArrayType->GetElementType())) { // the case of different element counts should have returned // out earlier, yes?? assert(toArrayType->GetElementCount() != fromArrayType->GetElementCount()); - return new TypeCastExpr(new ReferenceType(toType, false), this, pos); + return new TypeCastExpr(new ReferenceType(toType, false), this, + false, pos); } else if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType()->GetAsConstType())) { // T[x] -> const T[x] - return new TypeCastExpr(new ReferenceType(toType, false), this, pos); + return new TypeCastExpr(new ReferenceType(toType, false), this, + false, pos); } else { if (!failureOk) @@ -248,7 +251,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, toType->GetString().c_str(), errorMsgBase); return NULL; } - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } const StructType *toStructType = dynamic_cast(toType); @@ -263,7 +266,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, return NULL; } - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } const EnumType *toEnumType = dynamic_cast(toType); @@ -279,7 +282,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, return NULL; } - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } const AtomicType *toAtomicType = dynamic_cast(toType); @@ -288,7 +291,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // enum -> atomic (integer, generally...) is always ok if (fromEnumType != NULL) { assert(toAtomicType != NULL || toVectorType != NULL); - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } // from here on out, the from type can only be atomic something or @@ -303,7 +306,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // scalar -> short-vector conversions if (toVectorType != NULL) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); // ok, it better be a scalar->scalar conversion of some sort by now if (toAtomicType == NULL) { @@ -322,7 +325,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, errorMsgBase); #endif - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } @@ -2025,451 +2028,12 @@ SelectExpr::Print() const { /////////////////////////////////////////////////////////////////////////// // FunctionCallExpr -static std::string -lGetFunctionDeclaration(const std::string &name, const FunctionType *type) { - std::string ret; - ret += type->GetReturnType()->GetString(); - ret += " "; - ret += name; - ret += "("; - - const std::vector &argTypes = type->GetArgumentTypes(); - const std::vector &argDefaults = type->GetArgumentDefaults(); - - for (unsigned int i = 0; i < argTypes.size(); ++i) { - // If the parameter is a reference to an array, just print its type - // as the array type, since we always pass arrays by reference. - if (dynamic_cast(argTypes[i]) && - dynamic_cast(argTypes[i]->GetReferenceTarget())) - ret += argTypes[i]->GetReferenceTarget()->GetString(); - else - ret += argTypes[i]->GetString(); - ret += " "; - ret += type->GetArgumentName(i); - - // Print the default value if present - if (argDefaults[i] != NULL) { - char buf[32]; - if (argTypes[i]->IsFloatType()) { - double val; - int count = argDefaults[i]->AsDouble(&val); - assert(count == 1); - sprintf(buf, " = %g", val); - } - else if (argTypes[i]->IsBoolType()) { - bool val; - int count = argDefaults[i]->AsBool(&val); - assert(count == 1); - sprintf(buf, " = %s", val ? "true" : "false"); - } - else if (argTypes[i]->IsUnsignedType()) { - uint64_t val; - int count = argDefaults[i]->AsUInt64(&val); - assert(count == 1); -#ifdef ISPC_IS_LINUX - sprintf(buf, " = %lu", val); -#else - sprintf(buf, " = %llu", val); -#endif - } - else { - int64_t val; - int count = argDefaults[i]->AsInt64(&val); - assert(count == 1); -#ifdef ISPC_IS_LINUX - sprintf(buf, " = %ld", val); -#else - sprintf(buf, " = %lld", val); -#endif - } - ret += buf; - } - if (i != argTypes.size() - 1) - ret += ", "; - } - ret += ")"; - return ret; -} - - -static void -lPrintFunctionOverloads(const std::string &name, - const std::vector > &matches) { - fprintf(stderr, "Matching functions:\n"); - int minCost = matches[0].first; - for (unsigned int i = 1; i < matches.size(); ++i) - minCost = std::min(minCost, matches[i].first); - - for (unsigned int i = 0; i < matches.size(); ++i) { - const FunctionType *t = - dynamic_cast(matches[i].second->type); - assert(t != NULL); - if (matches[i].first == minCost) - fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); - } -} - - -static void -lPrintFunctionOverloads(const std::string &name, - const std::vector &funcs) { - fprintf(stderr, "Candidate functions:\n"); - for (unsigned int i = 0; i < funcs.size(); ++i) { - const FunctionType *t = - dynamic_cast(funcs[i]->type); - assert(t != NULL); - fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); - } -} - - -static void -lPrintPassedTypes(const char *funName, const std::vector &argExprs) { - fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' '); - for (unsigned int i = 0; i < argExprs.size(); ++i) { - const Type *t; - if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL) - fprintf(stderr, "%s%s", t->GetString().c_str(), - (i < argExprs.size()-1) ? ", " : ")\n\n"); - else - fprintf(stderr, "(unknown type)%s", - (i < argExprs.size()-1) ? ", " : ")\n\n"); - } -} - - -/** Helper function used for function overload resolution: returns zero - cost if the call argument's type exactly matches the function argument - type (modulo a conversion to a const type if needed), otherwise reports - failure. - */ -static int -lExactMatch(Expr *callArg, const Type *funcArgType) { - const Type *callType = callArg->GetType(); - - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - if (dynamic_cast(funcArgType) != NULL && - dynamic_cast(callType) == NULL) - callType = new ReferenceType(callType, funcArgType->IsConstType()); - - return Type::Equal(callType, funcArgType) ? 0 : -1; -} - -/** Helper function used for function overload resolution: returns a cost - of 1 if the call argument type and the function argument type match, - modulo conversion to a reference type if needed. - */ -static int -lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) { - int prev = lExactMatch(callArg, funcArgType); - if (prev != -1) - return prev; - - const Type *callType = callArg->GetType()->GetReferenceTarget(); - if (funcArgType->IsConstType()) - callType = callType->GetAsConstType(); - - return Type::Equal(callType, - funcArgType->GetReferenceTarget()) ? 1 : -1; -} - -/** Helper function used for function overload resolution: returns a cost - of 1 if converting the argument to the call type only requires a type - conversion that won't lose information. Otherwise reports failure. -*/ -static int -lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) { - int prev = lMatchIgnoringReferences(callArg, funcArgType); - if (prev != -1) - return prev; - - const Type *callType = callArg->GetType(); - const AtomicType *callAt = dynamic_cast(callType); - const AtomicType *funcAt = dynamic_cast(funcArgType); - if (callAt == NULL || funcAt == NULL) - return -1; - - if (callAt->IsUniformType() != funcAt->IsUniformType()) - return -1; - - switch (callAt->basicType) { - case AtomicType::TYPE_BOOL: - return 1; - case AtomicType::TYPE_INT8: - case AtomicType::TYPE_UINT8: - return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; - case AtomicType::TYPE_INT16: - case AtomicType::TYPE_UINT16: - return (funcAt->basicType != AtomicType::TYPE_BOOL && - funcAt->basicType != AtomicType::TYPE_INT8 && - funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; - case AtomicType::TYPE_INT32: - case AtomicType::TYPE_UINT32: - return (funcAt->basicType == AtomicType::TYPE_INT32 || - funcAt->basicType == AtomicType::TYPE_UINT32 || - funcAt->basicType == AtomicType::TYPE_INT64 || - funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; - case AtomicType::TYPE_FLOAT: - return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; - case AtomicType::TYPE_INT64: - case AtomicType::TYPE_UINT64: - return (funcAt->basicType == AtomicType::TYPE_INT64 || - funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; - case AtomicType::TYPE_DOUBLE: - return -1; - default: - FATAL("Unhandled atomic type"); - return -1; - } -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if the call argument type and the function argument type match if - we only do a uniform -> varying type conversion but otherwise have - exactly the same type. - */ -static int -lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) { - int prev = lMatchWithTypeWidening(callArg, funcArgType); - if (prev != -1) - return prev; - - const Type *callType = callArg->GetType(); - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - - return (callType->IsUniformType() && - funcArgType->IsVaryingType() && - Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if we can type convert from the call argument type to the function - argument type, but without doing a uniform -> varying conversion. - */ -static int -lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) { - int prev = lMatchIgnoringUniform(callArg, funcArgType); - if (prev != -1) - return prev; - - Expr *te = callArg->TypeConv(funcArgType, - "function call argument", true); - if (te != NULL && - te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType()) - return 1; - else - return -1; -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if there is any type conversion that gets us from the caller - argument type to the function argument type. - */ -static int -lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) { - int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType); - if (prev != -1) - return prev; - - Expr *te = callArg->TypeConv(funcArgType, - "function call argument", true); - return (te != NULL) ? 0 : -1; -} - - -/** Given a set of potential matching functions and their associated cost, - return the one with the lowest cost, if unique. Otherwise, if multiple - functions match with the same cost, return NULL. - */ -static Symbol * -lGetBestMatch(std::vector > &matches) { - assert(matches.size() > 0); - int minCost = matches[0].first; - - for (unsigned int i = 1; i < matches.size(); ++i) - minCost = std::min(minCost, matches[i].first); - - Symbol *match = NULL; - for (unsigned int i = 0; i < matches.size(); ++i) { - if (matches[i].first == minCost) { - if (match != NULL) - // multiple things had the same cost - return NULL; - else - match = matches[i].second; - } - } - return match; -} - - -/** See if we can find a single function from the set of overload options - based on the predicate function passed in. Returns true if no more - tries should be made to find a match, either due to success from - finding a single overloaded function that matches or failure due to - finding multiple ambiguous matches. - */ -bool -FunctionCallExpr::tryResolve(int (*matchFunc)(Expr *, const Type *)) { - FunctionSymbolExpr *fse = dynamic_cast(func); - - const char *funName = fse->candidateFunctions->front()->name.c_str(); - std::vector &callArgs = args->exprs; - - std::vector > matches; - std::vector::iterator iter; - for (iter = fse->candidateFunctions->begin(); - iter != fse->candidateFunctions->end(); ++iter) { - // Loop over the set of candidate functions and try each one - Symbol *candidateFunction = *iter; - const FunctionType *ft = - dynamic_cast(candidateFunction->type); - assert(ft != NULL); - const std::vector &funcArgTypes = ft->GetArgumentTypes(); - const std::vector &argumentDefaults = ft->GetArgumentDefaults(); - - // There's no way to match if the caller is passing more arguments - // than this function instance takes. - if (callArgs.size() > funcArgTypes.size()) - continue; - - unsigned int i; - // Note that we're looping over the caller arguments, not the - // function arguments; it may be ok to have more arguments to the - // function than are passed, if the function has default argument - // values. This case is handled below. - int cost = 0; - for (i = 0; i < callArgs.size(); ++i) { - // This may happen if there's an error earlier in compilation. - // It's kind of a silly to redundantly discover this for each - // potential match versus detecting this earlier in the - // matching process and just giving up. - if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] || - dynamic_cast(callArgs[i]->GetType()) != NULL) - return false; - - int argCost = matchFunc(callArgs[i], funcArgTypes[i]); - if (argCost == -1) - // If the predicate function returns -1, we have failed no - // matter what else happens, so we stop trying - break; - cost += argCost; - } - if (i == callArgs.size()) { - // All of the arguments matched! - if (i == funcArgTypes.size()) - // And we have exactly as many arguments as the function - // wants, so we're done. - matches.push_back(std::make_pair(cost, candidateFunction)); - else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL) - // Otherwise we can still make it if there are default - // arguments for the rest of the arguments! Because in - // Module::AddFunction() we have verified that once the - // default arguments start, then all of the following ones - // have them as well. Therefore, we just need to check if - // the arg we stopped at has a default value and we're - // done. - matches.push_back(std::make_pair(cost, candidateFunction)); - // otherwise, we don't have a match - } - } - - if (matches.size() == 0) - return false; - else if ((fse->matchingFunc = lGetBestMatch(matches)) != NULL) { - // We have a match--fill in with any default argument values - // needed. - const FunctionType *ft = - dynamic_cast(fse->matchingFunc->type); - assert(ft != NULL); - const std::vector &argumentDefaults = ft->GetArgumentDefaults(); - const std::vector &argTypes = ft->GetArgumentTypes(); - assert(argumentDefaults.size() == argTypes.size()); - for (unsigned int i = callArgs.size(); i < argTypes.size(); ++i) { - assert(argumentDefaults[i] != NULL); - args->exprs.push_back(argumentDefaults[i]); - } - return true; - } - else { - Error(fse->pos, "Multiple overloaded instances of function \"%s\" matched.", - funName); - lPrintFunctionOverloads(funName, matches); - lPrintPassedTypes(funName, args->exprs); - // Stop trying to find more matches after failure - return true; - } -} - - -void -FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) { - FunctionSymbolExpr *fse = dynamic_cast(func); - if (!fse) - // error will be issued later if not calling an actual function - return; - assert(args); - - // Is there an exact match that doesn't require any argument type - // conversion (other than converting type -> reference type)? - if (tryResolve(lExactMatch)) - return; - - if (!exactMatchOnly) { - // Try to find a single match ignoring references - if (tryResolve(lMatchIgnoringReferences)) - return; - - // Try to find an exact match via type widening--i.e. int8 -> - // int16, etc.--things that don't lose data. - if (tryResolve(lMatchWithTypeWidening)) - return; - - // Next try to see if there's a match via just uniform -> varying - // promotions. - if (tryResolve(lMatchIgnoringUniform)) - return; - - // Try to find a match via type conversion, but don't change - // unif->varying - if (tryResolve(lMatchWithTypeConvSameVariability)) - return; - - // Last chance: try to find a match via arbitrary type conversion. - if (tryResolve(lMatchWithTypeConv)) - return; - } - - // failure :-( - const char *funName = fse->candidateFunctions->front()->name.c_str(); - Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", - funName, exactMatchOnly ? " only considering exact matches" : ""); - lPrintFunctionOverloads(funName, *fse->candidateFunctions); - lPrintPassedTypes(funName, args->exprs); -} - - FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, bool il, Expr *lce) : Expr(p), isLaunch(il) { func = f; args = a; launchCountExpr = lce; - - FunctionSymbolExpr *fse = dynamic_cast(func); - // Functions with names that start with "__" should only be various - // builtins. For those, we'll demand an exact match, since we'll - // expect whichever function in stdlib.ispc is calling out to one of - // those to be matching the argument types exactly; this is to be a bit - // extra safe to be sure that the expected builtin is in fact being - // called. - bool exactMatchOnly = (fse != NULL) && (fse->name.substr(0,2) == "__"); - resolveFunctionOverloads(exactMatchOnly); } @@ -2481,18 +2045,16 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { ctx->SetDebugPos(pos); FunctionSymbolExpr *fse = dynamic_cast(func); - if (!fse) { - Error(pos, "No valid function available for function call."); - return NULL; - } + assert(fse != NULL); // should be caught during typechecking - if (!fse->matchingFunc) - // no overload match was found, get out of here.. + Symbol *funSym = fse->GetMatchingFunction(); + if (funSym == NULL) + // No match was found; an error should have been issued earlier, so + // just return. return NULL; - Symbol *funSym = fse->matchingFunc; llvm::Function *callee = funSym->function; - if (!callee) { + if (callee == NULL) { Error(pos, "Symbol \"%s\" is not a function.", funSym->name.c_str()); return NULL; } @@ -2510,7 +2072,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { bool err = false; for (unsigned int i = 0; i < callargs.size(); ++i) { Expr *argExpr = callargs[i]; - if (!argExpr) + if (argExpr == NULL) continue; // All arrays should already have been converted to reference types @@ -2546,6 +2108,19 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { if (err) return NULL; + // Fill in any default argument values needed. + // FIXME: should we do this during type checking? + const std::vector &argumentDefaults = ft->GetArgumentDefaults(); + for (unsigned int i = callargs.size(); i < argumentDefaults.size(); ++i) { + assert(argumentDefaults[i] != NULL); + Expr *defaultExpr = argumentDefaults[i]->TypeConv(argTypes[i], + "function call default argument"); + if (defaultExpr == NULL) + return NULL; + + callargs.push_back(defaultExpr); + } + // Now evaluate the values of all of the parameters being passed. We // need to evaluate these first here, since their GetValue() calls may // change the current basic block (e.g. if one of these is itself a @@ -2594,6 +2169,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { } } + llvm::Value *retVal = NULL; ctx->SetDebugPos(pos); if (ft->isTask) { @@ -2641,14 +2217,16 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { const Type * FunctionCallExpr::GetType() const { FunctionSymbolExpr *fse = dynamic_cast(func); - if (fse && fse->matchingFunc) { - const FunctionType *ft = - dynamic_cast(fse->matchingFunc->type); - assert(ft != NULL); - return ft->GetReturnType(); - } - else + if (fse == NULL) return NULL; + + Symbol *sym = fse->GetMatchingFunction(); + if (sym == NULL) + return NULL; + + const FunctionType *ft = dynamic_cast(sym->type); + assert(ft != NULL); + return ft->GetReturnType(); } @@ -2667,40 +2245,51 @@ FunctionCallExpr::Optimize() { Expr * FunctionCallExpr::TypeCheck() { - if (func) { - func = func->TypeCheck(); - if (func != NULL) { - const FunctionType *ft = dynamic_cast(func->GetType()); - if (ft != NULL) { - if (ft->isTask) { - if (!isLaunch) - Error(pos, "\"launch\" expression needed to call function " - "with \"task\" qualifier."); - if (!launchCountExpr) - return NULL; + if (args != NULL) + args = args->TypeCheck(); - launchCountExpr = - launchCountExpr->TypeConv(AtomicType::UniformInt32, - "task launch count"); - if (!launchCountExpr) - return NULL; - } - else { - if (isLaunch) - Error(pos, "\"launch\" expression illegal with non-\"task\"-" - "qualified function."); - assert(launchCountExpr == NULL); + if (args != NULL && func != NULL) { + FunctionSymbolExpr *fse = dynamic_cast(func); + + if (fse == NULL) { + Error(pos, "No valid function available for function call."); + return NULL; + } + + if (fse->ResolveOverloads(args->exprs) == true) { + func = fse->TypeCheck(); + + if (func != NULL) { + const FunctionType *ft = + dynamic_cast(func->GetType()); + if (ft != NULL) { + if (ft->isTask) { + if (!isLaunch) + Error(pos, "\"launch\" expression needed to call function " + "with \"task\" qualifier."); + if (!launchCountExpr) + return NULL; + + launchCountExpr = + launchCountExpr->TypeConv(AtomicType::UniformInt32, + "task launch count"); + if (!launchCountExpr) + return NULL; + } + else { + if (isLaunch) + Error(pos, "\"launch\" expression illegal with non-\"task\"-" + "qualified function."); + assert(launchCountExpr == NULL); + } } + else + Error(pos, "Valid function name must be used for function call."); } - else - Error(pos, "Valid function name must be used for function call."); } } - if (args) - args = args->TypeCheck(); - - if (!func || !args) + if (func == NULL || args == NULL) return NULL; return this; } @@ -4324,10 +3913,11 @@ ConstExpr::Print() const { /////////////////////////////////////////////////////////////////////////// // TypeCastExpr -TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p) +TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, bool pu, SourcePos p) : Expr(p) { type = t; expr = e; + preserveUniformity = pu; } @@ -5029,6 +4619,13 @@ TypeCastExpr::TypeCheck() { if (toType == NULL || fromType == NULL) return NULL; + if (preserveUniformity == true && fromType->IsUniformType() && + toType->IsVaryingType()) { + TypeCastExpr *tce = new TypeCastExpr(toType->GetAsUniformType(), + expr, false, pos); + return tce->TypeCheck(); + } + const char *toTypeString = toType->GetString().c_str(); const char *fromTypeString = fromType->GetString().c_str(); @@ -5494,6 +5091,434 @@ FunctionSymbolExpr::Print() const { } + +static std::string +lGetFunctionDeclaration(const std::string &name, const FunctionType *type) { + std::string ret; + ret += type->GetReturnType()->GetString(); + ret += " "; + ret += name; + ret += "("; + + const std::vector &argTypes = type->GetArgumentTypes(); + const std::vector &argDefaults = type->GetArgumentDefaults(); + + for (unsigned int i = 0; i < argTypes.size(); ++i) { + // If the parameter is a reference to an array, just print its type + // as the array type, since we always pass arrays by reference. + if (dynamic_cast(argTypes[i]) && + dynamic_cast(argTypes[i]->GetReferenceTarget())) + ret += argTypes[i]->GetReferenceTarget()->GetString(); + else + ret += argTypes[i]->GetString(); + ret += " "; + ret += type->GetArgumentName(i); + + // Print the default value if present + if (argDefaults[i] != NULL) { + char buf[32]; + if (argTypes[i]->IsFloatType()) { + double val; + int count = argDefaults[i]->AsDouble(&val); + assert(count == 1); + sprintf(buf, " = %g", val); + } + else if (argTypes[i]->IsBoolType()) { + bool val; + int count = argDefaults[i]->AsBool(&val); + assert(count == 1); + sprintf(buf, " = %s", val ? "true" : "false"); + } + else if (argTypes[i]->IsUnsignedType()) { + uint64_t val; + int count = argDefaults[i]->AsUInt64(&val); + assert(count == 1); +#ifdef ISPC_IS_LINUX + sprintf(buf, " = %lu", val); +#else + sprintf(buf, " = %llu", val); +#endif + } + else { + int64_t val; + int count = argDefaults[i]->AsInt64(&val); + assert(count == 1); +#ifdef ISPC_IS_LINUX + sprintf(buf, " = %ld", val); +#else + sprintf(buf, " = %lld", val); +#endif + } + ret += buf; + } + if (i != argTypes.size() - 1) + ret += ", "; + } + ret += ")"; + return ret; +} + + +static void +lPrintFunctionOverloads(const std::string &name, + const std::vector > &matches) { + fprintf(stderr, "Matching functions:\n"); + int minCost = matches[0].first; + for (unsigned int i = 1; i < matches.size(); ++i) + minCost = std::min(minCost, matches[i].first); + + for (unsigned int i = 0; i < matches.size(); ++i) { + const FunctionType *t = + dynamic_cast(matches[i].second->type); + assert(t != NULL); + if (matches[i].first == minCost) + fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); + } +} + + +static void +lPrintFunctionOverloads(const std::string &name, + const std::vector &funcs) { + fprintf(stderr, "Candidate functions:\n"); + for (unsigned int i = 0; i < funcs.size(); ++i) { + const FunctionType *t = + dynamic_cast(funcs[i]->type); + assert(t != NULL); + fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); + } +} + + +static void +lPrintPassedTypes(const char *funName, const std::vector &argExprs) { + fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' '); + for (unsigned int i = 0; i < argExprs.size(); ++i) { + const Type *t; + if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL) + fprintf(stderr, "%s%s", t->GetString().c_str(), + (i < argExprs.size()-1) ? ", " : ")\n\n"); + else + fprintf(stderr, "(unknown type)%s", + (i < argExprs.size()-1) ? ", " : ")\n\n"); + } +} + + +/** Helper function used for function overload resolution: returns zero + cost if the call argument's type exactly matches the function argument + type (modulo a conversion to a const type if needed), otherwise reports + failure. + */ +static int +lExactMatch(Expr *callArg, const Type *funcArgType) { + const Type *callType = callArg->GetType(); + + if (dynamic_cast(callType) == NULL) + callType = callType->GetAsNonConstType(); + if (dynamic_cast(funcArgType) != NULL && + dynamic_cast(callType) == NULL) + callType = new ReferenceType(callType, funcArgType->IsConstType()); + + return Type::Equal(callType, funcArgType) ? 0 : -1; +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if the call argument type and the function argument type match, + modulo conversion to a reference type if needed. + */ +static int +lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) { + int prev = lExactMatch(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType()->GetReferenceTarget(); + if (funcArgType->IsConstType()) + callType = callType->GetAsConstType(); + + return Type::Equal(callType, + funcArgType->GetReferenceTarget()) ? 1 : -1; +} + +/** Helper function used for function overload resolution: returns a cost + of 1 if converting the argument to the call type only requires a type + conversion that won't lose information. Otherwise reports failure. +*/ +static int +lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) { + int prev = lMatchIgnoringReferences(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType(); + const AtomicType *callAt = dynamic_cast(callType); + const AtomicType *funcAt = dynamic_cast(funcArgType); + if (callAt == NULL || funcAt == NULL) + return -1; + + if (callAt->IsUniformType() != funcAt->IsUniformType()) + return -1; + + switch (callAt->basicType) { + case AtomicType::TYPE_BOOL: + return 1; + case AtomicType::TYPE_INT8: + case AtomicType::TYPE_UINT8: + return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; + case AtomicType::TYPE_INT16: + case AtomicType::TYPE_UINT16: + return (funcAt->basicType != AtomicType::TYPE_BOOL && + funcAt->basicType != AtomicType::TYPE_INT8 && + funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; + case AtomicType::TYPE_INT32: + case AtomicType::TYPE_UINT32: + return (funcAt->basicType == AtomicType::TYPE_INT32 || + funcAt->basicType == AtomicType::TYPE_UINT32 || + funcAt->basicType == AtomicType::TYPE_INT64 || + funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + case AtomicType::TYPE_FLOAT: + return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; + case AtomicType::TYPE_INT64: + case AtomicType::TYPE_UINT64: + return (funcAt->basicType == AtomicType::TYPE_INT64 || + funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + case AtomicType::TYPE_DOUBLE: + return -1; + default: + FATAL("Unhandled atomic type"); + return -1; + } +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if the call argument type and the function argument type match if + we only do a uniform -> varying type conversion but otherwise have + exactly the same type. + */ +static int +lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) { + int prev = lMatchWithTypeWidening(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType(); + if (dynamic_cast(callType) == NULL) + callType = callType->GetAsNonConstType(); + + return (callType->IsUniformType() && + funcArgType->IsVaryingType() && + Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if we can type convert from the call argument type to the function + argument type, but without doing a uniform -> varying conversion. + */ +static int +lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) { + int prev = lMatchIgnoringUniform(callArg, funcArgType); + if (prev != -1) + return prev; + + Expr *te = callArg->TypeConv(funcArgType, + "function call argument", true); + if (te != NULL && + te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType()) + return 1; + else + return -1; +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if there is any type conversion that gets us from the caller + argument type to the function argument type. + */ +static int +lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) { + int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType); + if (prev != -1) + return prev; + + Expr *te = callArg->TypeConv(funcArgType, + "function call argument", true); + return (te != NULL) ? 0 : -1; +} + + +/** Given a set of potential matching functions and their associated cost, + return the one with the lowest cost, if unique. Otherwise, if multiple + functions match with the same cost, return NULL. + */ +static Symbol * +lGetBestMatch(std::vector > &matches) { + assert(matches.size() > 0); + int minCost = matches[0].first; + + for (unsigned int i = 1; i < matches.size(); ++i) + minCost = std::min(minCost, matches[i].first); + + Symbol *match = NULL; + for (unsigned int i = 0; i < matches.size(); ++i) { + if (matches[i].first == minCost) { + if (match != NULL) + // multiple things had the same cost + return NULL; + else + match = matches[i].second; + } + } + return match; +} + + +/** See if we can find a single function from the set of overload options + based on the predicate function passed in. Returns true if no more + tries should be made to find a match, either due to success from + finding a single overloaded function that matches or failure due to + finding multiple ambiguous matches. + */ +bool +FunctionSymbolExpr::tryResolve(int (*matchFunc)(Expr *, const Type *), + const std::vector &callArgs) { + const char *funName = candidateFunctions->front()->name.c_str(); + + std::vector > matches; + std::vector::iterator iter; + for (iter = candidateFunctions->begin(); + iter != candidateFunctions->end(); ++iter) { + // Loop over the set of candidate functions and try each one + Symbol *candidateFunction = *iter; + const FunctionType *ft = + dynamic_cast(candidateFunction->type); + assert(ft != NULL); + const std::vector &funcArgTypes = ft->GetArgumentTypes(); + const std::vector &argumentDefaults = ft->GetArgumentDefaults(); + + // There's no way to match if the caller is passing more arguments + // than this function instance takes. + if (callArgs.size() > funcArgTypes.size()) + continue; + + unsigned int i; + // Note that we're looping over the caller arguments, not the + // function arguments; it may be ok to have more arguments to the + // function than are passed, if the function has default argument + // values. This case is handled below. + int cost = 0; + for (i = 0; i < callArgs.size(); ++i) { + // This may happen if there's an error earlier in compilation. + // It's kind of a silly to redundantly discover this for each + // potential match versus detecting this earlier in the + // matching process and just giving up. + if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] || + dynamic_cast(callArgs[i]->GetType()) != NULL) + return false; + + int argCost = matchFunc(callArgs[i], funcArgTypes[i]); + if (argCost == -1) + // If the predicate function returns -1, we have failed no + // matter what else happens, so we stop trying + break; + cost += argCost; + } + if (i == callArgs.size()) { + // All of the arguments matched! + if (i == funcArgTypes.size()) + // And we have exactly as many arguments as the function + // wants, so we're done. + matches.push_back(std::make_pair(cost, candidateFunction)); + else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL) + // Otherwise we can still make it if there are default + // arguments for the rest of the arguments! Because in + // Module::AddFunction() we have verified that once the + // default arguments start, then all of the following ones + // have them as well. Therefore, we just need to check if + // the arg we stopped at has a default value and we're + // done. + matches.push_back(std::make_pair(cost, candidateFunction)); + // otherwise, we don't have a match + } + } + + if (matches.size() == 0) + return false; + else if ((matchingFunc = lGetBestMatch(matches)) != NULL) + // We have a match! + return true; + else { + Error(pos, "Multiple overloaded instances of function \"%s\" matched.", + funName); + lPrintFunctionOverloads(funName, matches); + lPrintPassedTypes(funName, callArgs); + // Stop trying to find more matches after an ambigious set of + // matches. + return true; + } +} + + +bool +FunctionSymbolExpr::ResolveOverloads(const std::vector &callArgs) { + // Functions with names that start with "__" should only be various + // builtins. For those, we'll demand an exact match, since we'll + // expect whichever function in stdlib.ispc is calling out to one of + // those to be matching the argument types exactly; this is to be a bit + // extra safe to be sure that the expected builtin is in fact being + // called. + bool exactMatchOnly = (name.substr(0,2) == "__"); + + // Is there an exact match that doesn't require any argument type + // conversion (other than converting type -> reference type)? + if (tryResolve(lExactMatch, callArgs)) + return true; + + if (exactMatchOnly == false) { + // Try to find a single match ignoring references + if (tryResolve(lMatchIgnoringReferences, callArgs)) + return true; + + // Try to find an exact match via type widening--i.e. int8 -> + // int16, etc.--things that don't lose data. + if (tryResolve(lMatchWithTypeWidening, callArgs)) + return true; + + // Next try to see if there's a match via just uniform -> varying + // promotions. + if (tryResolve(lMatchIgnoringUniform, callArgs)) + return true; + + // Try to find a match via type conversion, but don't change + // unif->varying + if (tryResolve(lMatchWithTypeConvSameVariability, + callArgs)) + return true; + + // Last chance: try to find a match via arbitrary type conversion. + if (tryResolve(lMatchWithTypeConv, callArgs)) + return true; + } + + // failure :-( + const char *funName = candidateFunctions->front()->name.c_str(); + Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", + funName, exactMatchOnly ? " only considering exact matches" : ""); + lPrintFunctionOverloads(funName, *candidateFunctions); + lPrintPassedTypes(funName, callArgs); + return false; +} + + +Symbol * +FunctionSymbolExpr::GetMatchingFunction() { + return matchingFunc; +} + + /////////////////////////////////////////////////////////////////////////// // SyncExpr diff --git a/expr.h b/expr.h index 0ee6c80c..e571d88a 100644 --- a/expr.h +++ b/expr.h @@ -42,8 +42,6 @@ #include "ast.h" #include "type.h" -class FunctionSymbolExpr; - /** @brief Expr is the abstract base class that defines the interface that all expression types must implement. */ @@ -266,10 +264,6 @@ public: ExprList *args; bool isLaunch; Expr *launchCountExpr; - -private: - void resolveFunctionOverloads(bool exactMatchOnly); - bool tryResolve(int (*matchFunc)(Expr *, const Type *)); }; @@ -495,7 +489,8 @@ private: probably-different type. */ class TypeCastExpr : public Expr { public: - TypeCastExpr(const Type *t, Expr *e, SourcePos p); + TypeCastExpr(const Type *t, Expr *e, bool preserveUniformity, + SourcePos p); llvm::Value *GetValue(FunctionEmitContext *ctx) const; const Type *GetType() const; @@ -506,6 +501,7 @@ public: const Type *type; Expr *expr; + bool preserveUniformity; }; @@ -581,8 +577,12 @@ public: void Print() const; int EstimateCost() const; + bool ResolveOverloads(const std::vector &args); + Symbol *GetMatchingFunction(); + private: - friend class FunctionCallExpr; + bool tryResolve(int (*matchFunc)(Expr *, const Type *), + const std::vector &args); /** Name of the function that is being called. */ std::string name; @@ -592,8 +592,7 @@ private: overload is the best match. */ std::vector *candidateFunctions; - /** The actual matching function found after overload resolution; this - value is set by FunctionCallExpr::resolveFunctionOverloads() */ + /** The actual matching function found after overload resolution. */ Symbol *matchingFunc; }; diff --git a/parse.yy b/parse.yy index 53c97b9b..c7ac848b 100644 --- a/parse.yy +++ b/parse.yy @@ -326,19 +326,13 @@ cast_expression : unary_expression | '(' type_name ')' cast_expression { - // If type_name isn't explicitly a varying, We do a GetUniform() - // call here so that things like: + // Pass true here to try to preserve uniformity + // so that things like: // uniform int y = ...; // uniform float x = 1. / (float)y; // don't issue an error due to (float)y being inadvertently // and undesirably-to-the-user "varying"... - if ($2 == NULL || $4 == NULL || $4->GetType() == NULL) - $$ = NULL; - else { - if ($4->GetType()->IsUniformType()) - $2 = $2->GetAsUniformType(); - $$ = new TypeCastExpr($2, $4, @1); - } + $$ = new TypeCastExpr($2, $4, true, @1); } ; @@ -1463,7 +1457,7 @@ lFinalizeEnumeratorSymbols(std::vector &enums, the actual enum type here and optimize it, which will have us end up with a ConstExpr with the desired EnumType... */ Expr *castExpr = new TypeCastExpr(enumType, enums[i]->constValue, - enums[i]->pos); + false, enums[i]->pos); castExpr = castExpr->Optimize(); enums[i]->constValue = dynamic_cast(castExpr); assert(enums[i]->constValue != NULL); diff --git a/stmt.cpp b/stmt.cpp index c2673112..1be6d3a6 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -519,15 +519,16 @@ Stmt *IfStmt::TypeCheck() { if (test) { const Type *testType = test->GetType(); if (testType) { - bool isUniform = testType->IsUniformType() && !g->opt.disableUniformControlFlow; + bool isUniform = (testType->IsUniformType() && + !g->opt.disableUniformControlFlow); if (!testType->IsNumericType() && !testType->IsBoolType()) { - Error(test->pos, "Type \"%s\" can't be converted to boolean for \"if\" test.", - testType->GetString().c_str()); + Error(test->pos, "Type \"%s\" can't be converted to boolean " + "for \"if\" test.", testType->GetString().c_str()); return NULL; } test = new TypeCastExpr(isUniform ? AtomicType::UniformBool : - AtomicType::VaryingBool, - test, test->pos); + AtomicType::VaryingBool, + test, false, test->pos); assert(test); } } @@ -1171,7 +1172,7 @@ DoStmt::TypeCheck() { !lHasVaryingBreakOrContinue(bodyStmts)); testExpr = new TypeCastExpr(uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool, - testExpr, testExpr->pos); + testExpr, false, testExpr->pos); } } } @@ -1373,7 +1374,7 @@ ForStmt::TypeCheck() { !lHasVaryingBreakOrContinue(stmts)); test = new TypeCastExpr(uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool, - test, test->pos); + test, false, test->pos); } } } @@ -1685,7 +1686,7 @@ lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, std::string &argTypes) { baseType == AtomicType::UniformUInt16) { expr = new TypeCastExpr(type->IsUniformType() ? AtomicType::UniformInt32 : AtomicType::VaryingInt32, - expr, expr->pos); + expr, false, expr->pos); type = expr->GetType(); } @@ -1908,7 +1909,7 @@ AssertStmt::TypeCheck() { } expr = new TypeCastExpr(isUniform ? AtomicType::UniformBool : AtomicType::VaryingBool, - expr, expr->pos); + expr, false, expr->pos); } } return this; diff --git a/type.cpp b/type.cpp index bf56904a..90c713c2 100644 --- a/type.cpp +++ b/type.cpp @@ -1887,6 +1887,15 @@ FunctionType::SetArgumentDefaults(const std::vector &d) const { } +std::string +FunctionType::GetArgumentName(int i) const { + if (i >= (int)argNames.size()) + return ""; + else + return argNames[i]; +} + + /////////////////////////////////////////////////////////////////////////// // Type diff --git a/type.h b/type.h index a985dc2e..56eacc95 100644 --- a/type.h +++ b/type.h @@ -677,7 +677,7 @@ public: const std::vector &GetArgumentTypes() const { return argTypes; } const std::vector &GetArgumentDefaults() const { return argDefaults; } - const std::string &GetArgumentName(int i) const { return argNames[i]; } + std::string GetArgumentName(int i) const; /** @todo It would be nice to pull this information together and pass it when the constructor is called; it's kind of ugly to set it like