Actually copy the AST.

Type replacement works except for function parameters.
This commit is contained in:
2017-05-11 03:09:38 -04:00
parent f65b3e6300
commit bfe723e1b7
7 changed files with 190 additions and 219 deletions

120
ast.cpp
View File

@@ -86,7 +86,7 @@ AST::GenerateIR() {
ASTNode * ASTNode *
WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
void *data) { void *data, ASTPostCallBackFunc preUpdate) {
if (node == NULL) if (node == NULL)
return node; return node;
@@ -97,6 +97,10 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
return node; return node;
} }
if (preUpdate != NULL) {
node = preUpdate(node, data);
}
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Handle Statements // Handle Statements
if (llvm::dyn_cast<Stmt>(node) != NULL) { if (llvm::dyn_cast<Stmt>(node) != NULL) {
@@ -120,54 +124,54 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
UnmaskedStmt *ums; UnmaskedStmt *ums;
if ((es = llvm::dyn_cast<ExprStmt>(node)) != NULL) if ((es = llvm::dyn_cast<ExprStmt>(node)) != NULL)
es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data); es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data, preUpdate);
else if ((ds = llvm::dyn_cast<DeclStmt>(node)) != NULL) { else if ((ds = llvm::dyn_cast<DeclStmt>(node)) != NULL) {
for (unsigned int i = 0; i < ds->vars.size(); ++i) for (unsigned int i = 0; i < ds->vars.size(); ++i)
ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc,
postFunc, data); postFunc, data, preUpdate);
} }
else if ((is = llvm::dyn_cast<IfStmt>(node)) != NULL) { else if ((is = llvm::dyn_cast<IfStmt>(node)) != NULL) {
is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data); is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data, preUpdate);
is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc,
postFunc, data); postFunc, data, preUpdate);
is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc,
postFunc, data); postFunc, data, preUpdate);
} }
else if ((dos = llvm::dyn_cast<DoStmt>(node)) != NULL) { else if ((dos = llvm::dyn_cast<DoStmt>(node)) != NULL) {
dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc,
postFunc, data); postFunc, data, preUpdate);
dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc,
postFunc, data); postFunc, data, preUpdate);
} }
else if ((fs = llvm::dyn_cast<ForStmt>(node)) != NULL) { else if ((fs = llvm::dyn_cast<ForStmt>(node)) != NULL) {
fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data); fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data, preUpdate);
fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data); fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data, preUpdate);
fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data); fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data, preUpdate);
fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data); fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data, preUpdate);
} }
else if ((fes = llvm::dyn_cast<ForeachStmt>(node)) != NULL) { else if ((fes = llvm::dyn_cast<ForeachStmt>(node)) != NULL) {
for (unsigned int i = 0; i < fes->startExprs.size(); ++i) for (unsigned int i = 0; i < fes->startExprs.size(); ++i)
fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc,
postFunc, data); postFunc, data, preUpdate);
for (unsigned int i = 0; i < fes->endExprs.size(); ++i) for (unsigned int i = 0; i < fes->endExprs.size(); ++i)
fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc,
postFunc, data); postFunc, data, preUpdate);
fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data); fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data, preUpdate);
} }
else if ((fas = llvm::dyn_cast<ForeachActiveStmt>(node)) != NULL) { else if ((fas = llvm::dyn_cast<ForeachActiveStmt>(node)) != NULL) {
fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data); fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data, preUpdate);
} }
else if ((fus = llvm::dyn_cast<ForeachUniqueStmt>(node)) != NULL) { else if ((fus = llvm::dyn_cast<ForeachUniqueStmt>(node)) != NULL) {
fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data); fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data, preUpdate);
fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data); fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data, preUpdate);
} }
else if ((cs = llvm::dyn_cast<CaseStmt>(node)) != NULL) else if ((cs = llvm::dyn_cast<CaseStmt>(node)) != NULL)
cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data); cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data, preUpdate);
else if ((defs = llvm::dyn_cast<DefaultStmt>(node)) != NULL) else if ((defs = llvm::dyn_cast<DefaultStmt>(node)) != NULL)
defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data); defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data, preUpdate);
else if ((ss = llvm::dyn_cast<SwitchStmt>(node)) != NULL) { else if ((ss = llvm::dyn_cast<SwitchStmt>(node)) != NULL) {
ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data); ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data, preUpdate);
ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data); ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data, preUpdate);
} }
else if (llvm::dyn_cast<BreakStmt>(node) != NULL || else if (llvm::dyn_cast<BreakStmt>(node) != NULL ||
llvm::dyn_cast<ContinueStmt>(node) != NULL || llvm::dyn_cast<ContinueStmt>(node) != NULL ||
@@ -175,22 +179,22 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
// nothing // nothing
} }
else if ((ls = llvm::dyn_cast<LabeledStmt>(node)) != NULL) else if ((ls = llvm::dyn_cast<LabeledStmt>(node)) != NULL)
ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data); ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data, preUpdate);
else if ((rs = llvm::dyn_cast<ReturnStmt>(node)) != NULL) else if ((rs = llvm::dyn_cast<ReturnStmt>(node)) != NULL)
rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data); rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data, preUpdate);
else if ((sl = llvm::dyn_cast<StmtList>(node)) != NULL) { else if ((sl = llvm::dyn_cast<StmtList>(node)) != NULL) {
std::vector<Stmt *> &sls = sl->stmts; std::vector<Stmt *> &sls = sl->stmts;
for (unsigned int i = 0; i < sls.size(); ++i) for (unsigned int i = 0; i < sls.size(); ++i)
sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data); sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data, preUpdate);
} }
else if ((ps = llvm::dyn_cast<PrintStmt>(node)) != NULL) else if ((ps = llvm::dyn_cast<PrintStmt>(node)) != NULL)
ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data); ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data, preUpdate);
else if ((as = llvm::dyn_cast<AssertStmt>(node)) != NULL) else if ((as = llvm::dyn_cast<AssertStmt>(node)) != NULL)
as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data); as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data, preUpdate);
else if ((dels = llvm::dyn_cast<DeleteStmt>(node)) != NULL) else if ((dels = llvm::dyn_cast<DeleteStmt>(node)) != NULL)
dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data); dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data, preUpdate);
else if ((ums = llvm::dyn_cast<UnmaskedStmt>(node)) != NULL) else if ((ums = llvm::dyn_cast<UnmaskedStmt>(node)) != NULL)
ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data); ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data, preUpdate);
else else
FATAL("Unhandled statement type in WalkAST()"); FATAL("Unhandled statement type in WalkAST()");
} }
@@ -215,57 +219,57 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc,
NewExpr *newe; NewExpr *newe;
if ((ue = llvm::dyn_cast<UnaryExpr>(node)) != NULL) if ((ue = llvm::dyn_cast<UnaryExpr>(node)) != NULL)
ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data); ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data, preUpdate);
else if ((be = llvm::dyn_cast<BinaryExpr>(node)) != NULL) { else if ((be = llvm::dyn_cast<BinaryExpr>(node)) != NULL) {
be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data); be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data, preUpdate);
be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data); be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data, preUpdate);
} }
else if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) { else if ((ae = llvm::dyn_cast<AssignExpr>(node)) != NULL) {
ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data); ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data, preUpdate);
ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data); ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data, preUpdate);
} }
else if ((se = llvm::dyn_cast<SelectExpr>(node)) != NULL) { else if ((se = llvm::dyn_cast<SelectExpr>(node)) != NULL) {
se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data); se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data, preUpdate);
se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data); se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data, preUpdate);
se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data); se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data, preUpdate);
} }
else if ((el = llvm::dyn_cast<ExprList>(node)) != NULL) { else if ((el = llvm::dyn_cast<ExprList>(node)) != NULL) {
for (unsigned int i = 0; i < el->exprs.size(); ++i) for (unsigned int i = 0; i < el->exprs.size(); ++i)
el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc,
postFunc, data); postFunc, data, preUpdate);
} }
else if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) { else if ((fce = llvm::dyn_cast<FunctionCallExpr>(node)) != NULL) {
fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data); fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data, preUpdate);
fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data); fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data, preUpdate);
for (int k = 0; k < 3; k++) for (int k = 0; k < 3; k++)
fce->launchCountExpr[0] = (Expr *)WalkAST(fce->launchCountExpr[0], preFunc, fce->launchCountExpr[0] = (Expr *)WalkAST(fce->launchCountExpr[0], preFunc,
postFunc, data); postFunc, data, preUpdate);
} }
else if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL) { else if ((ie = llvm::dyn_cast<IndexExpr>(node)) != NULL) {
ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data); ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data, preUpdate);
ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data); ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data, preUpdate);
} }
else if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL) else if ((me = llvm::dyn_cast<MemberExpr>(node)) != NULL)
me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data); me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data, preUpdate);
else if ((tce = llvm::dyn_cast<TypeCastExpr>(node)) != NULL) else if ((tce = llvm::dyn_cast<TypeCastExpr>(node)) != NULL)
tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data, preUpdate);
else if ((re = llvm::dyn_cast<ReferenceExpr>(node)) != NULL) else if ((re = llvm::dyn_cast<ReferenceExpr>(node)) != NULL)
re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data, preUpdate);
else if ((ptrderef = llvm::dyn_cast<PtrDerefExpr>(node)) != NULL) else if ((ptrderef = llvm::dyn_cast<PtrDerefExpr>(node)) != NULL)
ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc,
data); data, preUpdate);
else if ((refderef = llvm::dyn_cast<RefDerefExpr>(node)) != NULL) else if ((refderef = llvm::dyn_cast<RefDerefExpr>(node)) != NULL)
refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc,
data); data, preUpdate);
else if ((soe = llvm::dyn_cast<SizeOfExpr>(node)) != NULL) else if ((soe = llvm::dyn_cast<SizeOfExpr>(node)) != NULL)
soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data, preUpdate);
else if ((aoe = llvm::dyn_cast<AddressOfExpr>(node)) != NULL) else if ((aoe = llvm::dyn_cast<AddressOfExpr>(node)) != NULL)
aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data); aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data, preUpdate);
else if ((newe = llvm::dyn_cast<NewExpr>(node)) != NULL) { else if ((newe = llvm::dyn_cast<NewExpr>(node)) != NULL) {
newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc, newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc,
postFunc, data); postFunc, data, preUpdate);
newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc, newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc,
postFunc, data); postFunc, data, preUpdate);
} }
else if (llvm::dyn_cast<SymbolExpr>(node) != NULL || else if (llvm::dyn_cast<SymbolExpr>(node) != NULL ||
llvm::dyn_cast<ConstExpr>(node) != NULL || llvm::dyn_cast<ConstExpr>(node) != NULL ||
@@ -536,11 +540,21 @@ lTranslatePolyNode(ASTNode *node, void *d) {
return node->ReplacePolyType(data->polyType, data->replacement); return node->ReplacePolyType(data->polyType, data->replacement);
} }
static ASTNode *
lCopyNode(ASTNode *node, void *) {
return node->Copy();
}
ASTNode * ASTNode *
TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) { TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) {
struct PolyData data; struct PolyData data;
data.polyType = polyType; data.polyType = polyType;
data.replacement = replacement; data.replacement = replacement;
return WalkAST(root, NULL, lTranslatePolyNode, &data); return WalkAST(root, NULL, lTranslatePolyNode, &data, lCopyNode);
}
ASTNode *
CopyAST(ASTNode *root) {
return WalkAST(root, NULL, NULL, NULL, lCopyNode);
} }

7
ast.h
View File

@@ -68,6 +68,8 @@ public:
pointer in place of the original ASTNode *. */ pointer in place of the original ASTNode *. */
virtual ASTNode *TypeCheck() = 0; virtual ASTNode *TypeCheck() = 0;
virtual ASTNode *Copy() = 0;
virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0; virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0;
/** Estimate the execution cost of the node (not including the cost of /** Estimate the execution cost of the node (not including the cost of
@@ -177,7 +179,8 @@ typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data);
doing so, calls postFunc, at the node. The return value from the doing so, calls postFunc, at the node. The return value from the
postFunc call is ignored. */ postFunc call is ignored. */
extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc, extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc,
ASTPostCallBackFunc postFunc, void *data); ASTPostCallBackFunc postFunc, void *data,
ASTPostCallBackFunc preUpdate = NULL);
/** Perform simple optimizations on the AST or portion thereof passed to /** Perform simple optimizations on the AST or portion thereof passed to
this function, returning the resulting AST. */ this function, returning the resulting AST. */
@@ -209,6 +212,8 @@ extern int EstimateCost(ASTNode *root);
extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement); extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement);
extern ASTNode * CopyAST(ASTNode *root);
/** Returns true if it would be safe to run the given code with an "all /** Returns true if it would be safe to run the given code with an "all
off" mask. */ off" mask. */
extern bool SafeToRunWithMaskAllOff(ASTNode *root); extern bool SafeToRunWithMaskAllOff(ASTNode *root);

156
expr.cpp
View File

@@ -113,7 +113,7 @@ Expr::GetBaseSymbol() const {
} }
Expr * Expr *
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { Expr::Copy() {
Expr *copy; Expr *copy;
switch (getValueID()) { switch (getValueID()) {
case AddressOfExprID: case AddressOfExprID:
@@ -128,9 +128,6 @@ Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
case ConstExprID: case ConstExprID:
copy = (Expr*)new ConstExpr(*(ConstExpr*)this); copy = (Expr*)new ConstExpr(*(ConstExpr*)this);
break; break;
case SymbolExprID:
copy = (Expr*)new SymbolExpr(*(SymbolExpr*)this);
break;
case PtrDerefExprID: case PtrDerefExprID:
copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this); copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this);
break; break;
@@ -143,6 +140,21 @@ Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
case FunctionCallExprID: case FunctionCallExprID:
copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this); copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this);
break; break;
case FunctionSymbolExprID:
copy = (Expr*)new FunctionSymbolExpr(*(FunctionSymbolExpr*)this);
break;
case IndexExprID:
copy = (Expr*)new IndexExpr(*(IndexExpr*)this);
break;
case StructMemberExprID:
copy = (Expr*)new StructMemberExpr(*(StructMemberExpr*)this);
break;
case VectorMemberExprID:
copy = (Expr*)new VectorMemberExpr(*(VectorMemberExpr*)this);
break;
case NewExprID:
copy = (Expr*)new NewExpr(*(NewExpr*)this);
break;
case NullPointerExprID: case NullPointerExprID:
copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this); copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this);
break; break;
@@ -155,16 +167,30 @@ Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
case SizeOfExprID: case SizeOfExprID:
copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this); copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this);
break; break;
case SymbolExprID:
copy = (Expr*)new SymbolExpr(*(SymbolExpr*)this);
break;
case SyncExprID:
copy = (Expr*)new SyncExpr(*(SyncExpr*)this);
break;
case TypeCastExprID:
copy = (Expr*)new TypeCastExpr(*(TypeCastExpr*)this);
break;
case UnaryExprID: case UnaryExprID:
copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this); copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this);
break; break;
default: default:
FATAL("Unmatched case in ReplacePolyType (expr)"); FATAL("Unmatched case in Expr::Copy");
copy = this; // just to silence the compiler copy = this; // just to silence the compiler
} }
return copy; return copy;
} }
Expr *
Expr::ReplacePolyType(const PolyType *, const Type *) {
return this;
}
#if 0 #if 0
/** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost /** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost
@@ -4196,14 +4222,6 @@ IndexExpr::IndexExpr(Expr *a, Expr *i, SourcePos p)
type = lvalueType = NULL; type = lvalueType = NULL;
} }
IndexExpr::IndexExpr(IndexExpr *base)
: Expr(base->pos, IndexExprID) {
baseExpr = base->baseExpr;
index = base->index;
type = base->type;
lvalueType = base->lvalueType;
}
/** When computing pointer values, we need to apply a per-lane offset when /** When computing pointer values, we need to apply a per-lane offset when
we have a varying pointer that is itself indexing into varying data. we have a varying pointer that is itself indexing into varying data.
@@ -4736,18 +4754,16 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (index == NULL || baseExpr == NULL) if (index == NULL || baseExpr == NULL)
return NULL; return NULL;
IndexExpr *copy = new IndexExpr(this); if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) {
copy->type = PolyType::ReplaceType(type, to);
} }
if (Type::EqualForReplacement(copy->GetLValueType()->GetBaseType(), from)) { if (Type::EqualForReplacement(GetLValueType()->GetBaseType(), from)) {
copy->lvalueType = new PointerType(to, copy->lvalueType->GetVariability(), lvalueType = new PointerType(to, lvalueType->GetVariability(),
copy->lvalueType->IsConstType()); lvalueType->IsConstType());
} }
return copy; return this;
} }
@@ -4815,27 +4831,6 @@ lIdentifierToVectorElement(char id) {
////////////////////////////////////////////////// //////////////////////////////////////////////////
// StructMemberExpr // StructMemberExpr
class StructMemberExpr : public MemberExpr
{
public:
StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(StructMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == StructMemberExprID;
}
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const StructType *getStructType() const;
};
StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p, StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue) SourcePos idpos, bool derefLValue)
: MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) { : MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) {
@@ -4987,31 +4982,6 @@ StructMemberExpr::getStructType() const {
////////////////////////////////////////////////// //////////////////////////////////////////////////
// VectorMemberExpr // VectorMemberExpr
class VectorMemberExpr : public MemberExpr
{
public:
VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(VectorMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == VectorMemberExprID;
}
llvm::Value *GetValue(FunctionEmitContext* ctx) const;
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const VectorType *exprVectorType;
const VectorType *memberType;
};
VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p, VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue) SourcePos idpos, bool derefLValue)
: MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) { : MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) {
@@ -5397,15 +5367,11 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL) if (expr == NULL)
return NULL; return NULL;
MemberExpr *copy = getValueID() == StructMemberExprID ? if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) {
(MemberExpr*) new StructMemberExpr(*(StructMemberExpr*)this) : type = PolyType::ReplaceType(type, to);
(MemberExpr*) new VectorMemberExpr(*(VectorMemberExpr*)this);
if (Type::EqualForReplacement(copy->GetType()->GetBaseType(), from)) {
copy->type = PolyType::ReplaceType(copy->type, to);
} }
return copy; return this;
} }
@@ -6351,12 +6317,6 @@ TypeCastExpr::TypeCastExpr(const Type *t, Expr *e, SourcePos p)
expr = e; expr = e;
} }
TypeCastExpr::TypeCastExpr(TypeCastExpr *base)
: Expr(base->pos, TypeCastExprID) {
type = base->type;
expr = base->expr;
}
/** Handle all the grungy details of type conversion between atomic types. /** Handle all the grungy details of type conversion between atomic types.
Given an input value in exprVal of type fromType, convert it to the Given an input value in exprVal of type fromType, convert it to the
@@ -7457,13 +7417,11 @@ TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (type == NULL) if (type == NULL)
return NULL; return NULL;
TypeCastExpr *copy = new TypeCastExpr(this); if (Type::EqualForReplacement(type->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
if (Type::EqualForReplacement(copy->type->GetBaseType(), from)) {
copy->type = PolyType::ReplaceType(copy->type, to);
} }
return copy; return this;
} }
@@ -8072,11 +8030,6 @@ SymbolExpr::SymbolExpr(Symbol *s, SourcePos p)
symbol = s; symbol = s;
} }
SymbolExpr::SymbolExpr(SymbolExpr *base)
: Expr(base->pos, SymbolExprID) {
symbol = base->symbol;
}
llvm::Value * llvm::Value *
SymbolExpr::GetValue(FunctionEmitContext *ctx) const { SymbolExpr::GetValue(FunctionEmitContext *ctx) const {
@@ -8141,23 +8094,18 @@ SymbolExpr::Optimize() {
return this; return this;
} }
/*
Expr * Expr *
SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!symbol) if (!symbol)
return NULL; return NULL;
SymbolExpr *copy = new SymbolExpr(this);
//copy->symbol = new Symbol(*symbol);
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) { if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
copy->symbol->type = PolyType::ReplaceType(symbol->type, to); symbol->type = PolyType::ReplaceType(symbol->type, to);
} }
return copy; return this;
} }
*/
int int
@@ -8760,14 +8708,6 @@ NewExpr::NewExpr(int typeQual, const Type *t, Expr *init, Expr *count,
allocType = allocType->ResolveUnboundVariability(Variability::Uniform); allocType = allocType->ResolveUnboundVariability(Variability::Uniform);
} }
NewExpr::NewExpr(NewExpr *base)
: Expr(base->pos, NewExprID) {
allocType = base->allocType;
initExpr = base->initExpr;
countExpr = base->countExpr;
isVarying = base->isVarying;
}
llvm::Value * llvm::Value *
NewExpr::GetValue(FunctionEmitContext *ctx) const { NewExpr::GetValue(FunctionEmitContext *ctx) const {
@@ -8970,13 +8910,11 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!allocType) if (!allocType)
return this; return this;
NewExpr *copy = new NewExpr(this);
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) { if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
copy->allocType = PolyType::ReplaceType(allocType, to); allocType = PolyType::ReplaceType(allocType, to);
} }
return copy; return this;
} }

54
expr.h
View File

@@ -96,6 +96,7 @@ public:
encountered, NULL should be returned. */ encountered, NULL should be returned. */
virtual Expr *TypeCheck() = 0; virtual Expr *TypeCheck() = 0;
Expr *Copy();
/** This method replaces a polymorphic type with a specific atomic type */ /** This method replaces a polymorphic type with a specific atomic type */
Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement); Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement);
@@ -334,7 +335,6 @@ public:
Expr *baseExpr, *index; Expr *baseExpr, *index;
private: private:
IndexExpr(IndexExpr *base);
mutable const Type *type; mutable const Type *type;
mutable const PointerType *lvalueType; mutable const PointerType *lvalueType;
}; };
@@ -386,6 +386,51 @@ protected:
mutable const Type *type, *lvalueType; mutable const Type *type, *lvalueType;
}; };
class StructMemberExpr : public MemberExpr
{
public:
StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(StructMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == StructMemberExprID;
}
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const StructType *getStructType() const;
};
class VectorMemberExpr : public MemberExpr
{
public:
VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, bool derefLValue);
static inline bool classof(VectorMemberExpr const*) { return true; }
static inline bool classof(ASTNode const* N) {
return N->getValueID() == VectorMemberExprID;
}
llvm::Value *GetValue(FunctionEmitContext* ctx) const;
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
const Type *GetType() const;
const Type *GetLValueType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const VectorType *exprVectorType;
const VectorType *memberType;
};
/** @brief Expression representing a compile-time constant value. /** @brief Expression representing a compile-time constant value.
@@ -536,8 +581,6 @@ public:
const Type *type; const Type *type;
Expr *expr; Expr *expr;
private:
TypeCastExpr(TypeCastExpr *base);
}; };
@@ -691,12 +734,11 @@ public:
Symbol *GetBaseSymbol() const; Symbol *GetBaseSymbol() const;
Expr *TypeCheck(); Expr *TypeCheck();
Expr *Optimize(); Expr *Optimize();
//Expr *ReplacePolyType(const PolyType *from, const Type *to); Expr *ReplacePolyType(const PolyType *from, const Type *to);
void Print() const; void Print() const;
int EstimateCost() const; int EstimateCost() const;
private: private:
SymbolExpr(SymbolExpr *base);
Symbol *symbol; Symbol *symbol;
}; };
@@ -839,8 +881,6 @@ public:
instance, or whether a single allocation is performed for the instance, or whether a single allocation is performed for the
entire gang of program instances.) */ entire gang of program instances.) */
bool isVarying; bool isVarying;
private:
NewExpr(NewExpr *base);
}; };

View File

@@ -127,7 +127,6 @@ Function::Function(Symbol *s, Stmt *c) {
const FunctionType *type = CastType<FunctionType>(sym->type); const FunctionType *type = CastType<FunctionType>(sym->type);
Assert(type != NULL); Assert(type != NULL);
printf("Function %s symbol types: ", sym->name.c_str());
for (int i = 0; i < type->GetNumParameters(); ++i) { for (int i = 0; i < type->GetNumParameters(); ++i) {
const char *paramName = type->GetParameterName(i).c_str(); const char *paramName = type->GetParameterName(i).c_str();
Symbol *sym = m->symbolTable->LookupVariable(paramName); Symbol *sym = m->symbolTable->LookupVariable(paramName);
@@ -136,14 +135,10 @@ Function::Function(Symbol *s, Stmt *c) {
args.push_back(sym); args.push_back(sym);
const Type *t = type->GetParameterType(i); const Type *t = type->GetParameterType(i);
printf(" %s: %s==%s, ", sym->name.c_str(),
t->GetString().c_str(),
sym->type->GetString().c_str());
if (sym != NULL && CastType<ReferenceType>(t) == NULL) if (sym != NULL && CastType<ReferenceType>(t) == NULL)
sym->parentFunction = this; sym->parentFunction = this;
} }
printf("\n");
if (type->isTask if (type->isTask
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
@@ -657,14 +652,14 @@ Function::ExpandPolyArguments(SymbolTable *symbolTable) const {
const FunctionType *func = CastType<FunctionType>(sym->type); const FunctionType *func = CastType<FunctionType>(sym->type);
if (g->debugPrint) {
printf("%s before replacing anything:\n", sym->name.c_str());
code->Print(0);
}
for (size_t i=0; i<versions.size(); i++) { for (size_t i=0; i<versions.size(); i++) {
if (g->debugPrint) {
printf("%s before replacing anything:\n", sym->name.c_str());
code->Print(0);
}
const FunctionType *ft = CastType<FunctionType>(versions[i]->type); const FunctionType *ft = CastType<FunctionType>(versions[i]->type);
Stmt *ncode = code;
Stmt *ncode = (Stmt*)CopyAST(code);
for (int j=0; j<ft->GetNumParameters(); j++) { for (int j=0; j<ft->GetNumParameters(); j++) {
if (func->GetParameterType(j)->IsPolymorphicType()) { if (func->GetParameterType(j)->IsPolymorphicType()) {

View File

@@ -35,6 +35,7 @@
@brief File with definitions classes related to statements in the language @brief File with definitions classes related to statements in the language
*/ */
#include "ast.h"
#include "stmt.h" #include "stmt.h"
#include "ctx.h" #include "ctx.h"
#include "util.h" #include "util.h"
@@ -78,7 +79,7 @@ Stmt::Optimize() {
} }
Stmt * Stmt *
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { Stmt::Copy() {
Stmt *copy; Stmt *copy;
switch (getValueID()) { switch (getValueID()) {
case AssertStmtID: case AssertStmtID:
@@ -93,6 +94,9 @@ Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
case ContinueStmtID: case ContinueStmtID:
copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this); copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this);
break; break;
case DeclStmtID:
copy = (Stmt*)new DeclStmt(*(DeclStmt*)this);
break;
case DefaultStmtID: case DefaultStmtID:
copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this); copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this);
break; break;
@@ -105,12 +109,12 @@ Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
case ExprStmtID: case ExprStmtID:
copy = (Stmt*)new ExprStmt(*(ExprStmt*)this); copy = (Stmt*)new ExprStmt(*(ExprStmt*)this);
break; break;
case ForeachStmtID:
copy = (Stmt*)new ForeachStmt(*(ForeachStmt*)this);
break;
case ForeachActiveStmtID: case ForeachActiveStmtID:
copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this); copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this);
break; break;
case ForeachStmtID:
copy = (Stmt*)new ForeachStmt(*(ForeachStmt*)this);
break;
case ForeachUniqueStmtID: case ForeachUniqueStmtID:
copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this); copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this);
break; break;
@@ -142,12 +146,17 @@ Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this); copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this);
break; break;
default: default:
FATAL("Unmatched case in ReplacePolyType (stmt)"); FATAL("Unmatched case in Stmt::Copy");
copy = this; // just to silence the compiler copy = this; // just to silence the compiler
} }
return copy; return copy;
} }
Stmt *
Stmt::ReplacePolyType(const PolyType *, const Type *) {
return this;
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// ExprStmt // ExprStmt
@@ -200,11 +209,6 @@ DeclStmt::DeclStmt(const std::vector<VariableDeclaration> &v, SourcePos p)
: Stmt(p, DeclStmtID), vars(v) { : Stmt(p, DeclStmtID), vars(v) {
} }
DeclStmt::DeclStmt(DeclStmt *base)
: Stmt(base->pos, DeclStmtID) {
vars = base->vars;
}
static bool static bool
lHasUnsizedArrays(const Type *type) { lHasUnsizedArrays(const Type *type) {
@@ -572,16 +576,14 @@ DeclStmt::TypeCheck() {
Stmt * Stmt *
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) { DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
DeclStmt *copy = new DeclStmt(this);
for (size_t i = 0; i < vars.size(); i++) { for (size_t i = 0; i < vars.size(); i++) {
Symbol *s = copy->vars[i].sym; Symbol *s = vars[i].sym;
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) { if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
s->type = PolyType::ReplaceType(s->type, to); s->type = PolyType::ReplaceType(s->type, to);
} }
} }
return copy; return this;
} }
@@ -2277,25 +2279,6 @@ ForeachStmt::TypeCheck() {
return anyErrors ? NULL : this; return anyErrors ? NULL : this;
} }
/*
Stmt *
ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
if (!stmts)
return NULL;
ForeachStmt *copy = new ForeachStmt(this);
for (size_t i=0; i<dimVariables.size(); i++) {
const Type *t = copy->dimVariables[i]->type;
if (Type::EqualForReplacement(t->GetBaseType(), from)) {
copy->dimVariables[i]->type = PolyType::ReplaceType(t, to);
}
}
return copy;
}
*/
int int
ForeachStmt::EstimateCost() const { ForeachStmt::EstimateCost() const {

6
stmt.h
View File

@@ -70,6 +70,7 @@ public:
// Stmts don't have anything to do here. // Stmts don't have anything to do here.
virtual Stmt *Optimize(); virtual Stmt *Optimize();
virtual Stmt *TypeCheck() = 0; virtual Stmt *TypeCheck() = 0;
Stmt *Copy();
Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement); Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement);
}; };
@@ -122,8 +123,6 @@ public:
int EstimateCost() const; int EstimateCost() const;
std::vector<VariableDeclaration> vars; std::vector<VariableDeclaration> vars;
private:
DeclStmt(DeclStmt *base);
}; };
@@ -285,7 +284,6 @@ public:
void Print(int indent) const; void Print(int indent) const;
Stmt *TypeCheck(); Stmt *TypeCheck();
//Stmt *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const; int EstimateCost() const;
std::vector<Symbol *> dimVariables; std::vector<Symbol *> dimVariables;
@@ -293,8 +291,6 @@ public:
std::vector<Expr *> endExprs; std::vector<Expr *> endExprs;
bool isTiled; bool isTiled;
Stmt *stmts; Stmt *stmts;
private:
//ForeachStmt(ForeachStmt *base);
}; };