diff --git a/expr.cpp b/expr.cpp index 6e522a80..2e1be3c5 100644 --- a/expr.cpp +++ b/expr.cpp @@ -165,7 +165,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // Convert from type T -> const T; just return a TypeCast expr, which // can handle this if (Type::Equal(toType, fromType->GetAsConstType())) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); if (dynamic_cast(fromType)) { if (dynamic_cast(toType)) { @@ -173,13 +173,13 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // this is handled by TypeCastExpr if (Type::Equal(toType->GetReferenceTarget(), fromType->GetReferenceTarget()->GetAsConstType())) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); const ArrayType *atFrom = dynamic_cast(fromType->GetReferenceTarget()); const ArrayType *atTo = dynamic_cast(toType->GetReferenceTarget()); if (atFrom != NULL && atTo != NULL && Type::Equal(atFrom->GetElementType(), atTo->GetElementType())) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); else { if (!failureOk) @@ -206,7 +206,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, } else if (Type::Equal(toType, fromType->GetAsNonConstType())) // convert: const T -> T (as long as T isn't a reference) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); fromType = fromType->GetReferenceTarget(); toType = toType->GetReferenceTarget(); @@ -217,16 +217,19 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, const ArrayType *toArrayType = dynamic_cast(toType); const ArrayType *fromArrayType = dynamic_cast(fromType); if (toArrayType && fromArrayType) { - if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType())) { + if (Type::Equal(toArrayType->GetElementType(), + fromArrayType->GetElementType())) { // the case of different element counts should have returned // out earlier, yes?? assert(toArrayType->GetElementCount() != fromArrayType->GetElementCount()); - return new TypeCastExpr(new ReferenceType(toType, false), this, pos); + return new TypeCastExpr(new ReferenceType(toType, false), this, + false, pos); } else if (Type::Equal(toArrayType->GetElementType(), fromArrayType->GetElementType()->GetAsConstType())) { // T[x] -> const T[x] - return new TypeCastExpr(new ReferenceType(toType, false), this, pos); + return new TypeCastExpr(new ReferenceType(toType, false), this, + false, pos); } else { if (!failureOk) @@ -248,7 +251,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, toType->GetString().c_str(), errorMsgBase); return NULL; } - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } const StructType *toStructType = dynamic_cast(toType); @@ -263,7 +266,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, return NULL; } - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } const EnumType *toEnumType = dynamic_cast(toType); @@ -279,7 +282,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, return NULL; } - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } const AtomicType *toAtomicType = dynamic_cast(toType); @@ -288,7 +291,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // enum -> atomic (integer, generally...) is always ok if (fromEnumType != NULL) { assert(toAtomicType != NULL || toVectorType != NULL); - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } // from here on out, the from type can only be atomic something or @@ -303,7 +306,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, // scalar -> short-vector conversions if (toVectorType != NULL) - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); // ok, it better be a scalar->scalar conversion of some sort by now if (toAtomicType == NULL) { @@ -322,7 +325,7 @@ Expr::TypeConv(const Type *toType, const char *errorMsgBase, bool failureOk, errorMsgBase); #endif - return new TypeCastExpr(toType, this, pos); + return new TypeCastExpr(toType, this, false, pos); } @@ -2025,451 +2028,12 @@ SelectExpr::Print() const { /////////////////////////////////////////////////////////////////////////// // FunctionCallExpr -static std::string -lGetFunctionDeclaration(const std::string &name, const FunctionType *type) { - std::string ret; - ret += type->GetReturnType()->GetString(); - ret += " "; - ret += name; - ret += "("; - - const std::vector &argTypes = type->GetArgumentTypes(); - const std::vector &argDefaults = type->GetArgumentDefaults(); - - for (unsigned int i = 0; i < argTypes.size(); ++i) { - // If the parameter is a reference to an array, just print its type - // as the array type, since we always pass arrays by reference. - if (dynamic_cast(argTypes[i]) && - dynamic_cast(argTypes[i]->GetReferenceTarget())) - ret += argTypes[i]->GetReferenceTarget()->GetString(); - else - ret += argTypes[i]->GetString(); - ret += " "; - ret += type->GetArgumentName(i); - - // Print the default value if present - if (argDefaults[i] != NULL) { - char buf[32]; - if (argTypes[i]->IsFloatType()) { - double val; - int count = argDefaults[i]->AsDouble(&val); - assert(count == 1); - sprintf(buf, " = %g", val); - } - else if (argTypes[i]->IsBoolType()) { - bool val; - int count = argDefaults[i]->AsBool(&val); - assert(count == 1); - sprintf(buf, " = %s", val ? "true" : "false"); - } - else if (argTypes[i]->IsUnsignedType()) { - uint64_t val; - int count = argDefaults[i]->AsUInt64(&val); - assert(count == 1); -#ifdef ISPC_IS_LINUX - sprintf(buf, " = %lu", val); -#else - sprintf(buf, " = %llu", val); -#endif - } - else { - int64_t val; - int count = argDefaults[i]->AsInt64(&val); - assert(count == 1); -#ifdef ISPC_IS_LINUX - sprintf(buf, " = %ld", val); -#else - sprintf(buf, " = %lld", val); -#endif - } - ret += buf; - } - if (i != argTypes.size() - 1) - ret += ", "; - } - ret += ")"; - return ret; -} - - -static void -lPrintFunctionOverloads(const std::string &name, - const std::vector > &matches) { - fprintf(stderr, "Matching functions:\n"); - int minCost = matches[0].first; - for (unsigned int i = 1; i < matches.size(); ++i) - minCost = std::min(minCost, matches[i].first); - - for (unsigned int i = 0; i < matches.size(); ++i) { - const FunctionType *t = - dynamic_cast(matches[i].second->type); - assert(t != NULL); - if (matches[i].first == minCost) - fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); - } -} - - -static void -lPrintFunctionOverloads(const std::string &name, - const std::vector &funcs) { - fprintf(stderr, "Candidate functions:\n"); - for (unsigned int i = 0; i < funcs.size(); ++i) { - const FunctionType *t = - dynamic_cast(funcs[i]->type); - assert(t != NULL); - fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); - } -} - - -static void -lPrintPassedTypes(const char *funName, const std::vector &argExprs) { - fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' '); - for (unsigned int i = 0; i < argExprs.size(); ++i) { - const Type *t; - if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL) - fprintf(stderr, "%s%s", t->GetString().c_str(), - (i < argExprs.size()-1) ? ", " : ")\n\n"); - else - fprintf(stderr, "(unknown type)%s", - (i < argExprs.size()-1) ? ", " : ")\n\n"); - } -} - - -/** Helper function used for function overload resolution: returns zero - cost if the call argument's type exactly matches the function argument - type (modulo a conversion to a const type if needed), otherwise reports - failure. - */ -static int -lExactMatch(Expr *callArg, const Type *funcArgType) { - const Type *callType = callArg->GetType(); - - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - if (dynamic_cast(funcArgType) != NULL && - dynamic_cast(callType) == NULL) - callType = new ReferenceType(callType, funcArgType->IsConstType()); - - return Type::Equal(callType, funcArgType) ? 0 : -1; -} - -/** Helper function used for function overload resolution: returns a cost - of 1 if the call argument type and the function argument type match, - modulo conversion to a reference type if needed. - */ -static int -lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) { - int prev = lExactMatch(callArg, funcArgType); - if (prev != -1) - return prev; - - const Type *callType = callArg->GetType()->GetReferenceTarget(); - if (funcArgType->IsConstType()) - callType = callType->GetAsConstType(); - - return Type::Equal(callType, - funcArgType->GetReferenceTarget()) ? 1 : -1; -} - -/** Helper function used for function overload resolution: returns a cost - of 1 if converting the argument to the call type only requires a type - conversion that won't lose information. Otherwise reports failure. -*/ -static int -lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) { - int prev = lMatchIgnoringReferences(callArg, funcArgType); - if (prev != -1) - return prev; - - const Type *callType = callArg->GetType(); - const AtomicType *callAt = dynamic_cast(callType); - const AtomicType *funcAt = dynamic_cast(funcArgType); - if (callAt == NULL || funcAt == NULL) - return -1; - - if (callAt->IsUniformType() != funcAt->IsUniformType()) - return -1; - - switch (callAt->basicType) { - case AtomicType::TYPE_BOOL: - return 1; - case AtomicType::TYPE_INT8: - case AtomicType::TYPE_UINT8: - return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; - case AtomicType::TYPE_INT16: - case AtomicType::TYPE_UINT16: - return (funcAt->basicType != AtomicType::TYPE_BOOL && - funcAt->basicType != AtomicType::TYPE_INT8 && - funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; - case AtomicType::TYPE_INT32: - case AtomicType::TYPE_UINT32: - return (funcAt->basicType == AtomicType::TYPE_INT32 || - funcAt->basicType == AtomicType::TYPE_UINT32 || - funcAt->basicType == AtomicType::TYPE_INT64 || - funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; - case AtomicType::TYPE_FLOAT: - return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; - case AtomicType::TYPE_INT64: - case AtomicType::TYPE_UINT64: - return (funcAt->basicType == AtomicType::TYPE_INT64 || - funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; - case AtomicType::TYPE_DOUBLE: - return -1; - default: - FATAL("Unhandled atomic type"); - return -1; - } -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if the call argument type and the function argument type match if - we only do a uniform -> varying type conversion but otherwise have - exactly the same type. - */ -static int -lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) { - int prev = lMatchWithTypeWidening(callArg, funcArgType); - if (prev != -1) - return prev; - - const Type *callType = callArg->GetType(); - if (dynamic_cast(callType) == NULL) - callType = callType->GetAsNonConstType(); - - return (callType->IsUniformType() && - funcArgType->IsVaryingType() && - Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if we can type convert from the call argument type to the function - argument type, but without doing a uniform -> varying conversion. - */ -static int -lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) { - int prev = lMatchIgnoringUniform(callArg, funcArgType); - if (prev != -1) - return prev; - - Expr *te = callArg->TypeConv(funcArgType, - "function call argument", true); - if (te != NULL && - te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType()) - return 1; - else - return -1; -} - - -/** Helper function used for function overload resolution: returns a cost - of 1 if there is any type conversion that gets us from the caller - argument type to the function argument type. - */ -static int -lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) { - int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType); - if (prev != -1) - return prev; - - Expr *te = callArg->TypeConv(funcArgType, - "function call argument", true); - return (te != NULL) ? 0 : -1; -} - - -/** Given a set of potential matching functions and their associated cost, - return the one with the lowest cost, if unique. Otherwise, if multiple - functions match with the same cost, return NULL. - */ -static Symbol * -lGetBestMatch(std::vector > &matches) { - assert(matches.size() > 0); - int minCost = matches[0].first; - - for (unsigned int i = 1; i < matches.size(); ++i) - minCost = std::min(minCost, matches[i].first); - - Symbol *match = NULL; - for (unsigned int i = 0; i < matches.size(); ++i) { - if (matches[i].first == minCost) { - if (match != NULL) - // multiple things had the same cost - return NULL; - else - match = matches[i].second; - } - } - return match; -} - - -/** See if we can find a single function from the set of overload options - based on the predicate function passed in. Returns true if no more - tries should be made to find a match, either due to success from - finding a single overloaded function that matches or failure due to - finding multiple ambiguous matches. - */ -bool -FunctionCallExpr::tryResolve(int (*matchFunc)(Expr *, const Type *)) { - FunctionSymbolExpr *fse = dynamic_cast(func); - - const char *funName = fse->candidateFunctions->front()->name.c_str(); - std::vector &callArgs = args->exprs; - - std::vector > matches; - std::vector::iterator iter; - for (iter = fse->candidateFunctions->begin(); - iter != fse->candidateFunctions->end(); ++iter) { - // Loop over the set of candidate functions and try each one - Symbol *candidateFunction = *iter; - const FunctionType *ft = - dynamic_cast(candidateFunction->type); - assert(ft != NULL); - const std::vector &funcArgTypes = ft->GetArgumentTypes(); - const std::vector &argumentDefaults = ft->GetArgumentDefaults(); - - // There's no way to match if the caller is passing more arguments - // than this function instance takes. - if (callArgs.size() > funcArgTypes.size()) - continue; - - unsigned int i; - // Note that we're looping over the caller arguments, not the - // function arguments; it may be ok to have more arguments to the - // function than are passed, if the function has default argument - // values. This case is handled below. - int cost = 0; - for (i = 0; i < callArgs.size(); ++i) { - // This may happen if there's an error earlier in compilation. - // It's kind of a silly to redundantly discover this for each - // potential match versus detecting this earlier in the - // matching process and just giving up. - if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] || - dynamic_cast(callArgs[i]->GetType()) != NULL) - return false; - - int argCost = matchFunc(callArgs[i], funcArgTypes[i]); - if (argCost == -1) - // If the predicate function returns -1, we have failed no - // matter what else happens, so we stop trying - break; - cost += argCost; - } - if (i == callArgs.size()) { - // All of the arguments matched! - if (i == funcArgTypes.size()) - // And we have exactly as many arguments as the function - // wants, so we're done. - matches.push_back(std::make_pair(cost, candidateFunction)); - else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL) - // Otherwise we can still make it if there are default - // arguments for the rest of the arguments! Because in - // Module::AddFunction() we have verified that once the - // default arguments start, then all of the following ones - // have them as well. Therefore, we just need to check if - // the arg we stopped at has a default value and we're - // done. - matches.push_back(std::make_pair(cost, candidateFunction)); - // otherwise, we don't have a match - } - } - - if (matches.size() == 0) - return false; - else if ((fse->matchingFunc = lGetBestMatch(matches)) != NULL) { - // We have a match--fill in with any default argument values - // needed. - const FunctionType *ft = - dynamic_cast(fse->matchingFunc->type); - assert(ft != NULL); - const std::vector &argumentDefaults = ft->GetArgumentDefaults(); - const std::vector &argTypes = ft->GetArgumentTypes(); - assert(argumentDefaults.size() == argTypes.size()); - for (unsigned int i = callArgs.size(); i < argTypes.size(); ++i) { - assert(argumentDefaults[i] != NULL); - args->exprs.push_back(argumentDefaults[i]); - } - return true; - } - else { - Error(fse->pos, "Multiple overloaded instances of function \"%s\" matched.", - funName); - lPrintFunctionOverloads(funName, matches); - lPrintPassedTypes(funName, args->exprs); - // Stop trying to find more matches after failure - return true; - } -} - - -void -FunctionCallExpr::resolveFunctionOverloads(bool exactMatchOnly) { - FunctionSymbolExpr *fse = dynamic_cast(func); - if (!fse) - // error will be issued later if not calling an actual function - return; - assert(args); - - // Is there an exact match that doesn't require any argument type - // conversion (other than converting type -> reference type)? - if (tryResolve(lExactMatch)) - return; - - if (!exactMatchOnly) { - // Try to find a single match ignoring references - if (tryResolve(lMatchIgnoringReferences)) - return; - - // Try to find an exact match via type widening--i.e. int8 -> - // int16, etc.--things that don't lose data. - if (tryResolve(lMatchWithTypeWidening)) - return; - - // Next try to see if there's a match via just uniform -> varying - // promotions. - if (tryResolve(lMatchIgnoringUniform)) - return; - - // Try to find a match via type conversion, but don't change - // unif->varying - if (tryResolve(lMatchWithTypeConvSameVariability)) - return; - - // Last chance: try to find a match via arbitrary type conversion. - if (tryResolve(lMatchWithTypeConv)) - return; - } - - // failure :-( - const char *funName = fse->candidateFunctions->front()->name.c_str(); - Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", - funName, exactMatchOnly ? " only considering exact matches" : ""); - lPrintFunctionOverloads(funName, *fse->candidateFunctions); - lPrintPassedTypes(funName, args->exprs); -} - - FunctionCallExpr::FunctionCallExpr(Expr *f, ExprList *a, SourcePos p, bool il, Expr *lce) : Expr(p), isLaunch(il) { func = f; args = a; launchCountExpr = lce; - - FunctionSymbolExpr *fse = dynamic_cast(func); - // Functions with names that start with "__" should only be various - // builtins. For those, we'll demand an exact match, since we'll - // expect whichever function in stdlib.ispc is calling out to one of - // those to be matching the argument types exactly; this is to be a bit - // extra safe to be sure that the expected builtin is in fact being - // called. - bool exactMatchOnly = (fse != NULL) && (fse->name.substr(0,2) == "__"); - resolveFunctionOverloads(exactMatchOnly); } @@ -2481,18 +2045,16 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { ctx->SetDebugPos(pos); FunctionSymbolExpr *fse = dynamic_cast(func); - if (!fse) { - Error(pos, "No valid function available for function call."); - return NULL; - } + assert(fse != NULL); // should be caught during typechecking - if (!fse->matchingFunc) - // no overload match was found, get out of here.. + Symbol *funSym = fse->GetMatchingFunction(); + if (funSym == NULL) + // No match was found; an error should have been issued earlier, so + // just return. return NULL; - Symbol *funSym = fse->matchingFunc; llvm::Function *callee = funSym->function; - if (!callee) { + if (callee == NULL) { Error(pos, "Symbol \"%s\" is not a function.", funSym->name.c_str()); return NULL; } @@ -2510,7 +2072,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { bool err = false; for (unsigned int i = 0; i < callargs.size(); ++i) { Expr *argExpr = callargs[i]; - if (!argExpr) + if (argExpr == NULL) continue; // All arrays should already have been converted to reference types @@ -2546,6 +2108,19 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { if (err) return NULL; + // Fill in any default argument values needed. + // FIXME: should we do this during type checking? + const std::vector &argumentDefaults = ft->GetArgumentDefaults(); + for (unsigned int i = callargs.size(); i < argumentDefaults.size(); ++i) { + assert(argumentDefaults[i] != NULL); + Expr *defaultExpr = argumentDefaults[i]->TypeConv(argTypes[i], + "function call default argument"); + if (defaultExpr == NULL) + return NULL; + + callargs.push_back(defaultExpr); + } + // Now evaluate the values of all of the parameters being passed. We // need to evaluate these first here, since their GetValue() calls may // change the current basic block (e.g. if one of these is itself a @@ -2594,6 +2169,7 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { } } + llvm::Value *retVal = NULL; ctx->SetDebugPos(pos); if (ft->isTask) { @@ -2641,14 +2217,16 @@ FunctionCallExpr::GetValue(FunctionEmitContext *ctx) const { const Type * FunctionCallExpr::GetType() const { FunctionSymbolExpr *fse = dynamic_cast(func); - if (fse && fse->matchingFunc) { - const FunctionType *ft = - dynamic_cast(fse->matchingFunc->type); - assert(ft != NULL); - return ft->GetReturnType(); - } - else + if (fse == NULL) return NULL; + + Symbol *sym = fse->GetMatchingFunction(); + if (sym == NULL) + return NULL; + + const FunctionType *ft = dynamic_cast(sym->type); + assert(ft != NULL); + return ft->GetReturnType(); } @@ -2667,40 +2245,51 @@ FunctionCallExpr::Optimize() { Expr * FunctionCallExpr::TypeCheck() { - if (func) { - func = func->TypeCheck(); - if (func != NULL) { - const FunctionType *ft = dynamic_cast(func->GetType()); - if (ft != NULL) { - if (ft->isTask) { - if (!isLaunch) - Error(pos, "\"launch\" expression needed to call function " - "with \"task\" qualifier."); - if (!launchCountExpr) - return NULL; + if (args != NULL) + args = args->TypeCheck(); - launchCountExpr = - launchCountExpr->TypeConv(AtomicType::UniformInt32, - "task launch count"); - if (!launchCountExpr) - return NULL; - } - else { - if (isLaunch) - Error(pos, "\"launch\" expression illegal with non-\"task\"-" - "qualified function."); - assert(launchCountExpr == NULL); + if (args != NULL && func != NULL) { + FunctionSymbolExpr *fse = dynamic_cast(func); + + if (fse == NULL) { + Error(pos, "No valid function available for function call."); + return NULL; + } + + if (fse->ResolveOverloads(args->exprs) == true) { + func = fse->TypeCheck(); + + if (func != NULL) { + const FunctionType *ft = + dynamic_cast(func->GetType()); + if (ft != NULL) { + if (ft->isTask) { + if (!isLaunch) + Error(pos, "\"launch\" expression needed to call function " + "with \"task\" qualifier."); + if (!launchCountExpr) + return NULL; + + launchCountExpr = + launchCountExpr->TypeConv(AtomicType::UniformInt32, + "task launch count"); + if (!launchCountExpr) + return NULL; + } + else { + if (isLaunch) + Error(pos, "\"launch\" expression illegal with non-\"task\"-" + "qualified function."); + assert(launchCountExpr == NULL); + } } + else + Error(pos, "Valid function name must be used for function call."); } - else - Error(pos, "Valid function name must be used for function call."); } } - if (args) - args = args->TypeCheck(); - - if (!func || !args) + if (func == NULL || args == NULL) return NULL; return this; } @@ -4324,10 +3913,11 @@ ConstExpr::Print() const { /////////////////////////////////////////////////////////////////////////// // TypeCastExpr -TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p) +TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, bool pu, SourcePos p) : Expr(p) { type = t; expr = e; + preserveUniformity = pu; } @@ -5029,6 +4619,13 @@ TypeCastExpr::TypeCheck() { if (toType == NULL || fromType == NULL) return NULL; + if (preserveUniformity == true && fromType->IsUniformType() && + toType->IsVaryingType()) { + TypeCastExpr *tce = new TypeCastExpr(toType->GetAsUniformType(), + expr, false, pos); + return tce->TypeCheck(); + } + const char *toTypeString = toType->GetString().c_str(); const char *fromTypeString = fromType->GetString().c_str(); @@ -5494,6 +5091,434 @@ FunctionSymbolExpr::Print() const { } + +static std::string +lGetFunctionDeclaration(const std::string &name, const FunctionType *type) { + std::string ret; + ret += type->GetReturnType()->GetString(); + ret += " "; + ret += name; + ret += "("; + + const std::vector &argTypes = type->GetArgumentTypes(); + const std::vector &argDefaults = type->GetArgumentDefaults(); + + for (unsigned int i = 0; i < argTypes.size(); ++i) { + // If the parameter is a reference to an array, just print its type + // as the array type, since we always pass arrays by reference. + if (dynamic_cast(argTypes[i]) && + dynamic_cast(argTypes[i]->GetReferenceTarget())) + ret += argTypes[i]->GetReferenceTarget()->GetString(); + else + ret += argTypes[i]->GetString(); + ret += " "; + ret += type->GetArgumentName(i); + + // Print the default value if present + if (argDefaults[i] != NULL) { + char buf[32]; + if (argTypes[i]->IsFloatType()) { + double val; + int count = argDefaults[i]->AsDouble(&val); + assert(count == 1); + sprintf(buf, " = %g", val); + } + else if (argTypes[i]->IsBoolType()) { + bool val; + int count = argDefaults[i]->AsBool(&val); + assert(count == 1); + sprintf(buf, " = %s", val ? "true" : "false"); + } + else if (argTypes[i]->IsUnsignedType()) { + uint64_t val; + int count = argDefaults[i]->AsUInt64(&val); + assert(count == 1); +#ifdef ISPC_IS_LINUX + sprintf(buf, " = %lu", val); +#else + sprintf(buf, " = %llu", val); +#endif + } + else { + int64_t val; + int count = argDefaults[i]->AsInt64(&val); + assert(count == 1); +#ifdef ISPC_IS_LINUX + sprintf(buf, " = %ld", val); +#else + sprintf(buf, " = %lld", val); +#endif + } + ret += buf; + } + if (i != argTypes.size() - 1) + ret += ", "; + } + ret += ")"; + return ret; +} + + +static void +lPrintFunctionOverloads(const std::string &name, + const std::vector > &matches) { + fprintf(stderr, "Matching functions:\n"); + int minCost = matches[0].first; + for (unsigned int i = 1; i < matches.size(); ++i) + minCost = std::min(minCost, matches[i].first); + + for (unsigned int i = 0; i < matches.size(); ++i) { + const FunctionType *t = + dynamic_cast(matches[i].second->type); + assert(t != NULL); + if (matches[i].first == minCost) + fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); + } +} + + +static void +lPrintFunctionOverloads(const std::string &name, + const std::vector &funcs) { + fprintf(stderr, "Candidate functions:\n"); + for (unsigned int i = 0; i < funcs.size(); ++i) { + const FunctionType *t = + dynamic_cast(funcs[i]->type); + assert(t != NULL); + fprintf(stderr, "\t%s\n", lGetFunctionDeclaration(name, t).c_str()); + } +} + + +static void +lPrintPassedTypes(const char *funName, const std::vector &argExprs) { + fprintf(stderr, "Passed types: %*c(", (int)strlen(funName), ' '); + for (unsigned int i = 0; i < argExprs.size(); ++i) { + const Type *t; + if (argExprs[i] != NULL && (t = argExprs[i]->GetType()) != NULL) + fprintf(stderr, "%s%s", t->GetString().c_str(), + (i < argExprs.size()-1) ? ", " : ")\n\n"); + else + fprintf(stderr, "(unknown type)%s", + (i < argExprs.size()-1) ? ", " : ")\n\n"); + } +} + + +/** Helper function used for function overload resolution: returns zero + cost if the call argument's type exactly matches the function argument + type (modulo a conversion to a const type if needed), otherwise reports + failure. + */ +static int +lExactMatch(Expr *callArg, const Type *funcArgType) { + const Type *callType = callArg->GetType(); + + if (dynamic_cast(callType) == NULL) + callType = callType->GetAsNonConstType(); + if (dynamic_cast(funcArgType) != NULL && + dynamic_cast(callType) == NULL) + callType = new ReferenceType(callType, funcArgType->IsConstType()); + + return Type::Equal(callType, funcArgType) ? 0 : -1; +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if the call argument type and the function argument type match, + modulo conversion to a reference type if needed. + */ +static int +lMatchIgnoringReferences(Expr *callArg, const Type *funcArgType) { + int prev = lExactMatch(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType()->GetReferenceTarget(); + if (funcArgType->IsConstType()) + callType = callType->GetAsConstType(); + + return Type::Equal(callType, + funcArgType->GetReferenceTarget()) ? 1 : -1; +} + +/** Helper function used for function overload resolution: returns a cost + of 1 if converting the argument to the call type only requires a type + conversion that won't lose information. Otherwise reports failure. +*/ +static int +lMatchWithTypeWidening(Expr *callArg, const Type *funcArgType) { + int prev = lMatchIgnoringReferences(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType(); + const AtomicType *callAt = dynamic_cast(callType); + const AtomicType *funcAt = dynamic_cast(funcArgType); + if (callAt == NULL || funcAt == NULL) + return -1; + + if (callAt->IsUniformType() != funcAt->IsUniformType()) + return -1; + + switch (callAt->basicType) { + case AtomicType::TYPE_BOOL: + return 1; + case AtomicType::TYPE_INT8: + case AtomicType::TYPE_UINT8: + return (funcAt->basicType != AtomicType::TYPE_BOOL) ? 1 : -1; + case AtomicType::TYPE_INT16: + case AtomicType::TYPE_UINT16: + return (funcAt->basicType != AtomicType::TYPE_BOOL && + funcAt->basicType != AtomicType::TYPE_INT8 && + funcAt->basicType != AtomicType::TYPE_UINT8) ? 1 : -1; + case AtomicType::TYPE_INT32: + case AtomicType::TYPE_UINT32: + return (funcAt->basicType == AtomicType::TYPE_INT32 || + funcAt->basicType == AtomicType::TYPE_UINT32 || + funcAt->basicType == AtomicType::TYPE_INT64 || + funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + case AtomicType::TYPE_FLOAT: + return (funcAt->basicType == AtomicType::TYPE_DOUBLE) ? 1 : -1; + case AtomicType::TYPE_INT64: + case AtomicType::TYPE_UINT64: + return (funcAt->basicType == AtomicType::TYPE_INT64 || + funcAt->basicType == AtomicType::TYPE_UINT64) ? 1 : -1; + case AtomicType::TYPE_DOUBLE: + return -1; + default: + FATAL("Unhandled atomic type"); + return -1; + } +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if the call argument type and the function argument type match if + we only do a uniform -> varying type conversion but otherwise have + exactly the same type. + */ +static int +lMatchIgnoringUniform(Expr *callArg, const Type *funcArgType) { + int prev = lMatchWithTypeWidening(callArg, funcArgType); + if (prev != -1) + return prev; + + const Type *callType = callArg->GetType(); + if (dynamic_cast(callType) == NULL) + callType = callType->GetAsNonConstType(); + + return (callType->IsUniformType() && + funcArgType->IsVaryingType() && + Type::Equal(callType->GetAsVaryingType(), funcArgType)) ? 1 : -1; +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if we can type convert from the call argument type to the function + argument type, but without doing a uniform -> varying conversion. + */ +static int +lMatchWithTypeConvSameVariability(Expr *callArg, const Type *funcArgType) { + int prev = lMatchIgnoringUniform(callArg, funcArgType); + if (prev != -1) + return prev; + + Expr *te = callArg->TypeConv(funcArgType, + "function call argument", true); + if (te != NULL && + te->GetType()->IsUniformType() == callArg->GetType()->IsUniformType()) + return 1; + else + return -1; +} + + +/** Helper function used for function overload resolution: returns a cost + of 1 if there is any type conversion that gets us from the caller + argument type to the function argument type. + */ +static int +lMatchWithTypeConv(Expr *callArg, const Type *funcArgType) { + int prev = lMatchWithTypeConvSameVariability(callArg, funcArgType); + if (prev != -1) + return prev; + + Expr *te = callArg->TypeConv(funcArgType, + "function call argument", true); + return (te != NULL) ? 0 : -1; +} + + +/** Given a set of potential matching functions and their associated cost, + return the one with the lowest cost, if unique. Otherwise, if multiple + functions match with the same cost, return NULL. + */ +static Symbol * +lGetBestMatch(std::vector > &matches) { + assert(matches.size() > 0); + int minCost = matches[0].first; + + for (unsigned int i = 1; i < matches.size(); ++i) + minCost = std::min(minCost, matches[i].first); + + Symbol *match = NULL; + for (unsigned int i = 0; i < matches.size(); ++i) { + if (matches[i].first == minCost) { + if (match != NULL) + // multiple things had the same cost + return NULL; + else + match = matches[i].second; + } + } + return match; +} + + +/** See if we can find a single function from the set of overload options + based on the predicate function passed in. Returns true if no more + tries should be made to find a match, either due to success from + finding a single overloaded function that matches or failure due to + finding multiple ambiguous matches. + */ +bool +FunctionSymbolExpr::tryResolve(int (*matchFunc)(Expr *, const Type *), + const std::vector &callArgs) { + const char *funName = candidateFunctions->front()->name.c_str(); + + std::vector > matches; + std::vector::iterator iter; + for (iter = candidateFunctions->begin(); + iter != candidateFunctions->end(); ++iter) { + // Loop over the set of candidate functions and try each one + Symbol *candidateFunction = *iter; + const FunctionType *ft = + dynamic_cast(candidateFunction->type); + assert(ft != NULL); + const std::vector &funcArgTypes = ft->GetArgumentTypes(); + const std::vector &argumentDefaults = ft->GetArgumentDefaults(); + + // There's no way to match if the caller is passing more arguments + // than this function instance takes. + if (callArgs.size() > funcArgTypes.size()) + continue; + + unsigned int i; + // Note that we're looping over the caller arguments, not the + // function arguments; it may be ok to have more arguments to the + // function than are passed, if the function has default argument + // values. This case is handled below. + int cost = 0; + for (i = 0; i < callArgs.size(); ++i) { + // This may happen if there's an error earlier in compilation. + // It's kind of a silly to redundantly discover this for each + // potential match versus detecting this earlier in the + // matching process and just giving up. + if (!callArgs[i] || !callArgs[i]->GetType() || !funcArgTypes[i] || + dynamic_cast(callArgs[i]->GetType()) != NULL) + return false; + + int argCost = matchFunc(callArgs[i], funcArgTypes[i]); + if (argCost == -1) + // If the predicate function returns -1, we have failed no + // matter what else happens, so we stop trying + break; + cost += argCost; + } + if (i == callArgs.size()) { + // All of the arguments matched! + if (i == funcArgTypes.size()) + // And we have exactly as many arguments as the function + // wants, so we're done. + matches.push_back(std::make_pair(cost, candidateFunction)); + else if (i < funcArgTypes.size() && argumentDefaults[i] != NULL) + // Otherwise we can still make it if there are default + // arguments for the rest of the arguments! Because in + // Module::AddFunction() we have verified that once the + // default arguments start, then all of the following ones + // have them as well. Therefore, we just need to check if + // the arg we stopped at has a default value and we're + // done. + matches.push_back(std::make_pair(cost, candidateFunction)); + // otherwise, we don't have a match + } + } + + if (matches.size() == 0) + return false; + else if ((matchingFunc = lGetBestMatch(matches)) != NULL) + // We have a match! + return true; + else { + Error(pos, "Multiple overloaded instances of function \"%s\" matched.", + funName); + lPrintFunctionOverloads(funName, matches); + lPrintPassedTypes(funName, callArgs); + // Stop trying to find more matches after an ambigious set of + // matches. + return true; + } +} + + +bool +FunctionSymbolExpr::ResolveOverloads(const std::vector &callArgs) { + // Functions with names that start with "__" should only be various + // builtins. For those, we'll demand an exact match, since we'll + // expect whichever function in stdlib.ispc is calling out to one of + // those to be matching the argument types exactly; this is to be a bit + // extra safe to be sure that the expected builtin is in fact being + // called. + bool exactMatchOnly = (name.substr(0,2) == "__"); + + // Is there an exact match that doesn't require any argument type + // conversion (other than converting type -> reference type)? + if (tryResolve(lExactMatch, callArgs)) + return true; + + if (exactMatchOnly == false) { + // Try to find a single match ignoring references + if (tryResolve(lMatchIgnoringReferences, callArgs)) + return true; + + // Try to find an exact match via type widening--i.e. int8 -> + // int16, etc.--things that don't lose data. + if (tryResolve(lMatchWithTypeWidening, callArgs)) + return true; + + // Next try to see if there's a match via just uniform -> varying + // promotions. + if (tryResolve(lMatchIgnoringUniform, callArgs)) + return true; + + // Try to find a match via type conversion, but don't change + // unif->varying + if (tryResolve(lMatchWithTypeConvSameVariability, + callArgs)) + return true; + + // Last chance: try to find a match via arbitrary type conversion. + if (tryResolve(lMatchWithTypeConv, callArgs)) + return true; + } + + // failure :-( + const char *funName = candidateFunctions->front()->name.c_str(); + Error(pos, "Unable to find matching overload for call to function \"%s\"%s.", + funName, exactMatchOnly ? " only considering exact matches" : ""); + lPrintFunctionOverloads(funName, *candidateFunctions); + lPrintPassedTypes(funName, callArgs); + return false; +} + + +Symbol * +FunctionSymbolExpr::GetMatchingFunction() { + return matchingFunc; +} + + /////////////////////////////////////////////////////////////////////////// // SyncExpr diff --git a/expr.h b/expr.h index 0ee6c80c..e571d88a 100644 --- a/expr.h +++ b/expr.h @@ -42,8 +42,6 @@ #include "ast.h" #include "type.h" -class FunctionSymbolExpr; - /** @brief Expr is the abstract base class that defines the interface that all expression types must implement. */ @@ -266,10 +264,6 @@ public: ExprList *args; bool isLaunch; Expr *launchCountExpr; - -private: - void resolveFunctionOverloads(bool exactMatchOnly); - bool tryResolve(int (*matchFunc)(Expr *, const Type *)); }; @@ -495,7 +489,8 @@ private: probably-different type. */ class TypeCastExpr : public Expr { public: - TypeCastExpr(const Type *t, Expr *e, SourcePos p); + TypeCastExpr(const Type *t, Expr *e, bool preserveUniformity, + SourcePos p); llvm::Value *GetValue(FunctionEmitContext *ctx) const; const Type *GetType() const; @@ -506,6 +501,7 @@ public: const Type *type; Expr *expr; + bool preserveUniformity; }; @@ -581,8 +577,12 @@ public: void Print() const; int EstimateCost() const; + bool ResolveOverloads(const std::vector &args); + Symbol *GetMatchingFunction(); + private: - friend class FunctionCallExpr; + bool tryResolve(int (*matchFunc)(Expr *, const Type *), + const std::vector &args); /** Name of the function that is being called. */ std::string name; @@ -592,8 +592,7 @@ private: overload is the best match. */ std::vector *candidateFunctions; - /** The actual matching function found after overload resolution; this - value is set by FunctionCallExpr::resolveFunctionOverloads() */ + /** The actual matching function found after overload resolution. */ Symbol *matchingFunc; }; diff --git a/parse.yy b/parse.yy index 53c97b9b..c7ac848b 100644 --- a/parse.yy +++ b/parse.yy @@ -326,19 +326,13 @@ cast_expression : unary_expression | '(' type_name ')' cast_expression { - // If type_name isn't explicitly a varying, We do a GetUniform() - // call here so that things like: + // Pass true here to try to preserve uniformity + // so that things like: // uniform int y = ...; // uniform float x = 1. / (float)y; // don't issue an error due to (float)y being inadvertently // and undesirably-to-the-user "varying"... - if ($2 == NULL || $4 == NULL || $4->GetType() == NULL) - $$ = NULL; - else { - if ($4->GetType()->IsUniformType()) - $2 = $2->GetAsUniformType(); - $$ = new TypeCastExpr($2, $4, @1); - } + $$ = new TypeCastExpr($2, $4, true, @1); } ; @@ -1463,7 +1457,7 @@ lFinalizeEnumeratorSymbols(std::vector &enums, the actual enum type here and optimize it, which will have us end up with a ConstExpr with the desired EnumType... */ Expr *castExpr = new TypeCastExpr(enumType, enums[i]->constValue, - enums[i]->pos); + false, enums[i]->pos); castExpr = castExpr->Optimize(); enums[i]->constValue = dynamic_cast(castExpr); assert(enums[i]->constValue != NULL); diff --git a/stmt.cpp b/stmt.cpp index c2673112..1be6d3a6 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -519,15 +519,16 @@ Stmt *IfStmt::TypeCheck() { if (test) { const Type *testType = test->GetType(); if (testType) { - bool isUniform = testType->IsUniformType() && !g->opt.disableUniformControlFlow; + bool isUniform = (testType->IsUniformType() && + !g->opt.disableUniformControlFlow); if (!testType->IsNumericType() && !testType->IsBoolType()) { - Error(test->pos, "Type \"%s\" can't be converted to boolean for \"if\" test.", - testType->GetString().c_str()); + Error(test->pos, "Type \"%s\" can't be converted to boolean " + "for \"if\" test.", testType->GetString().c_str()); return NULL; } test = new TypeCastExpr(isUniform ? AtomicType::UniformBool : - AtomicType::VaryingBool, - test, test->pos); + AtomicType::VaryingBool, + test, false, test->pos); assert(test); } } @@ -1171,7 +1172,7 @@ DoStmt::TypeCheck() { !lHasVaryingBreakOrContinue(bodyStmts)); testExpr = new TypeCastExpr(uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool, - testExpr, testExpr->pos); + testExpr, false, testExpr->pos); } } } @@ -1373,7 +1374,7 @@ ForStmt::TypeCheck() { !lHasVaryingBreakOrContinue(stmts)); test = new TypeCastExpr(uniformTest ? AtomicType::UniformBool : AtomicType::VaryingBool, - test, test->pos); + test, false, test->pos); } } } @@ -1685,7 +1686,7 @@ lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, std::string &argTypes) { baseType == AtomicType::UniformUInt16) { expr = new TypeCastExpr(type->IsUniformType() ? AtomicType::UniformInt32 : AtomicType::VaryingInt32, - expr, expr->pos); + expr, false, expr->pos); type = expr->GetType(); } @@ -1908,7 +1909,7 @@ AssertStmt::TypeCheck() { } expr = new TypeCastExpr(isUniform ? AtomicType::UniformBool : AtomicType::VaryingBool, - expr, expr->pos); + expr, false, expr->pos); } } return this; diff --git a/type.cpp b/type.cpp index bf56904a..90c713c2 100644 --- a/type.cpp +++ b/type.cpp @@ -1887,6 +1887,15 @@ FunctionType::SetArgumentDefaults(const std::vector &d) const { } +std::string +FunctionType::GetArgumentName(int i) const { + if (i >= (int)argNames.size()) + return ""; + else + return argNames[i]; +} + + /////////////////////////////////////////////////////////////////////////// // Type diff --git a/type.h b/type.h index a985dc2e..56eacc95 100644 --- a/type.h +++ b/type.h @@ -677,7 +677,7 @@ public: const std::vector &GetArgumentTypes() const { return argTypes; } const std::vector &GetArgumentDefaults() const { return argDefaults; } - const std::string &GetArgumentName(int i) const { return argNames[i]; } + std::string GetArgumentName(int i) const; /** @todo It would be nice to pull this information together and pass it when the constructor is called; it's kind of ugly to set it like