Attempt to replicate AST when expanding polytypes

This commit is contained in:
2017-05-10 11:11:39 -04:00
parent 192b99f21d
commit 6a91c5d5ac
4 changed files with 199 additions and 26 deletions

124
expr.cpp
View File

@@ -114,7 +114,52 @@ Expr::GetBaseSymbol() const {
Expr * Expr *
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this; Expr *copy;
switch (getValueID()) {
case AddressOfExprID:
copy = (Expr*)new AddressOfExpr(*(AddressOfExpr*)this);
break;
case AssignExprID:
copy = (Expr*)new AssignExpr(*(AssignExpr*)this);
break;
case BinaryExprID:
copy = (Expr*)new BinaryExpr(*(BinaryExpr*)this);
break;
case ConstExprID:
copy = (Expr*)new ConstExpr(*(ConstExpr*)this);
break;
case PtrDerefExprID:
copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this);
break;
case RefDerefExprID:
copy = (Expr*)new RefDerefExpr(*(RefDerefExpr*)this);
break;
case ExprListID:
copy = (Expr*)new ExprList(*(ExprList*)this);
break;
case FunctionCallExprID:
copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this);
break;
case NullPointerExprID:
copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this);
break;
case ReferenceExprID:
copy = (Expr*)new ReferenceExpr(*(ReferenceExpr*)this);
break;
case SelectExprID:
copy = (Expr*)new SelectExpr(*(SelectExpr*)this);
break;
case SizeOfExprID:
copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this);
break;
case UnaryExprID:
copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this);
break;
default:
FATAL("Unmatched case in ReplacePolyType (expr)");
copy = this; // just to silence the compiler
}
return copy;
} }
@@ -4147,6 +4192,14 @@ IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p)
type = lvalueType = NULL; type = lvalueType = NULL;
} }
IndexExpr::IndexExpr(IndexExpr *base)
: Expr(base->pos, IndexExprID) {
baseExpr = base->baseExpr;
index = base->index;
type = base->type;
lvalueType = base->lvalueType;
}
/** When computing pointer values, we need to apply a per-lane offset when /** When computing pointer values, we need to apply a per-lane offset when
we have a varying pointer that is itself indexing into varying data. we have a varying pointer that is itself indexing into varying data.
@@ -4679,16 +4732,18 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (index == NULL || baseExpr == NULL) if (index == NULL || baseExpr == NULL)
return NULL; return NULL;
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { IndexExpr *copy = new IndexExpr(this);
type = PolyType::ReplaceType(type, to);
if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) {
copy->type = PolyType::ReplaceType(type, to);
} }
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { if (Type::EqualForReplacement(copy->GetLValueType()->GetBaseType(), from)) {
lvalueType = new PointerType(to, lvalueType->GetVariability(), copy->lvalueType = new PointerType(to, copy->lvalueType->GetVariability(),
lvalueType->IsConstType()); copy->lvalueType->IsConstType());
} }
return this; return copy;
} }
@@ -5338,15 +5393,15 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL) if (expr == NULL)
return NULL; return NULL;
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { MemberExpr *copy = getValueID() == StructMemberExprID ?
type = PolyType::ReplaceType(type, to); (MemberExpr*) new StructMemberExpr(*(StructMemberExpr*)this) :
(MemberExpr*) new VectorMemberExpr(*(VectorMemberExpr*)this);
if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) {
copy->type = PolyType::ReplaceType(copy->type, to);
} }
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { return copy;
lvalueType = PolyType::ReplaceType(lvalueType, lvalueType);
}
return this;
} }
@@ -6292,6 +6347,12 @@ TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p)
expr = e; expr = e;
} }
TypeCastExpr::TypeCastExpr(TypeCastExpr *base)
: Expr(base->pos, TypeCastExprID) {
type = base->type;
expr = base->expr;
}
/** Handle all the grungy details of type conversion between atomic types. /** Handle all the grungy details of type conversion between atomic types.
Given an input value in exprVal of type fromType, convert it to the Given an input value in exprVal of type fromType, convert it to the
@@ -7389,11 +7450,13 @@ TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (type == NULL) if (type == NULL)
return NULL; return NULL;
if (Type::EqualForReplacement(type->GetBaseType(), from)) { TypeCastExpr *copy = new TypeCastExpr(this);
type = PolyType::ReplaceType(type, to);
if (Type::EqualForReplacement(copy->type->GetBaseType(), from)) {
copy->type = PolyType::ReplaceType(copy->type, to);
} }
return this; return copy;
} }
@@ -8002,6 +8065,11 @@ SymbolExpr::SymbolExpr(Symbol *s, SourcePos p)
symbol = s; symbol = s;
} }
SymbolExpr::SymbolExpr(SymbolExpr *base)
: Expr(base->pos, SymbolExprID) {
symbol = base->symbol;
}
llvm::Value * llvm::Value *
SymbolExpr::GetValue(FunctionEmitContext *ctx) const { SymbolExpr::GetValue(FunctionEmitContext *ctx) const {
@@ -8071,11 +8139,15 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!symbol) if (!symbol)
return NULL; return NULL;
SymbolExpr *copy = new SymbolExpr(this);
copy->symbol = new Symbol(*symbol);
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) { if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
symbol->type = PolyType::ReplaceType(symbol->type, to); copy->symbol->type = PolyType::ReplaceType(symbol->type, to);
} }
return this; return copy;
} }
@@ -8679,6 +8751,14 @@ NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count,
allocType = allocType->ResolveUnboundVariability(Variability::Uniform); allocType = allocType->ResolveUnboundVariability(Variability::Uniform);
} }
NewExpr::NewExpr(NewExpr *base)
: Expr(base->pos, NewExprID) {
allocType = base->allocType;
initExpr = base->initExpr;
countExpr = base->countExpr;
isVarying = base->isVarying;
}
llvm::Value * llvm::Value *
NewExpr::GetValue(FunctionEmitContext *ctx) const { NewExpr::GetValue(FunctionEmitContext *ctx) const {
@@ -8881,11 +8961,13 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!allocType) if (!allocType)
return this; return this;
NewExpr *copy = new NewExpr(this);
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) { if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
allocType = PolyType::ReplaceType(allocType, to); copy->allocType = PolyType::ReplaceType(allocType, to);
} }
return this; return copy;
} }

6
expr.h
View File

@@ -334,6 +334,7 @@ public:
Expr *baseExpr, *index; Expr *baseExpr, *index;
private: private:
IndexExpr(IndexExpr *base);
mutable const Type *type; mutable const Type *type;
mutable const PointerType *lvalueType; mutable const PointerType *lvalueType;
}; };
@@ -535,6 +536,8 @@ public:
const Type *type; const Type *type;
Expr *expr; Expr *expr;
private:
TypeCastExpr(TypeCastExpr *base);
}; };
@@ -693,6 +696,7 @@ public:
int EstimateCost() const; int EstimateCost() const;
private: private:
SymbolExpr(SymbolExpr *base);
Symbol *symbol; Symbol *symbol;
}; };
@@ -835,6 +839,8 @@ public:
instance, or whether a single allocation is performed for the instance, or whether a single allocation is performed for the
entire gang of program instances.) */ entire gang of program instances.) */
bool isVarying; bool isVarying;
private:
NewExpr(NewExpr *base);
}; };

View File

@@ -79,7 +79,70 @@ Stmt::Optimize() {
Stmt * Stmt *
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this; Stmt *copy;
switch (getValueID()) {
case AssertStmtID:
copy = (Stmt*)new AssertStmt(*(AssertStmt*)this);
break;
case BreakStmtID:
copy = (Stmt*)new BreakStmt(*(BreakStmt*)this);
break;
case CaseStmtID:
copy = (Stmt*)new CaseStmt(*(CaseStmt*)this);
break;
case ContinueStmtID:
copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this);
break;
case DefaultStmtID:
copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this);
break;
case DeleteStmtID:
copy = (Stmt*)new DeleteStmt(*(DeleteStmt*)this);
break;
case DoStmtID:
copy = (Stmt*)new DoStmt(*(DoStmt*)this);
break;
case ExprStmtID:
copy = (Stmt*)new ExprStmt(*(ExprStmt*)this);
break;
case ForeachActiveStmtID:
copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this);
break;
case ForeachUniqueStmtID:
copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this);
break;
case ForStmtID:
copy = (Stmt*)new ForStmt(*(ForStmt*)this);
break;
case GotoStmtID:
copy = (Stmt*)new GotoStmt(*(GotoStmt*)this);
break;
case IfStmtID:
copy = (Stmt*)new IfStmt(*(IfStmt*)this);
break;
case LabeledStmtID:
copy = (Stmt*)new LabeledStmt(*(LabeledStmt*)this);
break;
case PrintStmtID:
copy = (Stmt*)new PrintStmt(*(PrintStmt*)this);
break;
case ReturnStmtID:
copy = (Stmt*)new ReturnStmt(*(ReturnStmt*)this);
break;
case StmtListID:
copy = (Stmt*)new StmtList(*(StmtList*)this);
break;
case SwitchStmtID:
copy = (Stmt*)new SwitchStmt(*(SwitchStmt*)this);
break;
case UnmaskedStmtID:
copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this);
break;
default:
FATAL("Unmatched case in ReplacePolyType (stmt)");
copy = this; // just to silence the compiler
}
return copy;
} }
@@ -134,6 +197,11 @@ DeclStmt::DeclStmt(const std::vector<VariableDeclaration> &v, SourcePos p)
: Stmt(p, DeclStmtID), vars(v) { : Stmt(p, DeclStmtID), vars(v) {
} }
DeclStmt::DeclStmt(DeclStmt *base)
: Stmt(base->pos, DeclStmtID) {
vars = base->vars;
}
static bool static bool
lHasUnsizedArrays(const Type *type) { lHasUnsizedArrays(const Type *type) {
@@ -501,14 +569,16 @@ DeclStmt::TypeCheck() {
Stmt * Stmt *
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) { DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
DeclStmt *copy = new DeclStmt(this);
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
Symbol *s = vars[i].sym; Symbol *s = copy->vars[i].sym;
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) { if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
s->type = PolyType::ReplaceType(s->type, to); s->type = PolyType::ReplaceType(s->type, to);
} }
} }
return this; return copy;
} }
@@ -1487,6 +1557,15 @@ ForeachStmt::ForeachStmt(const std::vector<Symbol *> &lvs,
stmts(s) { stmts(s) {
} }
ForeachStmt::ForeachStmt(ForeachStmt *base)
: Stmt(base->pos, ForeachStmtID) {
dimVariables = base->dimVariables;
startExprs = base->startExprs;
endExprs = base->endExprs;
isTiled = base->isTiled;
stmts = base->stmts;
}
/* Given a uniform counter value in the memory location pointed to by /* Given a uniform counter value in the memory location pointed to by
uniformCounterPtr, compute the corresponding set of varying counter uniformCounterPtr, compute the corresponding set of varying counter
@@ -2196,14 +2275,16 @@ ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
if (!stmts) if (!stmts)
return NULL; return NULL;
ForeachStmt *copy = new ForeachStmt(this);
for (size_t i=0; i<dimVariables.size(); i++) { for (size_t i=0; i<dimVariables.size(); i++) {
const Type *t = dimVariables[i]->type; const Type *t = copy->dimVariables[i]->type;
if (Type::EqualForReplacement(t->GetBaseType(), from)) { if (Type::EqualForReplacement(t->GetBaseType(), from)) {
t = PolyType::ReplaceType(t, to); t = PolyType::ReplaceType(t, to);
} }
} }
return this; return copy;
} }

4
stmt.h
View File

@@ -122,6 +122,8 @@ public:
int EstimateCost() const; int EstimateCost() const;
std::vector<VariableDeclaration> vars; std::vector<VariableDeclaration> vars;
private:
DeclStmt(DeclStmt *base);
}; };
@@ -291,6 +293,8 @@ public:
std::vector<Expr *> endExprs; std::vector<Expr *> endExprs;
bool isTiled; bool isTiled;
Stmt *stmts; Stmt *stmts;
private:
ForeachStmt(ForeachStmt *base);
}; };