diff --git a/ast.cpp b/ast.cpp index 064226b4..5e71b9b6 100644 --- a/ast.cpp +++ b/ast.cpp @@ -86,7 +86,7 @@ AST::GenerateIR() { ASTNode * WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, - void *data) { + void *data, ASTPostCallBackFunc preUpdate) { if (node == NULL) return node; @@ -97,6 +97,10 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, return node; } + if (preUpdate != NULL) { + node = preUpdate(node, data); + } + //////////////////////////////////////////////////////////////////////////// // Handle Statements if (llvm::dyn_cast(node) != NULL) { @@ -120,54 +124,54 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, UnmaskedStmt *ums; if ((es = llvm::dyn_cast(node)) != NULL) - es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data); + es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data, preUpdate); else if ((ds = llvm::dyn_cast(node)) != NULL) { for (unsigned int i = 0; i < ds->vars.size(); ++i) ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((is = llvm::dyn_cast(node)) != NULL) { - is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data); + is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data, preUpdate); is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, - postFunc, data); + postFunc, data, preUpdate); is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((dos = llvm::dyn_cast(node)) != NULL) { dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, - postFunc, data); + postFunc, data, preUpdate); dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((fs = llvm::dyn_cast(node)) != NULL) { - fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data); - fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data); - fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data); - fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data); + fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data, preUpdate); + fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data, preUpdate); + fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data, preUpdate); + fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data, preUpdate); } else if ((fes = llvm::dyn_cast(node)) != NULL) { for (unsigned int i = 0; i < fes->startExprs.size(); ++i) fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, - postFunc, data); + postFunc, data, preUpdate); for (unsigned int i = 0; i < fes->endExprs.size(); ++i) fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, - postFunc, data); - fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data); + postFunc, data, preUpdate); + fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data, preUpdate); } else if ((fas = llvm::dyn_cast(node)) != NULL) { - fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data); + fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data, preUpdate); } else if ((fus = llvm::dyn_cast(node)) != NULL) { - fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data); - fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data); + fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data, preUpdate); + fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data, preUpdate); } else if ((cs = llvm::dyn_cast(node)) != NULL) - cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data); + cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data, preUpdate); else if ((defs = llvm::dyn_cast(node)) != NULL) - defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data); + defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data, preUpdate); else if ((ss = llvm::dyn_cast(node)) != NULL) { - ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data); - ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data); + ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data, preUpdate); + ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data, preUpdate); } else if (llvm::dyn_cast(node) != NULL || llvm::dyn_cast(node) != NULL || @@ -175,22 +179,22 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, // nothing } else if ((ls = llvm::dyn_cast(node)) != NULL) - ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data); + ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data, preUpdate); else if ((rs = llvm::dyn_cast(node)) != NULL) - rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data); + rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data, preUpdate); else if ((sl = llvm::dyn_cast(node)) != NULL) { std::vector &sls = sl->stmts; for (unsigned int i = 0; i < sls.size(); ++i) - sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data); + sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data, preUpdate); } else if ((ps = llvm::dyn_cast(node)) != NULL) - ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data); + ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data, preUpdate); else if ((as = llvm::dyn_cast(node)) != NULL) - as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data); + as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data, preUpdate); else if ((dels = llvm::dyn_cast(node)) != NULL) - dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data); + dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data, preUpdate); else if ((ums = llvm::dyn_cast(node)) != NULL) - ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data); + ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data, preUpdate); else FATAL("Unhandled statement type in WalkAST()"); } @@ -215,57 +219,57 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, NewExpr *newe; if ((ue = llvm::dyn_cast(node)) != NULL) - ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data); + ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data, preUpdate); else if ((be = llvm::dyn_cast(node)) != NULL) { - be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data); - be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data); + be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data, preUpdate); + be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data, preUpdate); } else if ((ae = llvm::dyn_cast(node)) != NULL) { - ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data); - ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data); + ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data, preUpdate); + ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data, preUpdate); } else if ((se = llvm::dyn_cast(node)) != NULL) { - se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data); - se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data); - se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data); + se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data, preUpdate); + se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data, preUpdate); + se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data, preUpdate); } else if ((el = llvm::dyn_cast(node)) != NULL) { for (unsigned int i = 0; i < el->exprs.size(); ++i) el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((fce = llvm::dyn_cast(node)) != NULL) { - fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data); - fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data); + fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data, preUpdate); + fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data, preUpdate); for (int k = 0; k < 3; k++) fce->launchCountExpr[0] = (Expr *)WalkAST(fce->launchCountExpr[0], preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((ie = llvm::dyn_cast(node)) != NULL) { - ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data); - ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data); + ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data, preUpdate); + ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data, preUpdate); } else if ((me = llvm::dyn_cast(node)) != NULL) - me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data); + me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data, preUpdate); else if ((tce = llvm::dyn_cast(node)) != NULL) - tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); + tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data, preUpdate); else if ((re = llvm::dyn_cast(node)) != NULL) - re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); + re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data, preUpdate); else if ((ptrderef = llvm::dyn_cast(node)) != NULL) ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, - data); + data, preUpdate); else if ((refderef = llvm::dyn_cast(node)) != NULL) refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, - data); + data, preUpdate); else if ((soe = llvm::dyn_cast(node)) != NULL) - soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); + soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data, preUpdate); else if ((aoe = llvm::dyn_cast(node)) != NULL) - aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data); + aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data, preUpdate); else if ((newe = llvm::dyn_cast(node)) != NULL) { newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc, - postFunc, data); + postFunc, data, preUpdate); newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if (llvm::dyn_cast(node) != NULL || llvm::dyn_cast(node) != NULL || @@ -536,11 +540,21 @@ lTranslatePolyNode(ASTNode *node, void *d) { return node->ReplacePolyType(data->polyType, data->replacement); } +static ASTNode * +lCopyNode(ASTNode *node, void *) { + return node->Copy(); +} + ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) { struct PolyData data; data.polyType = polyType; data.replacement = replacement; - return WalkAST(root, NULL, lTranslatePolyNode, &data); + return WalkAST(root, NULL, lTranslatePolyNode, &data, lCopyNode); +} + +ASTNode * +CopyAST(ASTNode *root) { + return WalkAST(root, NULL, NULL, NULL, lCopyNode); } diff --git a/ast.h b/ast.h index 91d3dd5b..5c40454f 100644 --- a/ast.h +++ b/ast.h @@ -68,6 +68,8 @@ public: pointer in place of the original ASTNode *. */ virtual ASTNode *TypeCheck() = 0; + virtual ASTNode *Copy() = 0; + virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0; /** Estimate the execution cost of the node (not including the cost of @@ -177,7 +179,8 @@ typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data); doing so, calls postFunc, at the node. The return value from the postFunc call is ignored. */ extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc, - ASTPostCallBackFunc postFunc, void *data); + ASTPostCallBackFunc postFunc, void *data, + ASTPostCallBackFunc preUpdate = NULL); /** Perform simple optimizations on the AST or portion thereof passed to this function, returning the resulting AST. */ @@ -209,6 +212,8 @@ extern int EstimateCost(ASTNode *root); extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement); +extern ASTNode * CopyAST(ASTNode *root); + /** Returns true if it would be safe to run the given code with an "all off" mask. */ extern bool SafeToRunWithMaskAllOff(ASTNode *root); diff --git a/expr.cpp b/expr.cpp index aed5d45c..8fe870c1 100644 --- a/expr.cpp +++ b/expr.cpp @@ -113,7 +113,7 @@ Expr::GetBaseSymbol() const { } Expr * -Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { +Expr::Copy() { Expr *copy; switch (getValueID()) { case AddressOfExprID: @@ -128,9 +128,6 @@ Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { case ConstExprID: copy = (Expr*)new ConstExpr(*(ConstExpr*)this); break; - case SymbolExprID: - copy = (Expr*)new SymbolExpr(*(SymbolExpr*)this); - break; case PtrDerefExprID: copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this); break; @@ -143,6 +140,21 @@ Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { case FunctionCallExprID: copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this); break; + case FunctionSymbolExprID: + copy = (Expr*)new FunctionSymbolExpr(*(FunctionSymbolExpr*)this); + break; + case IndexExprID: + copy = (Expr*)new IndexExpr(*(IndexExpr*)this); + break; + case StructMemberExprID: + copy = (Expr*)new StructMemberExpr(*(StructMemberExpr*)this); + break; + case VectorMemberExprID: + copy = (Expr*)new VectorMemberExpr(*(VectorMemberExpr*)this); + break; + case NewExprID: + copy = (Expr*)new NewExpr(*(NewExpr*)this); + break; case NullPointerExprID: copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this); break; @@ -155,16 +167,30 @@ Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { case SizeOfExprID: copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this); break; + case SymbolExprID: + copy = (Expr*)new SymbolExpr(*(SymbolExpr*)this); + break; + case SyncExprID: + copy = (Expr*)new SyncExpr(*(SyncExpr*)this); + break; + case TypeCastExprID: + copy = (Expr*)new TypeCastExpr(*(TypeCastExpr*)this); + break; case UnaryExprID: copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this); break; default: - FATAL("Unmatched case in ReplacePolyType (expr)"); + FATAL("Unmatched case in Expr::Copy"); copy = this; // just to silence the compiler } return copy; } +Expr * +Expr::ReplacePolyType(const PolyType *, const Type *) { + return this; +} + #if 0 /** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost @@ -4196,14 +4222,6 @@ IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p) type = lvalueType = NULL; } -IndexExpr::IndexExpr(IndexExpr *base) - : Expr(base->pos, IndexExprID) { - baseExpr = base->baseExpr; - index = base->index; - type = base->type; - lvalueType = base->lvalueType; -} - /** When computing pointer values, we need to apply a per-lane offset when we have a varying pointer that is itself indexing into varying data. @@ -4736,18 +4754,16 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (index == NULL || baseExpr == NULL) return NULL; - IndexExpr *copy = new IndexExpr(this); - - if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) { - copy->type = PolyType::ReplaceType(type, to); + if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) { + type = PolyType::ReplaceType(type, to); } - if (Type::EqualForReplacement(copy->GetLValueType()->GetBaseType(), from)) { - copy->lvalueType = new PointerType(to, copy->lvalueType->GetVariability(), - copy->lvalueType->IsConstType()); + if (Type::EqualForReplacement(GetLValueType()->GetBaseType(), from)) { + lvalueType = new PointerType(to, lvalueType->GetVariability(), + lvalueType->IsConstType()); } - return copy; + return this; } @@ -4815,27 +4831,6 @@ lIdentifierToVectorElement(char id) { ////////////////////////////////////////////////// // StructMemberExpr -class StructMemberExpr : public MemberExpr -{ -public: - StructMemberExpr(Expr *e, const char *id, SourcePos p, - SourcePos idpos, bool derefLValue); - - static inline bool classof(StructMemberExpr const*) { return true; } - static inline bool classof(ASTNode const* N) { - return N->getValueID() == StructMemberExprID; - } - - const Type *GetType() const; - const Type *GetLValueType() const; - int getElementNumber() const; - const Type *getElementType() const; - -private: - const StructType *getStructType() const; -}; - - StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue) : MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) { @@ -4987,31 +4982,6 @@ StructMemberExpr::getStructType() const { ////////////////////////////////////////////////// // VectorMemberExpr -class VectorMemberExpr : public MemberExpr -{ -public: - VectorMemberExpr(Expr *e, const char *id, SourcePos p, - SourcePos idpos, bool derefLValue); - - static inline bool classof(VectorMemberExpr const*) { return true; } - static inline bool classof(ASTNode const* N) { - return N->getValueID() == VectorMemberExprID; - } - - llvm::Value *GetValue(FunctionEmitContext* ctx) const; - llvm::Value *GetLValue(FunctionEmitContext* ctx) const; - const Type *GetType() const; - const Type *GetLValueType() const; - - int getElementNumber() const; - const Type *getElementType() const; - -private: - const VectorType *exprVectorType; - const VectorType *memberType; -}; - - VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue) : MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) { @@ -5397,15 +5367,11 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (expr == NULL) return NULL; - MemberExpr *copy = getValueID() == StructMemberExprID ? - (MemberExpr*) new StructMemberExpr(*(StructMemberExpr*)this) : - (MemberExpr*) new VectorMemberExpr(*(VectorMemberExpr*)this); - - if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) { - copy->type = PolyType::ReplaceType(copy->type, to); + if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) { + type = PolyType::ReplaceType(type, to); } - return copy; + return this; } @@ -6351,12 +6317,6 @@ TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p) expr = e; } -TypeCastExpr::TypeCastExpr(TypeCastExpr *base) - : Expr(base->pos, TypeCastExprID) { - type = base->type; - expr = base->expr; -} - /** Handle all the grungy details of type conversion between atomic types. Given an input value in exprVal of type fromType, convert it to the @@ -7457,13 +7417,11 @@ TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (type == NULL) return NULL; - TypeCastExpr *copy = new TypeCastExpr(this); - - if (Type::EqualForReplacement(copy->type->GetBaseType(), from)) { - copy->type = PolyType::ReplaceType(copy->type, to); + if (Type::EqualForReplacement(type->GetBaseType(), from)) { + type = PolyType::ReplaceType(type, to); } - return copy; + return this; } @@ -8072,11 +8030,6 @@ SymbolExpr::SymbolExpr(Symbol *s, SourcePos p) symbol = s; } -SymbolExpr::SymbolExpr(SymbolExpr *base) - : Expr(base->pos, SymbolExprID) { - symbol = base->symbol; -} - llvm::Value * SymbolExpr::GetValue(FunctionEmitContext *ctx) const { @@ -8141,23 +8094,18 @@ SymbolExpr::Optimize() { return this; } -/* Expr * SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!symbol) return NULL; - SymbolExpr *copy = new SymbolExpr(this); - - //copy->symbol = new Symbol(*symbol); if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) { - copy->symbol->type = PolyType::ReplaceType(symbol->type, to); + symbol->type = PolyType::ReplaceType(symbol->type, to); } - return copy; + return this; } -*/ int @@ -8760,14 +8708,6 @@ NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count, allocType = allocType->ResolveUnboundVariability(Variability::Uniform); } -NewExpr::NewExpr(NewExpr *base) - : Expr(base->pos, NewExprID) { - allocType = base->allocType; - initExpr = base->initExpr; - countExpr = base->countExpr; - isVarying = base->isVarying; -} - llvm::Value * NewExpr::GetValue(FunctionEmitContext *ctx) const { @@ -8970,13 +8910,11 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!allocType) return this; - NewExpr *copy = new NewExpr(this); - if (Type::EqualForReplacement(allocType->GetBaseType(), from)) { - copy->allocType = PolyType::ReplaceType(allocType, to); + allocType = PolyType::ReplaceType(allocType, to); } - return copy; + return this; } diff --git a/expr.h b/expr.h index 9ee9710d..e80926e9 100644 --- a/expr.h +++ b/expr.h @@ -96,6 +96,7 @@ public: encountered, NULL should be returned. */ virtual Expr *TypeCheck() = 0; + Expr *Copy(); /** This method replaces a polymorphic type with a specific atomic type */ Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement); @@ -334,7 +335,6 @@ public: Expr *baseExpr, *index; private: - IndexExpr(IndexExpr *base); mutable const Type *type; mutable const PointerType *lvalueType; }; @@ -386,6 +386,51 @@ protected: mutable const Type *type, *lvalueType; }; +class StructMemberExpr : public MemberExpr +{ +public: + StructMemberExpr(Expr *e, const char *id, SourcePos p, + SourcePos idpos, bool derefLValue); + + static inline bool classof(StructMemberExpr const*) { return true; } + static inline bool classof(ASTNode const* N) { + return N->getValueID() == StructMemberExprID; + } + + const Type *GetType() const; + const Type *GetLValueType() const; + int getElementNumber() const; + const Type *getElementType() const; + +private: + const StructType *getStructType() const; +}; + + +class VectorMemberExpr : public MemberExpr +{ +public: + VectorMemberExpr(Expr *e, const char *id, SourcePos p, + SourcePos idpos, bool derefLValue); + + static inline bool classof(VectorMemberExpr const*) { return true; } + static inline bool classof(ASTNode const* N) { + return N->getValueID() == VectorMemberExprID; + } + + llvm::Value *GetValue(FunctionEmitContext* ctx) const; + llvm::Value *GetLValue(FunctionEmitContext* ctx) const; + const Type *GetType() const; + const Type *GetLValueType() const; + + int getElementNumber() const; + const Type *getElementType() const; + +private: + const VectorType *exprVectorType; + const VectorType *memberType; +}; + /** @brief Expression representing a compile-time constant value. @@ -536,8 +581,6 @@ public: const Type *type; Expr *expr; -private: - TypeCastExpr(TypeCastExpr *base); }; @@ -691,12 +734,11 @@ public: Symbol *GetBaseSymbol() const; Expr *TypeCheck(); Expr *Optimize(); - //Expr *ReplacePolyType(const PolyType *from, const Type *to); + Expr *ReplacePolyType(const PolyType *from, const Type *to); void Print() const; int EstimateCost() const; private: - SymbolExpr(SymbolExpr *base); Symbol *symbol; }; @@ -839,8 +881,6 @@ public: instance, or whether a single allocation is performed for the entire gang of program instances.) */ bool isVarying; -private: - NewExpr(NewExpr *base); }; diff --git a/func.cpp b/func.cpp index da7ba0ca..9dca2b34 100644 --- a/func.cpp +++ b/func.cpp @@ -127,7 +127,6 @@ Function::Function(Symbol *s, Stmt *c) { const FunctionType *type = CastType(sym->type); Assert(type != NULL); - printf("Function %s symbol types: ", sym->name.c_str()); for (int i = 0; i < type->GetNumParameters(); ++i) { const char *paramName = type->GetParameterName(i).c_str(); Symbol *sym = m->symbolTable->LookupVariable(paramName); @@ -136,14 +135,10 @@ Function::Function(Symbol *s, Stmt *c) { args.push_back(sym); const Type *t = type->GetParameterType(i); - printf(" %s: %s==%s, ", sym->name.c_str(), - t->GetString().c_str(), - sym->type->GetString().c_str()); if (sym != NULL && CastType(t) == NULL) sym->parentFunction = this; } - printf("\n"); if (type->isTask #ifdef ISPC_NVPTX_ENABLED @@ -657,14 +652,14 @@ Function::ExpandPolyArguments(SymbolTable *symbolTable) const { const FunctionType *func = CastType(sym->type); - if (g->debugPrint) { - printf("%s before replacing anything:\n", sym->name.c_str()); - code->Print(0); - } - for (size_t i=0; idebugPrint) { + printf("%s before replacing anything:\n", sym->name.c_str()); + code->Print(0); + } const FunctionType *ft = CastType(versions[i]->type); - Stmt *ncode = code; + + Stmt *ncode = (Stmt*)CopyAST(code); for (int j=0; jGetNumParameters(); j++) { if (func->GetParameterType(j)->IsPolymorphicType()) { diff --git a/stmt.cpp b/stmt.cpp index 0e116dba..0d84c704 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -35,6 +35,7 @@ @brief File with definitions classes related to statements in the language */ +#include "ast.h" #include "stmt.h" #include "ctx.h" #include "util.h" @@ -78,7 +79,7 @@ Stmt::Optimize() { } Stmt * -Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { +Stmt::Copy() { Stmt *copy; switch (getValueID()) { case AssertStmtID: @@ -93,6 +94,9 @@ Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { case ContinueStmtID: copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this); break; + case DeclStmtID: + copy = (Stmt*)new DeclStmt(*(DeclStmt*)this); + break; case DefaultStmtID: copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this); break; @@ -105,12 +109,12 @@ Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { case ExprStmtID: copy = (Stmt*)new ExprStmt(*(ExprStmt*)this); break; - case ForeachStmtID: - copy = (Stmt*)new ForeachStmt(*(ForeachStmt*)this); - break; case ForeachActiveStmtID: copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this); break; + case ForeachStmtID: + copy = (Stmt*)new ForeachStmt(*(ForeachStmt*)this); + break; case ForeachUniqueStmtID: copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this); break; @@ -142,12 +146,17 @@ Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this); break; default: - FATAL("Unmatched case in ReplacePolyType (stmt)"); + FATAL("Unmatched case in Stmt::Copy"); copy = this; // just to silence the compiler } return copy; } +Stmt * +Stmt::ReplacePolyType(const PolyType *, const Type *) { + return this; +} + /////////////////////////////////////////////////////////////////////////// // ExprStmt @@ -200,11 +209,6 @@ DeclStmt::DeclStmt(const std::vector &v, SourcePos p) : Stmt(p, DeclStmtID), vars(v) { } -DeclStmt::DeclStmt(DeclStmt *base) - : Stmt(base->pos, DeclStmtID) { - vars = base->vars; -} - static bool lHasUnsizedArrays(const Type *type) { @@ -572,16 +576,14 @@ DeclStmt::TypeCheck() { Stmt * DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) { - DeclStmt *copy = new DeclStmt(this); - for (size_t i = 0; i < vars.size(); i++) { - Symbol *s = copy->vars[i].sym; + Symbol *s = vars[i].sym; if (Type::EqualForReplacement(s->type->GetBaseType(), from)) { s->type = PolyType::ReplaceType(s->type, to); } } - return copy; + return this; } @@ -2277,25 +2279,6 @@ ForeachStmt::TypeCheck() { return anyErrors ? NULL : this; } -/* -Stmt * -ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) { - if (!stmts) - return NULL; - - ForeachStmt *copy = new ForeachStmt(this); - - for (size_t i=0; idimVariables[i]->type; - if (Type::EqualForReplacement(t->GetBaseType(), from)) { - copy->dimVariables[i]->type = PolyType::ReplaceType(t, to); - } - } - - return copy; -} -*/ - int ForeachStmt::EstimateCost() const { diff --git a/stmt.h b/stmt.h index ee2cd8fc..371f5915 100644 --- a/stmt.h +++ b/stmt.h @@ -70,6 +70,7 @@ public: // Stmts don't have anything to do here. virtual Stmt *Optimize(); virtual Stmt *TypeCheck() = 0; + Stmt *Copy(); Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement); }; @@ -122,8 +123,6 @@ public: int EstimateCost() const; std::vector vars; -private: - DeclStmt(DeclStmt *base); }; @@ -285,7 +284,6 @@ public: void Print(int indent) const; Stmt *TypeCheck(); - //Stmt *ReplacePolyType(const PolyType *from, const Type *to); int EstimateCost() const; std::vector dimVariables; @@ -293,8 +291,6 @@ public: std::vector endExprs; bool isTiled; Stmt *stmts; -private: - //ForeachStmt(ForeachStmt *base); };