[WIP] replace polymorphic types from expressions

This commit is contained in:
2017-05-09 01:46:36 -04:00
parent 9c0f9be022
commit aeb4c0b6f9
5 changed files with 182 additions and 33 deletions

View File

@@ -4674,6 +4674,23 @@ IndexExpr::TypeCheck() {
return this; return this;
} }
Expr *
IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (index == NULL || baseExpr == NULL)
return NULL;
if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) {
lvalueType = new PointerType(to, lvalueType->GetVariability(),
lvalueType->IsConstType());
}
return this;
}
int int
IndexExpr::EstimateCost() const { IndexExpr::EstimateCost() const {
@@ -5316,6 +5333,23 @@ MemberExpr::Optimize() {
return expr ? this : NULL; return expr ? this : NULL;
} }
Expr *
MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL)
return NULL;
if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) {
lvalueType = PolyType::ReplaceType(lvalueType, lvalueType);
}
return this;
}
int int
MemberExpr::EstimateCost() const { MemberExpr::EstimateCost() const {
@@ -7118,6 +7152,9 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const {
else { else {
const AtomicType *toAtomic = CastType<AtomicType>(toType); const AtomicType *toAtomic = CastType<AtomicType>(toType);
// typechecking should ensure this is the case // typechecking should ensure this is the case
if (!toAtomic) {
fprintf(stderr, "I want %s to be atomic\n", toType->GetString().c_str());
}
AssertPos(pos, toAtomic != NULL); AssertPos(pos, toAtomic != NULL);
return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos); return lTypeConvAtomic(ctx, exprVal, toAtomic, fromAtomic, pos);
@@ -7347,6 +7384,18 @@ TypeCastExpr::Optimize() {
return this; return this;
} }
Expr *
TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL)
return NULL;
if (Type::EqualIgnoringConst(type->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
return this;
}
int int
TypeCastExpr::EstimateCost() const { TypeCastExpr::EstimateCost() const {
@@ -8017,6 +8066,18 @@ SymbolExpr::Optimize() {
return this; return this;
} }
Expr *
SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!symbol)
return NULL;
if (Type::EqualIgnoringConst(symbol->type->GetBaseType(), from)) {
symbol->type = PolyType::ReplaceType(symbol->type, to);
}
return this;
}
int int
SymbolExpr::EstimateCost() const { SymbolExpr::EstimateCost() const {
@@ -8815,6 +8876,18 @@ NewExpr::Optimize() {
return this; return this;
} }
Expr *
NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!allocType)
return this;
if (Type::EqualIgnoringConst(allocType->GetBaseType(), from)) {
allocType = PolyType::ReplaceType(allocType, to);
}
return this;
}
void void
NewExpr::Print() const { NewExpr::Print() const {

5
expr.h
View File

@@ -328,6 +328,7 @@ public:
Expr *Optimize(); Expr *Optimize();
Expr *TypeCheck(); Expr *TypeCheck();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const; int EstimateCost() const;
Expr *baseExpr, *index; Expr *baseExpr, *index;
@@ -361,6 +362,7 @@ public:
void Print() const; void Print() const;
Expr *Optimize(); Expr *Optimize();
Expr *TypeCheck(); Expr *TypeCheck();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const; int EstimateCost() const;
virtual int getElementNumber() const = 0; virtual int getElementNumber() const = 0;
@@ -526,6 +528,7 @@ public:
void Print() const; void Print() const;
Expr *TypeCheck(); Expr *TypeCheck();
Expr *Optimize(); Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const; int EstimateCost() const;
Symbol *GetBaseSymbol() const; Symbol *GetBaseSymbol() const;
llvm::Constant *GetConstant(const Type *type) const; llvm::Constant *GetConstant(const Type *type) const;
@@ -685,6 +688,7 @@ public:
Symbol *GetBaseSymbol() const; Symbol *GetBaseSymbol() const;
Expr *TypeCheck(); Expr *TypeCheck();
Expr *Optimize(); Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const; void Print() const;
int EstimateCost() const; int EstimateCost() const;
@@ -813,6 +817,7 @@ public:
const Type *GetType() const; const Type *GetType() const;
Expr *TypeCheck(); Expr *TypeCheck();
Expr *Optimize(); Expr *Optimize();
Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const; void Print() const;
int EstimateCost() const; int EstimateCost() const;

View File

@@ -45,6 +45,7 @@
#include "sym.h" #include "sym.h"
#include "util.h" #include "util.h"
#include <stdio.h> #include <stdio.h>
#include <set>
#if ISPC_LLVM_VERSION == ISPC_LLVM_3_2 // 3.2 #if ISPC_LLVM_VERSION == ISPC_LLVM_3_2 // 3.2
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
@@ -639,41 +640,87 @@ Function::IsPolyFunction() const {
return false; return false;
} }
static bool
lPolyTypeLess(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
if (!pa || !pb) {
char buf[1024];
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
"\"%s\" and \"%s\"\n",
a->GetString().c_str(), b->GetString().c_str());
FATAL(buf);
}
if (pa->restriction < pb->restriction)
return true;
if (pa->restriction > pb->restriction)
return false;
if (pa->GetQuant() < pb->GetQuant())
return true;
return false;
}
std::vector<Function *> * std::vector<Function *> *
Function::ExpandPolyArguments() const { Function::ExpandPolyArguments() const {
std::vector<const Type *> toExpand; std::set<const Type *, bool(*)(const Type *, const Type *)> toExpand(&lPolyTypeLess);
std::vector<Function *> *expanded = new std::vector<Function *>(); std::vector<Function *> *expanded = new std::vector<Function *>();
for (size_t i = 0; i < args.size(); i++) { for (size_t i = 0; i < args.size(); i++) {
if (args[i]->type->IsPolymorphicType()) { if (args[i]->type->IsPolymorphicType() &&
toExpand.push_back(args[i]->type); !toExpand.count(args[i]->type)) {
toExpand.insert(args[i]->type);
} }
} }
for (size_t i = 0; i < toExpand.size(); i++) { std::set<const Type *>::iterator te;
const PolyType *pt = CastType<PolyType>(toExpand[i]->GetBaseType()); for (te = toExpand.begin(); te != toExpand.end(); te++) {
const PolyType *pt = CastType<PolyType>((*te)->GetBaseType());
std::vector<AtomicType *>::iterator expanded; std::vector<AtomicType *>::iterator expand;
expanded = pt->ExpandBegin(); expand = pt->ExpandBegin();
for (; expanded != pt->ExpandEnd(); expanded++) { for (; expand != pt->ExpandEnd(); expand++) {
Type *replacement = *expanded; const Type *replacement = *expand;
Stmt *code_r = code->ReplacePolyType(pt, replacement);
if (toExpand[i]->IsPointerType()) const FunctionType *ft = CastType<FunctionType>(sym->type);
replacement = new PointerType(replacement, llvm::SmallVector<const Type *, 8> nargs;
toExpand[i]->GetVariability(), llvm::SmallVector<std::string, 8> nargsn;
toExpand[i]->IsConstType()); llvm::SmallVector<Expr *, 8> nargsd;
else if (toExpand[i]->IsArrayType()) llvm::SmallVector<SourcePos, 8> nargsp;
replacement = new ArrayType(replacement, for (size_t i = 0; i < args.size(); i++) {
(CastType<ArrayType>(toExpand[i]))->GetElementCount()); if (Type::EqualIgnoringConst(args[i]->type->GetBaseType(), pt)) {
else if (toExpand[i]->IsReferenceType()) nargs.push_back(PolyType::ReplaceType(args[i]->type, replacement));
replacement = new ReferenceType(replacement); } else {
nargs.push_back(args[i]->type);
}
nargsn.push_back(ft->GetParameterName(i));
nargsd.push_back(ft->GetParameterDefault(i));
nargsp.push_back(ft->GetParameterSourcePos(i));
}
printf("pretend I'm replacing %s with %s\n", Symbol *nsym = new Symbol(sym->name, sym->pos,
toExpand[i]->GetString().c_str(), new FunctionType(ft->GetReturnType(),
replacement->GetString().c_str()); nargs,
nargsn,
nargsd,
nargsp,
ft->isTask,
ft->isExported,
ft->isExternC,
ft->isUnmasked));
nsym->function = sym->function;
nsym->exportedFunction = sym->exportedFunction;
expanded->push_back(new Function(nsym, code_r));
replacement = PolyType::ReplaceType(*te, replacement);
} }
} }
return expanded; return expanded;
} }

View File

@@ -694,6 +694,28 @@ const PolyType *PolyType::UniformNumber =
const PolyType *PolyType::VaryingNumber = const PolyType *PolyType::VaryingNumber =
new PolyType(PolyType::TYPE_NUMBER, Variability::Varying, false); new PolyType(PolyType::TYPE_NUMBER, Variability::Varying, false);
const Type *
PolyType::ReplaceType(const Type *from, const Type *to) {
const Type *t = to;
if (from->IsPointerType()) {
t = new PointerType(to,
from->GetVariability(),
from->IsConstType());
} else if (from->IsArrayType()) {
t = new ArrayType(to,
CastType<ArrayType>(from)->GetElementCount());
} else if (from->IsReferenceType()) {
t = new ReferenceType(to);
}
fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n",
from->GetString().c_str(),
t->GetString().c_str());
return t;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic) PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) { : Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
asOtherConstType = NULL; asOtherConstType = NULL;
@@ -816,7 +838,7 @@ PolyType::GetAsUniformType() const {
return asUniformType; return asUniformType;
} }
const std::vector<AtomicType *>::iterator const std::vector<AtomicType *>::iterator
PolyType::ExpandBegin() const { PolyType::ExpandBegin() const {
if (expandedTypes) if (expandedTypes)
return expandedTypes->begin(); return expandedTypes->begin();
@@ -841,7 +863,7 @@ PolyType::ExpandBegin() const {
return expandedTypes->begin(); return expandedTypes->begin();
} }
const std::vector<AtomicType *>::iterator const std::vector<AtomicType *>::iterator
PolyType::ExpandEnd() const { PolyType::ExpandEnd() const {
Assert(expandedTypes != NULL); Assert(expandedTypes != NULL);
@@ -922,7 +944,7 @@ PolyType::GetString() const {
case TYPE_NUMBER: ret += "number"; break; case TYPE_NUMBER: ret += "number"; break;
default: FATAL("Logic error in PolyType::GetString()"); default: FATAL("Logic error in PolyType::GetString()");
} }
if (quant >= 0) { if (quant >= 0) {
ret += "$"; ret += "$";
ret += std::to_string(quant); ret += std::to_string(quant);
@@ -1619,9 +1641,9 @@ PointerType::GetCDeclaration(const std::string &name) const {
} }
std::string ret = baseType->GetCDeclaration(""); std::string ret = baseType->GetCDeclaration("");
bool baseIsBasicVarying = (IsBasicType(baseType)) && (baseType->IsVaryingType()); bool baseIsBasicVarying = (IsBasicType(baseType)) && (baseType->IsVaryingType());
if (baseIsBasicVarying) ret += std::string("("); if (baseIsBasicVarying) ret += std::string("(");
ret += std::string(" *"); ret += std::string(" *");
if (isConst) ret += " const"; if (isConst) ret += " const";
@@ -2463,7 +2485,7 @@ StructType::StructType(const std::string &n, const llvm::SmallVector<const Type
} }
} }
const std::string const std::string
StructType::GetCStructName() const { StructType::GetCStructName() const {
// only return mangled name for varying structs for backwards // only return mangled name for varying structs for backwards
// compatibility... // compatibility...
@@ -3523,7 +3545,7 @@ FunctionType::GetCDeclaration(const std::string &fname) const {
CastType<ArrayType>(pt->GetBaseType()) != NULL) { CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0); type = new ArrayType(pt->GetBaseType(), 0);
} }
if (paramNames[i] != "") if (paramNames[i] != "")
ret += type->GetCDeclaration(paramNames[i]); ret += type->GetCDeclaration(paramNames[i]);
else else
@@ -3554,11 +3576,11 @@ FunctionType::GetCDeclarationForDispatch(const std::string &fname) const {
CastType<ArrayType>(pt->GetBaseType()) != NULL) { CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0); type = new ArrayType(pt->GetBaseType(), 0);
} }
// Change pointers to varying thingies to void * // Change pointers to varying thingies to void *
if (pt != NULL && pt->GetBaseType()->IsVaryingType()) { if (pt != NULL && pt->GetBaseType()->IsVaryingType()) {
PointerType *t = PointerType::Void; PointerType *t = PointerType::Void;
if (paramNames[i] != "") if (paramNames[i] != "")
ret += t->GetCDeclaration(paramNames[i]); ret += t->GetCDeclaration(paramNames[i]);
else else
@@ -3690,10 +3712,10 @@ FunctionType::LLVMFunctionType(llvm::LLVMContext *ctx, bool removeMask) const {
llvmArgTypes.push_back(LLVMTypes::MaskType); llvmArgTypes.push_back(LLVMTypes::MaskType);
std::vector<llvm::Type *> callTypes; std::vector<llvm::Type *> callTypes;
if (isTask if (isTask
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
&& (g->target->getISA() != Target::NVPTX) && (g->target->getISA() != Target::NVPTX)
#endif #endif
){ ){
// Tasks take three arguments: a pointer to a struct that holds the // Tasks take three arguments: a pointer to a struct that holds the
// actual task arguments, the thread index, and the total number of // actual task arguments, the thread index, and the total number of

2
type.h
View File

@@ -413,6 +413,8 @@ public:
const PolyRestriction restriction; const PolyRestriction restriction;
static const Type * ReplaceType(const Type *from, const Type *to);
static const PolyType *UniformInteger, *VaryingInteger; static const PolyType *UniformInteger, *VaryingInteger;
static const PolyType *UniformFloating, *VaryingFloating; static const PolyType *UniformFloating, *VaryingFloating;