Attempt to replicate AST when expanding polytypes
This commit is contained in:
124
expr.cpp
124
expr.cpp
@@ -114,7 +114,52 @@ Expr::GetBaseSymbol() const {
|
|||||||
|
|
||||||
Expr *
|
Expr *
|
||||||
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
|
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
|
||||||
return this;
|
Expr *copy;
|
||||||
|
switch (getValueID()) {
|
||||||
|
case AddressOfExprID:
|
||||||
|
copy = (Expr*)new AddressOfExpr(*(AddressOfExpr*)this);
|
||||||
|
break;
|
||||||
|
case AssignExprID:
|
||||||
|
copy = (Expr*)new AssignExpr(*(AssignExpr*)this);
|
||||||
|
break;
|
||||||
|
case BinaryExprID:
|
||||||
|
copy = (Expr*)new BinaryExpr(*(BinaryExpr*)this);
|
||||||
|
break;
|
||||||
|
case ConstExprID:
|
||||||
|
copy = (Expr*)new ConstExpr(*(ConstExpr*)this);
|
||||||
|
break;
|
||||||
|
case PtrDerefExprID:
|
||||||
|
copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this);
|
||||||
|
break;
|
||||||
|
case RefDerefExprID:
|
||||||
|
copy = (Expr*)new RefDerefExpr(*(RefDerefExpr*)this);
|
||||||
|
break;
|
||||||
|
case ExprListID:
|
||||||
|
copy = (Expr*)new ExprList(*(ExprList*)this);
|
||||||
|
break;
|
||||||
|
case FunctionCallExprID:
|
||||||
|
copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this);
|
||||||
|
break;
|
||||||
|
case NullPointerExprID:
|
||||||
|
copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this);
|
||||||
|
break;
|
||||||
|
case ReferenceExprID:
|
||||||
|
copy = (Expr*)new ReferenceExpr(*(ReferenceExpr*)this);
|
||||||
|
break;
|
||||||
|
case SelectExprID:
|
||||||
|
copy = (Expr*)new SelectExpr(*(SelectExpr*)this);
|
||||||
|
break;
|
||||||
|
case SizeOfExprID:
|
||||||
|
copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this);
|
||||||
|
break;
|
||||||
|
case UnaryExprID:
|
||||||
|
copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
FATAL("Unmatched case in ReplacePolyType (expr)");
|
||||||
|
copy = this; // just to silence the compiler
|
||||||
|
}
|
||||||
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -4147,6 +4192,14 @@ IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p)
|
|||||||
type = lvalueType = NULL;
|
type = lvalueType = NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IndexExpr::IndexExpr(IndexExpr *base)
|
||||||
|
: Expr(base->pos, IndexExprID) {
|
||||||
|
baseExpr = base->baseExpr;
|
||||||
|
index = base->index;
|
||||||
|
type = base->type;
|
||||||
|
lvalueType = base->lvalueType;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/** When computing pointer values, we need to apply a per-lane offset when
|
/** When computing pointer values, we need to apply a per-lane offset when
|
||||||
we have a varying pointer that is itself indexing into varying data.
|
we have a varying pointer that is itself indexing into varying data.
|
||||||
@@ -4679,16 +4732,18 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
|||||||
if (index == NULL || baseExpr == NULL)
|
if (index == NULL || baseExpr == NULL)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
|
IndexExpr *copy = new IndexExpr(this);
|
||||||
type = PolyType::ReplaceType(type, to);
|
|
||||||
|
if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) {
|
||||||
|
copy->type = PolyType::ReplaceType(type, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
|
if (Type::EqualForReplacement(copy->GetLValueType()->GetBaseType(), from)) {
|
||||||
lvalueType = new PointerType(to, lvalueType->GetVariability(),
|
copy->lvalueType = new PointerType(to, copy->lvalueType->GetVariability(),
|
||||||
lvalueType->IsConstType());
|
copy->lvalueType->IsConstType());
|
||||||
}
|
}
|
||||||
|
|
||||||
return this;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -5338,15 +5393,15 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
|||||||
if (expr == NULL)
|
if (expr == NULL)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
|
MemberExpr *copy = getValueID() == StructMemberExprID ?
|
||||||
type = PolyType::ReplaceType(type, to);
|
(MemberExpr*) new StructMemberExpr(*(StructMemberExpr*)this) :
|
||||||
|
(MemberExpr*) new VectorMemberExpr(*(VectorMemberExpr*)this);
|
||||||
|
|
||||||
|
if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) {
|
||||||
|
copy->type = PolyType::ReplaceType(copy->type, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
|
return copy;
|
||||||
lvalueType = PolyType::ReplaceType(lvalueType, lvalueType);
|
|
||||||
}
|
|
||||||
|
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -6292,6 +6347,12 @@ TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p)
|
|||||||
expr = e;
|
expr = e;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypeCastExpr::TypeCastExpr(TypeCastExpr *base)
|
||||||
|
: Expr(base->pos, TypeCastExprID) {
|
||||||
|
type = base->type;
|
||||||
|
expr = base->expr;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/** Handle all the grungy details of type conversion between atomic types.
|
/** Handle all the grungy details of type conversion between atomic types.
|
||||||
Given an input value in exprVal of type fromType, convert it to the
|
Given an input value in exprVal of type fromType, convert it to the
|
||||||
@@ -7389,11 +7450,13 @@ TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
|||||||
if (type == NULL)
|
if (type == NULL)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
if (Type::EqualForReplacement(type->GetBaseType(), from)) {
|
TypeCastExpr *copy = new TypeCastExpr(this);
|
||||||
type = PolyType::ReplaceType(type, to);
|
|
||||||
|
if (Type::EqualForReplacement(copy->type->GetBaseType(), from)) {
|
||||||
|
copy->type = PolyType::ReplaceType(copy->type, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -8002,6 +8065,11 @@ SymbolExpr::SymbolExpr(Symbol *s, SourcePos p)
|
|||||||
symbol = s;
|
symbol = s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SymbolExpr::SymbolExpr(SymbolExpr *base)
|
||||||
|
: Expr(base->pos, SymbolExprID) {
|
||||||
|
symbol = base->symbol;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
llvm::Value *
|
llvm::Value *
|
||||||
SymbolExpr::GetValue(FunctionEmitContext *ctx) const {
|
SymbolExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||||
@@ -8071,11 +8139,15 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
|||||||
if (!symbol)
|
if (!symbol)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
|
SymbolExpr *copy = new SymbolExpr(this);
|
||||||
|
|
||||||
|
copy->symbol = new Symbol(*symbol);
|
||||||
|
|
||||||
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
|
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
|
||||||
symbol->type = PolyType::ReplaceType(symbol->type, to);
|
copy->symbol->type = PolyType::ReplaceType(symbol->type, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -8679,6 +8751,14 @@ NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count,
|
|||||||
allocType = allocType->ResolveUnboundVariability(Variability::Uniform);
|
allocType = allocType->ResolveUnboundVariability(Variability::Uniform);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NewExpr::NewExpr(NewExpr *base)
|
||||||
|
: Expr(base->pos, NewExprID) {
|
||||||
|
allocType = base->allocType;
|
||||||
|
initExpr = base->initExpr;
|
||||||
|
countExpr = base->countExpr;
|
||||||
|
isVarying = base->isVarying;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
llvm::Value *
|
llvm::Value *
|
||||||
NewExpr::GetValue(FunctionEmitContext *ctx) const {
|
NewExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||||
@@ -8881,11 +8961,13 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
|||||||
if (!allocType)
|
if (!allocType)
|
||||||
return this;
|
return this;
|
||||||
|
|
||||||
|
NewExpr *copy = new NewExpr(this);
|
||||||
|
|
||||||
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
|
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
|
||||||
allocType = PolyType::ReplaceType(allocType, to);
|
copy->allocType = PolyType::ReplaceType(allocType, to);
|
||||||
}
|
}
|
||||||
|
|
||||||
return this;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
6
expr.h
6
expr.h
@@ -334,6 +334,7 @@ public:
|
|||||||
Expr *baseExpr, *index;
|
Expr *baseExpr, *index;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
IndexExpr(IndexExpr *base);
|
||||||
mutable const Type *type;
|
mutable const Type *type;
|
||||||
mutable const PointerType *lvalueType;
|
mutable const PointerType *lvalueType;
|
||||||
};
|
};
|
||||||
@@ -535,6 +536,8 @@ public:
|
|||||||
|
|
||||||
const Type *type;
|
const Type *type;
|
||||||
Expr *expr;
|
Expr *expr;
|
||||||
|
private:
|
||||||
|
TypeCastExpr(TypeCastExpr *base);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -693,6 +696,7 @@ public:
|
|||||||
int EstimateCost() const;
|
int EstimateCost() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
SymbolExpr(SymbolExpr *base);
|
||||||
Symbol *symbol;
|
Symbol *symbol;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -835,6 +839,8 @@ public:
|
|||||||
instance, or whether a single allocation is performed for the
|
instance, or whether a single allocation is performed for the
|
||||||
entire gang of program instances.) */
|
entire gang of program instances.) */
|
||||||
bool isVarying;
|
bool isVarying;
|
||||||
|
private:
|
||||||
|
NewExpr(NewExpr *base);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
91
stmt.cpp
91
stmt.cpp
@@ -79,7 +79,70 @@ Stmt::Optimize() {
|
|||||||
|
|
||||||
Stmt *
|
Stmt *
|
||||||
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
|
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
|
||||||
return this;
|
Stmt *copy;
|
||||||
|
switch (getValueID()) {
|
||||||
|
case AssertStmtID:
|
||||||
|
copy = (Stmt*)new AssertStmt(*(AssertStmt*)this);
|
||||||
|
break;
|
||||||
|
case BreakStmtID:
|
||||||
|
copy = (Stmt*)new BreakStmt(*(BreakStmt*)this);
|
||||||
|
break;
|
||||||
|
case CaseStmtID:
|
||||||
|
copy = (Stmt*)new CaseStmt(*(CaseStmt*)this);
|
||||||
|
break;
|
||||||
|
case ContinueStmtID:
|
||||||
|
copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this);
|
||||||
|
break;
|
||||||
|
case DefaultStmtID:
|
||||||
|
copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this);
|
||||||
|
break;
|
||||||
|
case DeleteStmtID:
|
||||||
|
copy = (Stmt*)new DeleteStmt(*(DeleteStmt*)this);
|
||||||
|
break;
|
||||||
|
case DoStmtID:
|
||||||
|
copy = (Stmt*)new DoStmt(*(DoStmt*)this);
|
||||||
|
break;
|
||||||
|
case ExprStmtID:
|
||||||
|
copy = (Stmt*)new ExprStmt(*(ExprStmt*)this);
|
||||||
|
break;
|
||||||
|
case ForeachActiveStmtID:
|
||||||
|
copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this);
|
||||||
|
break;
|
||||||
|
case ForeachUniqueStmtID:
|
||||||
|
copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this);
|
||||||
|
break;
|
||||||
|
case ForStmtID:
|
||||||
|
copy = (Stmt*)new ForStmt(*(ForStmt*)this);
|
||||||
|
break;
|
||||||
|
case GotoStmtID:
|
||||||
|
copy = (Stmt*)new GotoStmt(*(GotoStmt*)this);
|
||||||
|
break;
|
||||||
|
case IfStmtID:
|
||||||
|
copy = (Stmt*)new IfStmt(*(IfStmt*)this);
|
||||||
|
break;
|
||||||
|
case LabeledStmtID:
|
||||||
|
copy = (Stmt*)new LabeledStmt(*(LabeledStmt*)this);
|
||||||
|
break;
|
||||||
|
case PrintStmtID:
|
||||||
|
copy = (Stmt*)new PrintStmt(*(PrintStmt*)this);
|
||||||
|
break;
|
||||||
|
case ReturnStmtID:
|
||||||
|
copy = (Stmt*)new ReturnStmt(*(ReturnStmt*)this);
|
||||||
|
break;
|
||||||
|
case StmtListID:
|
||||||
|
copy = (Stmt*)new StmtList(*(StmtList*)this);
|
||||||
|
break;
|
||||||
|
case SwitchStmtID:
|
||||||
|
copy = (Stmt*)new SwitchStmt(*(SwitchStmt*)this);
|
||||||
|
break;
|
||||||
|
case UnmaskedStmtID:
|
||||||
|
copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
FATAL("Unmatched case in ReplacePolyType (stmt)");
|
||||||
|
copy = this; // just to silence the compiler
|
||||||
|
}
|
||||||
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -134,6 +197,11 @@ DeclStmt::DeclStmt(const std::vector<VariableDeclaration> &v, SourcePos p)
|
|||||||
: Stmt(p, DeclStmtID), vars(v) {
|
: Stmt(p, DeclStmtID), vars(v) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
DeclStmt::DeclStmt(DeclStmt *base)
|
||||||
|
: Stmt(base->pos, DeclStmtID) {
|
||||||
|
vars = base->vars;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static bool
|
static bool
|
||||||
lHasUnsizedArrays(const Type *type) {
|
lHasUnsizedArrays(const Type *type) {
|
||||||
@@ -501,14 +569,16 @@ DeclStmt::TypeCheck() {
|
|||||||
|
|
||||||
Stmt *
|
Stmt *
|
||||||
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
|
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||||
|
DeclStmt *copy = new DeclStmt(this);
|
||||||
|
|
||||||
for (size_t i = 0; i < vars.size(); i++) {
|
for (size_t i = 0; i < vars.size(); i++) {
|
||||||
Symbol *s = vars[i].sym;
|
Symbol *s = copy->vars[i].sym;
|
||||||
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
|
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
|
||||||
s->type = PolyType::ReplaceType(s->type, to);
|
s->type = PolyType::ReplaceType(s->type, to);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return this;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -1487,6 +1557,15 @@ ForeachStmt::ForeachStmt(const std::vector<Symbol *> &lvs,
|
|||||||
stmts(s) {
|
stmts(s) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ForeachStmt::ForeachStmt(ForeachStmt *base)
|
||||||
|
: Stmt(base->pos, ForeachStmtID) {
|
||||||
|
dimVariables = base->dimVariables;
|
||||||
|
startExprs = base->startExprs;
|
||||||
|
endExprs = base->endExprs;
|
||||||
|
isTiled = base->isTiled;
|
||||||
|
stmts = base->stmts;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Given a uniform counter value in the memory location pointed to by
|
/* Given a uniform counter value in the memory location pointed to by
|
||||||
uniformCounterPtr, compute the corresponding set of varying counter
|
uniformCounterPtr, compute the corresponding set of varying counter
|
||||||
@@ -2196,14 +2275,16 @@ ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
|
|||||||
if (!stmts)
|
if (!stmts)
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
|
ForeachStmt *copy = new ForeachStmt(this);
|
||||||
|
|
||||||
for (size_t i=0; i<dimVariables.size(); i++) {
|
for (size_t i=0; i<dimVariables.size(); i++) {
|
||||||
const Type *t = dimVariables[i]->type;
|
const Type *t = copy->dimVariables[i]->type;
|
||||||
if (Type::EqualForReplacement(t->GetBaseType(), from)) {
|
if (Type::EqualForReplacement(t->GetBaseType(), from)) {
|
||||||
t = PolyType::ReplaceType(t, to);
|
t = PolyType::ReplaceType(t, to);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return this;
|
return copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
4
stmt.h
4
stmt.h
@@ -122,6 +122,8 @@ public:
|
|||||||
int EstimateCost() const;
|
int EstimateCost() const;
|
||||||
|
|
||||||
std::vector<VariableDeclaration> vars;
|
std::vector<VariableDeclaration> vars;
|
||||||
|
private:
|
||||||
|
DeclStmt(DeclStmt *base);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -291,6 +293,8 @@ public:
|
|||||||
std::vector<Expr *> endExprs;
|
std::vector<Expr *> endExprs;
|
||||||
bool isTiled;
|
bool isTiled;
|
||||||
Stmt *stmts;
|
Stmt *stmts;
|
||||||
|
private:
|
||||||
|
ForeachStmt(ForeachStmt *base);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user