diff --git a/expr.cpp b/expr.cpp index 04a28d40..ad62e08b 100644 --- a/expr.cpp +++ b/expr.cpp @@ -114,7 +114,52 @@ Expr::GetBaseSymbol() const { Expr * Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { - return this; + Expr *copy; + switch (getValueID()) { + case AddressOfExprID: + copy = (Expr*)new AddressOfExpr(*(AddressOfExpr*)this); + break; + case AssignExprID: + copy = (Expr*)new AssignExpr(*(AssignExpr*)this); + break; + case BinaryExprID: + copy = (Expr*)new BinaryExpr(*(BinaryExpr*)this); + break; + case ConstExprID: + copy = (Expr*)new ConstExpr(*(ConstExpr*)this); + break; + case PtrDerefExprID: + copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this); + break; + case RefDerefExprID: + copy = (Expr*)new RefDerefExpr(*(RefDerefExpr*)this); + break; + case ExprListID: + copy = (Expr*)new ExprList(*(ExprList*)this); + break; + case FunctionCallExprID: + copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this); + break; + case NullPointerExprID: + copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this); + break; + case ReferenceExprID: + copy = (Expr*)new ReferenceExpr(*(ReferenceExpr*)this); + break; + case SelectExprID: + copy = (Expr*)new SelectExpr(*(SelectExpr*)this); + break; + case SizeOfExprID: + copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this); + break; + case UnaryExprID: + copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this); + break; + default: + FATAL("Unmatched case in ReplacePolyType (expr)"); + copy = this; // just to silence the compiler + } + return copy; } @@ -4147,6 +4192,14 @@ IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p) type = lvalueType = NULL; } +IndexExpr::IndexExpr(IndexExpr *base) + : Expr(base->pos, IndexExprID) { + baseExpr = base->baseExpr; + index = base->index; + type = base->type; + lvalueType = base->lvalueType; +} + /** When computing pointer values, we need to apply a per-lane offset when we have a varying pointer that is itself indexing into varying data. @@ -4679,16 +4732,18 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (index == NULL || baseExpr == NULL) return NULL; - if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { - type = PolyType::ReplaceType(type, to); + IndexExpr *copy = new IndexExpr(this); + + if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) { + copy->type = PolyType::ReplaceType(type, to); } - if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { - lvalueType = new PointerType(to, lvalueType->GetVariability(), - lvalueType->IsConstType()); + if (Type::EqualForReplacement(copy->GetLValueType()->GetBaseType(), from)) { + copy->lvalueType = new PointerType(to, copy->lvalueType->GetVariability(), + copy->lvalueType->IsConstType()); } - return this; + return copy; } @@ -5338,15 +5393,15 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (expr == NULL) return NULL; - if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { - type = PolyType::ReplaceType(type, to); + MemberExpr *copy = getValueID() == StructMemberExprID ? + (MemberExpr*) new StructMemberExpr(*(StructMemberExpr*)this) : + (MemberExpr*) new VectorMemberExpr(*(VectorMemberExpr*)this); + + if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) { + copy->type = PolyType::ReplaceType(copy->type, to); } - if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { - lvalueType = PolyType::ReplaceType(lvalueType, lvalueType); - } - - return this; + return copy; } @@ -6292,6 +6347,12 @@ TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p) expr = e; } +TypeCastExpr::TypeCastExpr(TypeCastExpr *base) + : Expr(base->pos, TypeCastExprID) { + type = base->type; + expr = base->expr; +} + /** Handle all the grungy details of type conversion between atomic types. Given an input value in exprVal of type fromType, convert it to the @@ -7389,11 +7450,13 @@ TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (type == NULL) return NULL; - if (Type::EqualForReplacement(type->GetBaseType(), from)) { - type = PolyType::ReplaceType(type, to); + TypeCastExpr *copy = new TypeCastExpr(this); + + if (Type::EqualForReplacement(copy->type->GetBaseType(), from)) { + copy->type = PolyType::ReplaceType(copy->type, to); } - return this; + return copy; } @@ -8002,6 +8065,11 @@ SymbolExpr::SymbolExpr(Symbol *s, SourcePos p) symbol = s; } +SymbolExpr::SymbolExpr(SymbolExpr *base) + : Expr(base->pos, SymbolExprID) { + symbol = base->symbol; +} + llvm::Value * SymbolExpr::GetValue(FunctionEmitContext *ctx) const { @@ -8071,11 +8139,15 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!symbol) return NULL; + SymbolExpr *copy = new SymbolExpr(this); + + copy->symbol = new Symbol(*symbol); + if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) { - symbol->type = PolyType::ReplaceType(symbol->type, to); + copy->symbol->type = PolyType::ReplaceType(symbol->type, to); } - return this; + return copy; } @@ -8679,6 +8751,14 @@ NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count, allocType = allocType->ResolveUnboundVariability(Variability::Uniform); } +NewExpr::NewExpr(NewExpr *base) + : Expr(base->pos, NewExprID) { + allocType = base->allocType; + initExpr = base->initExpr; + countExpr = base->countExpr; + isVarying = base->isVarying; +} + llvm::Value * NewExpr::GetValue(FunctionEmitContext *ctx) const { @@ -8881,11 +8961,13 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!allocType) return this; + NewExpr *copy = new NewExpr(this); + if (Type::EqualForReplacement(allocType->GetBaseType(), from)) { - allocType = PolyType::ReplaceType(allocType, to); + copy->allocType = PolyType::ReplaceType(allocType, to); } - return this; + return copy; } diff --git a/expr.h b/expr.h index 6b655614..65dc1236 100644 --- a/expr.h +++ b/expr.h @@ -334,6 +334,7 @@ public: Expr *baseExpr, *index; private: + IndexExpr(IndexExpr *base); mutable const Type *type; mutable const PointerType *lvalueType; }; @@ -535,6 +536,8 @@ public: const Type *type; Expr *expr; +private: + TypeCastExpr(TypeCastExpr *base); }; @@ -693,6 +696,7 @@ public: int EstimateCost() const; private: + SymbolExpr(SymbolExpr *base); Symbol *symbol; }; @@ -835,6 +839,8 @@ public: instance, or whether a single allocation is performed for the entire gang of program instances.) */ bool isVarying; +private: + NewExpr(NewExpr *base); }; diff --git a/stmt.cpp b/stmt.cpp index f03944e8..e8841a87 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -79,7 +79,70 @@ Stmt::Optimize() { Stmt * Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { - return this; + Stmt *copy; + switch (getValueID()) { + case AssertStmtID: + copy = (Stmt*)new AssertStmt(*(AssertStmt*)this); + break; + case BreakStmtID: + copy = (Stmt*)new BreakStmt(*(BreakStmt*)this); + break; + case CaseStmtID: + copy = (Stmt*)new CaseStmt(*(CaseStmt*)this); + break; + case ContinueStmtID: + copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this); + break; + case DefaultStmtID: + copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this); + break; + case DeleteStmtID: + copy = (Stmt*)new DeleteStmt(*(DeleteStmt*)this); + break; + case DoStmtID: + copy = (Stmt*)new DoStmt(*(DoStmt*)this); + break; + case ExprStmtID: + copy = (Stmt*)new ExprStmt(*(ExprStmt*)this); + break; + case ForeachActiveStmtID: + copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this); + break; + case ForeachUniqueStmtID: + copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this); + break; + case ForStmtID: + copy = (Stmt*)new ForStmt(*(ForStmt*)this); + break; + case GotoStmtID: + copy = (Stmt*)new GotoStmt(*(GotoStmt*)this); + break; + case IfStmtID: + copy = (Stmt*)new IfStmt(*(IfStmt*)this); + break; + case LabeledStmtID: + copy = (Stmt*)new LabeledStmt(*(LabeledStmt*)this); + break; + case PrintStmtID: + copy = (Stmt*)new PrintStmt(*(PrintStmt*)this); + break; + case ReturnStmtID: + copy = (Stmt*)new ReturnStmt(*(ReturnStmt*)this); + break; + case StmtListID: + copy = (Stmt*)new StmtList(*(StmtList*)this); + break; + case SwitchStmtID: + copy = (Stmt*)new SwitchStmt(*(SwitchStmt*)this); + break; + case UnmaskedStmtID: + copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this); + break; + default: + FATAL("Unmatched case in ReplacePolyType (stmt)"); + copy = this; // just to silence the compiler + } + return copy; } @@ -134,6 +197,11 @@ DeclStmt::DeclStmt(const std::vector &v, SourcePos p) : Stmt(p, DeclStmtID), vars(v) { } +DeclStmt::DeclStmt(DeclStmt *base) + : Stmt(base->pos, DeclStmtID) { + vars = base->vars; +} + static bool lHasUnsizedArrays(const Type *type) { @@ -501,14 +569,16 @@ DeclStmt::TypeCheck() { Stmt * DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) { + DeclStmt *copy = new DeclStmt(this); + for (size_t i = 0; i < vars.size(); i++) { - Symbol *s = vars[i].sym; + Symbol *s = copy->vars[i].sym; if (Type::EqualForReplacement(s->type->GetBaseType(), from)) { s->type = PolyType::ReplaceType(s->type, to); } } - return this; + return copy; } @@ -1487,6 +1557,15 @@ ForeachStmt::ForeachStmt(const std::vector &lvs, stmts(s) { } +ForeachStmt::ForeachStmt(ForeachStmt *base) + : Stmt(base->pos, ForeachStmtID) { + dimVariables = base->dimVariables; + startExprs = base->startExprs; + endExprs = base->endExprs; + isTiled = base->isTiled; + stmts = base->stmts; +} + /* Given a uniform counter value in the memory location pointed to by uniformCounterPtr, compute the corresponding set of varying counter @@ -2196,14 +2275,16 @@ ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) { if (!stmts) return NULL; + ForeachStmt *copy = new ForeachStmt(this); + for (size_t i=0; itype; + const Type *t = copy->dimVariables[i]->type; if (Type::EqualForReplacement(t->GetBaseType(), from)) { t = PolyType::ReplaceType(t, to); } } - return this; + return copy; } diff --git a/stmt.h b/stmt.h index 237acaaf..7d5bb0d2 100644 --- a/stmt.h +++ b/stmt.h @@ -122,6 +122,8 @@ public: int EstimateCost() const; std::vector vars; +private: + DeclStmt(DeclStmt *base); }; @@ -291,6 +293,8 @@ public: std::vector endExprs; bool isTiled; Stmt *stmts; +private: + ForeachStmt(ForeachStmt *base); };