diff --git a/expr.cpp b/expr.cpp index 3b1276f5..9412ed12 100644 --- a/expr.cpp +++ b/expr.cpp @@ -4674,6 +4674,23 @@ IndexExpr::TypeCheck() { return this; } +Expr * +IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) { + if (index == NULL || baseExpr == NULL) + return NULL; + + if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) { + type = PolyType::ReplaceType(type, to); + } + + if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) { + lvalueType = new PointerType(to, lvalueType->GetVariability(), + lvalueType->IsConstType()); + } + + return this; +} + int IndexExpr::EstimateCost() const { @@ -5316,6 +5333,23 @@ MemberExpr::Optimize() { return expr ? this : NULL; } +Expr * +MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) { + if (expr == NULL) + return NULL; + + if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) { + type = PolyType::ReplaceType(type, to); + } + + if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) { + lvalueType = PolyType::ReplaceType(lvalueType, lvalueType); + } + + return this; +} + + int MemberExpr::EstimateCost() const { @@ -7118,6 +7152,9 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const { else { const AtomicType *toAtomic = CastType(toType); // typechecking should ensure this is the case + if (!toAtomic) { + fprintf(stderr, "I want %s to be atomic\n", toType->GetString().c_str()); + } AssertPos(pos, toAtomic != NULL); return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos); @@ -7347,6 +7384,18 @@ TypeCastExpr::Optimize() { return this; } +Expr * +TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) { + if (expr == NULL) + return NULL; + + if (Type::EqualIgnoringConst(type->GetBaseType(), from)) { + type = PolyType::ReplaceType(type, to); + } + + return this; +} + int TypeCastExpr::EstimateCost() const { @@ -8017,6 +8066,18 @@ SymbolExpr::Optimize() { return this; } +Expr * +SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { + if (!symbol) + return NULL; + + if (Type::EqualIgnoringConst(symbol->type->GetBaseType(), from)) { + symbol->type = PolyType::ReplaceType(symbol->type, to); + } + + return this; +} + int SymbolExpr::EstimateCost() const { @@ -8815,6 +8876,18 @@ NewExpr::Optimize() { return this; } +Expr * +NewExpr::ReplacePolyType(const PolyType *from, const Type *to) { + if (!allocType) + return this; + + if (Type::EqualIgnoringConst(allocType->GetBaseType(), from)) { + allocType = PolyType::ReplaceType(allocType, to); + } + + return this; +} + void NewExpr::Print() const { diff --git a/expr.h b/expr.h index 8b3e8db0..6b655614 100644 --- a/expr.h +++ b/expr.h @@ -328,6 +328,7 @@ public: Expr *Optimize(); Expr *TypeCheck(); + Expr *ReplacePolyType(const PolyType *from, const Type *to); int EstimateCost() const; Expr *baseExpr, *index; @@ -361,6 +362,7 @@ public: void Print() const; Expr *Optimize(); Expr *TypeCheck(); + Expr *ReplacePolyType(const PolyType *from, const Type *to); int EstimateCost() const; virtual int getElementNumber() const = 0; @@ -526,6 +528,7 @@ public: void Print() const; Expr *TypeCheck(); Expr *Optimize(); + Expr *ReplacePolyType(const PolyType *from, const Type *to); int EstimateCost() const; Symbol *GetBaseSymbol() const; llvm::Constant *GetConstant(const Type *type) const; @@ -685,6 +688,7 @@ public: Symbol *GetBaseSymbol() const; Expr *TypeCheck(); Expr *Optimize(); + Expr *ReplacePolyType(const PolyType *from, const Type *to); void Print() const; int EstimateCost() const; @@ -813,6 +817,7 @@ public: const Type *GetType() const; Expr *TypeCheck(); Expr *Optimize(); + Expr *ReplacePolyType(const PolyType *from, const Type *to); void Print() const; int EstimateCost() const; diff --git a/func.cpp b/func.cpp index d97c4dd0..b69f2280 100644 --- a/func.cpp +++ b/func.cpp @@ -45,6 +45,7 @@ #include "sym.h" #include "util.h" #include +#include #if ISPC_LLVM_VERSION == ISPC_LLVM_3_2 // 3.2 #ifdef ISPC_NVPTX_ENABLED @@ -639,41 +640,87 @@ Function::IsPolyFunction() const { return false; } +static bool +lPolyTypeLess(const Type *a, const Type *b) { + const PolyType *pa = CastType(a->GetBaseType()); + const PolyType *pb = CastType(b->GetBaseType()); + + if (!pa || !pb) { + char buf[1024]; + snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types" + "\"%s\" and \"%s\"\n", + a->GetString().c_str(), b->GetString().c_str()); + FATAL(buf); + } + + + if (pa->restriction < pb->restriction) + return true; + if (pa->restriction > pb->restriction) + return false; + + if (pa->GetQuant() < pb->GetQuant()) + return true; + + return false; +} + std::vector * Function::ExpandPolyArguments() const { - std::vector toExpand; + std::set toExpand(&lPolyTypeLess); std::vector *expanded = new std::vector(); for (size_t i = 0; i < args.size(); i++) { - if (args[i]->type->IsPolymorphicType()) { - toExpand.push_back(args[i]->type); + if (args[i]->type->IsPolymorphicType() && + !toExpand.count(args[i]->type)) { + toExpand.insert(args[i]->type); } } - for (size_t i = 0; i < toExpand.size(); i++) { - const PolyType *pt = CastType(toExpand[i]->GetBaseType()); + std::set::iterator te; + for (te = toExpand.begin(); te != toExpand.end(); te++) { + const PolyType *pt = CastType((*te)->GetBaseType()); - std::vector::iterator expanded; - expanded = pt->ExpandBegin(); - for (; expanded != pt->ExpandEnd(); expanded++) { - Type *replacement = *expanded; + std::vector::iterator expand; + expand = pt->ExpandBegin(); + for (; expand != pt->ExpandEnd(); expand++) { + const Type *replacement = *expand; + Stmt *code_r = code->ReplacePolyType(pt, replacement); - if (toExpand[i]->IsPointerType()) - replacement = new PointerType(replacement, - toExpand[i]->GetVariability(), - toExpand[i]->IsConstType()); - else if (toExpand[i]->IsArrayType()) - replacement = new ArrayType(replacement, - (CastType(toExpand[i]))->GetElementCount()); - else if (toExpand[i]->IsReferenceType()) - replacement = new ReferenceType(replacement); + const FunctionType *ft = CastType(sym->type); + llvm::SmallVector nargs; + llvm::SmallVector nargsn; + llvm::SmallVector nargsd; + llvm::SmallVector nargsp; + for (size_t i = 0; i < args.size(); i++) { + if (Type::EqualIgnoringConst(args[i]->type->GetBaseType(), pt)) { + nargs.push_back(PolyType::ReplaceType(args[i]->type, replacement)); + } else { + nargs.push_back(args[i]->type); + } + nargsn.push_back(ft->GetParameterName(i)); + nargsd.push_back(ft->GetParameterDefault(i)); + nargsp.push_back(ft->GetParameterSourcePos(i)); + } - printf("pretend I'm replacing %s with %s\n", - toExpand[i]->GetString().c_str(), - replacement->GetString().c_str()); + Symbol *nsym = new Symbol(sym->name, sym->pos, + new FunctionType(ft->GetReturnType(), + nargs, + nargsn, + nargsd, + nargsp, + ft->isTask, + ft->isExported, + ft->isExternC, + ft->isUnmasked)); + nsym->function = sym->function; + nsym->exportedFunction = sym->exportedFunction; + + expanded->push_back(new Function(nsym, code_r)); + + replacement = PolyType::ReplaceType(*te, replacement); } } - return expanded; } diff --git a/type.cpp b/type.cpp index f94ac266..3340b8b5 100644 --- a/type.cpp +++ b/type.cpp @@ -694,6 +694,28 @@ const PolyType *PolyType::UniformNumber = const PolyType *PolyType::VaryingNumber = new PolyType(PolyType::TYPE_NUMBER, Variability::Varying, false); +const Type * +PolyType::ReplaceType(const Type *from, const Type *to) { + const Type *t = to; + + if (from->IsPointerType()) { + t = new PointerType(to, + from->GetVariability(), + from->IsConstType()); + } else if (from->IsArrayType()) { + t = new ArrayType(to, + CastType(from)->GetElementCount()); + } else if (from->IsReferenceType()) { + t = new ReferenceType(to); + } + + fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n", + from->GetString().c_str(), + t->GetString().c_str()); + + return t; +} + PolyType::PolyType(PolyRestriction r, Variability v, bool ic) : Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) { asOtherConstType = NULL; @@ -816,7 +838,7 @@ PolyType::GetAsUniformType() const { return asUniformType; } -const std::vector::iterator +const std::vector::iterator PolyType::ExpandBegin() const { if (expandedTypes) return expandedTypes->begin(); @@ -841,7 +863,7 @@ PolyType::ExpandBegin() const { return expandedTypes->begin(); } -const std::vector::iterator +const std::vector::iterator PolyType::ExpandEnd() const { Assert(expandedTypes != NULL); @@ -922,7 +944,7 @@ PolyType::GetString() const { case TYPE_NUMBER: ret += "number"; break; default: FATAL("Logic error in PolyType::GetString()"); } - + if (quant >= 0) { ret += "$"; ret += std::to_string(quant); @@ -1619,9 +1641,9 @@ PointerType::GetCDeclaration(const std::string &name) const { } std::string ret = baseType->GetCDeclaration(""); - + bool baseIsBasicVarying = (IsBasicType(baseType)) && (baseType->IsVaryingType()); - + if (baseIsBasicVarying) ret += std::string("("); ret += std::string(" *"); if (isConst) ret += " const"; @@ -2463,7 +2485,7 @@ StructType::StructType(const std::string &n, const llvm::SmallVector(pt->GetBaseType()) != NULL) { type = new ArrayType(pt->GetBaseType(), 0); } - + if (paramNames[i] != "") ret += type->GetCDeclaration(paramNames[i]); else @@ -3554,11 +3576,11 @@ FunctionType::GetCDeclarationForDispatch(const std::string &fname) const { CastType(pt->GetBaseType()) != NULL) { type = new ArrayType(pt->GetBaseType(), 0); } - + // Change pointers to varying thingies to void * if (pt != NULL && pt->GetBaseType()->IsVaryingType()) { PointerType *t = PointerType::Void; - + if (paramNames[i] != "") ret += t->GetCDeclaration(paramNames[i]); else @@ -3690,10 +3712,10 @@ FunctionType::LLVMFunctionType(llvm::LLVMContext *ctx, bool removeMask) const { llvmArgTypes.push_back(LLVMTypes::MaskType); std::vector callTypes; - if (isTask + if (isTask #ifdef ISPC_NVPTX_ENABLED && (g->target->getISA() != Target::NVPTX) -#endif +#endif ){ // Tasks take three arguments: a pointer to a struct that holds the // actual task arguments, the thread index, and the total number of diff --git a/type.h b/type.h index cdb6300f..58efe1ca 100644 --- a/type.h +++ b/type.h @@ -413,6 +413,8 @@ public: const PolyRestriction restriction; + static const Type * ReplaceType(const Type *from, const Type *to); + static const PolyType *UniformInteger, *VaryingInteger; static const PolyType *UniformFloating, *VaryingFloating;