diff --git a/ast.cpp b/ast.cpp index 506aa1e2..064226b4 100644 --- a/ast.cpp +++ b/ast.cpp @@ -58,16 +58,18 @@ ASTNode::~ASTNode() { // AST void -AST::AddFunction(Symbol *sym, Stmt *code) { +AST::AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable) { if (sym == NULL) return; Function *f = new Function(sym, code); if (f->IsPolyFunction()) { - std::vector *expanded = f->ExpandPolyArguments(); - for (size_t i=0; isize(); i++) + std::vector *expanded = f->ExpandPolyArguments(symbolTable); + for (size_t i=0; isize(); i++) { functions.push_back((*expanded)[i]); + } + delete expanded; } else { functions.push_back(f); } @@ -540,5 +542,5 @@ TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) data.polyType = polyType; data.replacement = replacement; - return WalkAST(root, NULL, lTranslatePolyNode, &replacement); + return WalkAST(root, NULL, lTranslatePolyNode, &data); } diff --git a/ast.h b/ast.h index b13120d8..91d3dd5b 100644 --- a/ast.h +++ b/ast.h @@ -148,7 +148,7 @@ class AST { public: /** Add the AST for a function described by the given declaration information and source code. */ - void AddFunction(Symbol *sym, Stmt *code); + void AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable=NULL); /** Generate LLVM IR for all of the functions into the current module. */ @@ -207,6 +207,8 @@ extern Stmt *TypeCheck(Stmt *); the given root. */ extern int EstimateCost(ASTNode *root); +extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement); + /** Returns true if it would be safe to run the given code with an "all off" mask. */ extern bool SafeToRunWithMaskAllOff(ASTNode *root); diff --git a/expr.cpp b/expr.cpp index 9412ed12..04a28d40 100644 --- a/expr.cpp +++ b/expr.cpp @@ -4679,11 +4679,11 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (index == NULL || baseExpr == NULL) return NULL; - if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { type = PolyType::ReplaceType(type, to); } - if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { lvalueType = new PointerType(to, lvalueType->GetVariability(), lvalueType->IsConstType()); } @@ -5338,11 +5338,11 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (expr == NULL) return NULL; - if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { type = PolyType::ReplaceType(type, to); } - if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { lvalueType = PolyType::ReplaceType(lvalueType, lvalueType); } @@ -7386,10 +7386,10 @@ TypeCastExpr::Optimize() { Expr * TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) { - if (expr == NULL) + if (type == NULL) return NULL; - if (Type::EqualIgnoringConst(type->GetBaseType(), from)) { + if (Type::EqualForReplacement(type->GetBaseType(), from)) { type = PolyType::ReplaceType(type, to); } @@ -8071,7 +8071,7 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!symbol) return NULL; - if (Type::EqualIgnoringConst(symbol->type->GetBaseType(), from)) { + if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) { symbol->type = PolyType::ReplaceType(symbol->type, to); } @@ -8881,7 +8881,7 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!allocType) return this; - if (Type::EqualIgnoringConst(allocType->GetBaseType(), from)) { + if (Type::EqualForReplacement(allocType->GetBaseType(), from)) { allocType = PolyType::ReplaceType(allocType, to); } diff --git a/func.cpp b/func.cpp index b69f2280..fa1e030d 100644 --- a/func.cpp +++ b/func.cpp @@ -640,87 +640,45 @@ Function::IsPolyFunction() const { return false; } -static bool -lPolyTypeLess(const Type *a, const Type *b) { - const PolyType *pa = CastType(a->GetBaseType()); - const PolyType *pb = CastType(b->GetBaseType()); - - if (!pa || !pb) { - char buf[1024]; - snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types" - "\"%s\" and \"%s\"\n", - a->GetString().c_str(), b->GetString().c_str()); - FATAL(buf); - } - - - if (pa->restriction < pb->restriction) - return true; - if (pa->restriction > pb->restriction) - return false; - - if (pa->GetQuant() < pb->GetQuant()) - return true; - - return false; -} std::vector * -Function::ExpandPolyArguments() const { - std::set toExpand(&lPolyTypeLess); +Function::ExpandPolyArguments(SymbolTable *symbolTable) const { + Assert(symbolTable != NULL); + std::vector *expanded = new std::vector(); - for (size_t i = 0; i < args.size(); i++) { - if (args[i]->type->IsPolymorphicType() && - !toExpand.count(args[i]->type)) { - toExpand.insert(args[i]->type); - } - } + std::vector versions = symbolTable->LookupPolyFunction(sym->name.c_str()); - std::set::iterator te; - for (te = toExpand.begin(); te != toExpand.end(); te++) { - const PolyType *pt = CastType((*te)->GetBaseType()); + const FunctionType *func = CastType(sym->type); - std::vector::iterator expand; - expand = pt->ExpandBegin(); - for (; expand != pt->ExpandEnd(); expand++) { - const Type *replacement = *expand; - Stmt *code_r = code->ReplacePolyType(pt, replacement); + printf("%s before replacing anything:\n", sym->name.c_str()); + code->Print(0); - const FunctionType *ft = CastType(sym->type); - llvm::SmallVector nargs; - llvm::SmallVector nargsn; - llvm::SmallVector nargsd; - llvm::SmallVector nargsp; - for (size_t i = 0; i < args.size(); i++) { - if (Type::EqualIgnoringConst(args[i]->type->GetBaseType(), pt)) { - nargs.push_back(PolyType::ReplaceType(args[i]->type, replacement)); - } else { - nargs.push_back(args[i]->type); - } - nargsn.push_back(ft->GetParameterName(i)); - nargsd.push_back(ft->GetParameterDefault(i)); - nargsp.push_back(ft->GetParameterSourcePos(i)); + for (size_t i=0; i(versions[i]->type); + Stmt *ncode = code; + + for (int j=0; jGetNumParameters(); j++) { + if (func->GetParameterType(j)->IsPolymorphicType()) { + const PolyType *from = CastType( + func->GetParameterType(j)->GetBaseType()); + + ncode = (Stmt*)TranslatePoly(ncode, from, + ft->GetParameterType(j)->GetBaseType()); + printf("%s after replacing %s with %s:\n\n", + sym->name.c_str(), from->GetString().c_str(), + ft->GetParameterType(j)->GetBaseType()->GetString().c_str()); + + ncode->Print(0); + + printf("------------------------------------------\n\n"); } - - - Symbol *nsym = new Symbol(sym->name, sym->pos, - new FunctionType(ft->GetReturnType(), - nargs, - nargsn, - nargsd, - nargsp, - ft->isTask, - ft->isExported, - ft->isExternC, - ft->isUnmasked)); - nsym->function = sym->function; - nsym->exportedFunction = sym->exportedFunction; - - expanded->push_back(new Function(nsym, code_r)); - - replacement = PolyType::ReplaceType(*te, replacement); } + + Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft); + + expanded->push_back(new Function(s, ncode)); } + return expanded; } diff --git a/func.h b/func.h index 2ac9cc90..86d801f4 100644 --- a/func.h +++ b/func.h @@ -39,6 +39,7 @@ #define ISPC_FUNC_H 1 #include "ispc.h" +#include "sym.h" #include class Function { @@ -54,7 +55,7 @@ public: /** Checks if the function has polymorphic parameters */ const bool IsPolyFunction() const; - std::vector *ExpandPolyArguments() const; + std::vector *ExpandPolyArguments(SymbolTable *symbolTable) const; private: void emitCode(FunctionEmitContext *ctx, llvm::Function *function, diff --git a/module.cpp b/module.cpp index bf9843ac..85b119ff 100644 --- a/module.cpp +++ b/module.cpp @@ -915,8 +915,6 @@ Module::AddFunctionDeclaration(const std::string &name, SourcePos pos) { Assert(functionType != NULL); - fprintf(stderr, "Adding %s\n", name.c_str()); - // If a global variable with the same name has already been declared // issue an error. if (symbolTable->LookupVariable(name.c_str()) != NULL) { @@ -1020,26 +1018,26 @@ Module::AddFunctionDeclaration(const std::string &name, * these functions will be overloaded if they are not exported, or mangled * if exported */ - std::vector toExpand; + std::set toExpand(&PolyType::Less); std::vector expanded; expanded.push_back(functionType); for (int i=0; iGetNumParameters(); i++) { - if (functionType->GetParameterType(i)->IsPolymorphicType()) { - fprintf(stderr, "Expanding polymorphic function \"%s\"\n", - name.c_str()); + const Type *param = functionType->GetParameterType(i); + if (param->IsPolymorphicType() && + !toExpand.count(param->GetBaseType())) { - toExpand.push_back(i); + toExpand.insert(param->GetBaseType()); } } std::vector nextExpanded; - for (size_t i=0; i::iterator iter; + for (iter = toExpand.begin(); iter != toExpand.end(); iter++) { for (size_t j=0; j( - eft->GetParameterType(toExpand[i])->GetBaseType()); + const PolyType *pt=CastType(*iter); std::vector::iterator te; for (te = pt->ExpandBegin(); te != pt->ExpandEnd(); te++) { @@ -1048,9 +1046,10 @@ Module::AddFunctionDeclaration(const std::string &name, llvm::SmallVector nargsd; llvm::SmallVector nargsp; for (size_t k=0; kGetNumParameters(); k++) { - if (k == toExpand[i]) { + if (Type::Equal(eft->GetParameterType(k)->GetBaseType(), + pt)) { const Type *r; - r = PolyType::ReplaceType(eft->GetParameterType(j),*te); + r = PolyType::ReplaceType(eft->GetParameterType(k),*te); nargs.push_back(r); } else { nargs.push_back(eft->GetParameterType(k)); @@ -1087,8 +1086,8 @@ Module::AddFunctionDeclaration(const std::string &name, } } - fprintf(stderr, "Adding expanded function %s\n", nname.c_str()); + symbolTable->MapPolyFunction(name, nname, expanded[i]); AddFunctionDeclaration(nname, expanded[i], storageClass, isInline, pos); } @@ -1263,14 +1262,7 @@ Module::AddFunctionDefinition(const std::string &name, const FunctionType *type, sym->pos = code->pos; - // FIXME: because we encode the parameter names in the function type, - // we need to override the function type here in case the function had - // earlier been declared with anonymous parameter names but is now - // defined with actual names. This is yet another reason we shouldn't - // include the names in FunctionType... - sym->type = type; - - ast->AddFunction(sym, code); + ast->AddFunction(sym, code, symbolTable); } diff --git a/stmt.cpp b/stmt.cpp index 9b03d340..f03944e8 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -503,7 +503,7 @@ Stmt * DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) { for (size_t i = 0; i < vars.size(); i++) { Symbol *s = vars[i].sym; - if (Type::EqualIgnoringConst(s->type->GetBaseType(), from)) { + if (Type::EqualForReplacement(s->type->GetBaseType(), from)) { s->type = PolyType::ReplaceType(s->type, to); } } @@ -2198,7 +2198,7 @@ ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) { for (size_t i=0; itype; - if (Type::EqualIgnoringConst(t->GetBaseType(), from)) { + if (Type::EqualForReplacement(t->GetBaseType(), from)) { t = PolyType::ReplaceType(t, to); } } diff --git a/sym.cpp b/sym.cpp index 396ec488..88777d5a 100644 --- a/sym.cpp +++ b/sym.cpp @@ -157,6 +157,14 @@ SymbolTable::AddFunction(Symbol *symbol) { return true; } +void +SymbolTable::MapPolyFunction(std::string name, std::string polyname, + const FunctionType *type) { + std::vector &polyExpansions = polyFunctions[name]; + SourcePos p; + polyExpansions.push_back(new Symbol(polyname, p, type, SC_NONE)); +} + bool SymbolTable::LookupFunction(const char *name, std::vector *matches) { @@ -184,9 +192,20 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) { return funcs[j]; } } + // Try looking for a polymorphic function + if (polyFunctions[name].size() > 0) { + std::string n = name; + return new Symbol(name, polyFunctions[name][0]->pos, type); + } + return NULL; } +std::vector& +SymbolTable::LookupPolyFunction(const char *name) { + return polyFunctions[name]; +} + bool SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) { diff --git a/sym.h b/sym.h index d2955f90..802ad58c 100644 --- a/sym.h +++ b/sym.h @@ -108,6 +108,7 @@ public: }; + /** @brief Symbol table that holds all known symbols during parsing and compilation. A single instance of a SymbolTable is stored in the Module class @@ -159,6 +160,14 @@ public: already present in the symbol table. */ bool AddFunction(Symbol *symbol); + /** Adds the given function to the list of polymorphic definitions for the + given name + @param name The name of the original function + @param type The expanded FunctionType */ + + void MapPolyFunction(std::string name, std::string polyname, + const FunctionType *type); + /** Looks for the function or functions with the given name in the symbol name. If a function has been overloaded and multiple definitions are present for a given function name, all of them will @@ -174,6 +183,8 @@ public: @return pointer to matching Symbol; NULL if none is found. */ Symbol *LookupFunction(const char *name, const FunctionType *type); + std::vector& LookupPolyFunction(const char *name); + /** Returns all of the functions in the symbol table that match the given predicate. @@ -276,6 +287,8 @@ private: typedef std::map > FunctionMapType; FunctionMapType functions; + FunctionMapType polyFunctions; + /** Type definitions can't currently be scoped. */ typedef std::map TypeMapType; diff --git a/tests_ispcpp/simple.ispc b/tests_ispcpp/simple.ispc index da4642da..aed27d4f 100644 --- a/tests_ispcpp/simple.ispc +++ b/tests_ispcpp/simple.ispc @@ -1,4 +1,4 @@ -export void foo(uniform int N, floating$1 X[]) +export void foo(uniform int N, uniform floating$1 X[]) { foreach (i = 0 ... N) { X[i] = X[i] + 1.0; diff --git a/type.cpp b/type.cpp index 3340b8b5..da8851b1 100644 --- a/type.cpp +++ b/type.cpp @@ -709,6 +709,9 @@ PolyType::ReplaceType(const Type *from, const Type *to) { t = new ReferenceType(to); } + if (from->IsVaryingType()) + t = t->GetAsVaryingType(); + fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n", from->GetString().c_str(), t->GetString().c_str()); @@ -716,6 +719,31 @@ PolyType::ReplaceType(const Type *from, const Type *to) { return t; } +bool +PolyType::Less(const Type *a, const Type *b) { + const PolyType *pa = CastType(a->GetBaseType()); + const PolyType *pb = CastType(b->GetBaseType()); + + if (!pa || !pb) { + char buf[1024]; + snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types" + "\"%s\" and \"%s\"\n", + a->GetString().c_str(), b->GetString().c_str()); + FATAL(buf); + } + + + if (pa->restriction < pb->restriction) + return true; + if (pa->restriction > pb->restriction) + return false; + + if (pa->GetQuant() < pb->GetQuant()) + return true; + + return false; +} + PolyType::PolyType(PolyRestriction r, Variability v, bool ic) : Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) { asOtherConstType = NULL; @@ -4137,3 +4165,16 @@ bool Type::EqualIgnoringConst(const Type *a, const Type *b) { return lCheckTypeEquality(a, b, true); } + +bool +Type::EqualForReplacement(const Type *a, const Type *b) { + const PolyType *pa = CastType(a); + const PolyType *pb = CastType(b); + + + if (!pa || !pb) + return false; + + return pa->restriction == pb->restriction && + pa->GetQuant() == pb->GetQuant(); +} diff --git a/type.h b/type.h index 58efe1ca..69d739d2 100644 --- a/type.h +++ b/type.h @@ -244,6 +244,8 @@ public: the same (ignoring const-ness of the type), false otherwise. */ static bool EqualIgnoringConst(const Type *a, const Type *b); + static bool EqualForReplacement(const Type *a, const Type *b); + /** Given two types, returns the least general Type that is more general than both of them. (i.e. that can represent their values without any loss of data.) If there is no such Type, return NULL. @@ -415,6 +417,8 @@ public: static const Type * ReplaceType(const Type *from, const Type *to); + static bool Less(const Type *a, const Type *b); + static const PolyType *UniformInteger, *VaryingInteger; static const PolyType *UniformFloating, *VaryingFloating;