diff --git a/ast.cpp b/ast.cpp index 064226b4..5e71b9b6 100644 --- a/ast.cpp +++ b/ast.cpp @@ -86,7 +86,7 @@ AST::GenerateIR() { ASTNode * WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, - void *data) { + void *data, ASTPostCallBackFunc preUpdate) { if (node == NULL) return node; @@ -97,6 +97,10 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, return node; } + if (preUpdate != NULL) { + node = preUpdate(node, data); + } + //////////////////////////////////////////////////////////////////////////// // Handle Statements if (llvm::dyn_cast(node) != NULL) { @@ -120,54 +124,54 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, UnmaskedStmt *ums; if ((es = llvm::dyn_cast(node)) != NULL) - es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data); + es->expr = (Expr *)WalkAST(es->expr, preFunc, postFunc, data, preUpdate); else if ((ds = llvm::dyn_cast(node)) != NULL) { for (unsigned int i = 0; i < ds->vars.size(); ++i) ds->vars[i].init = (Expr *)WalkAST(ds->vars[i].init, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((is = llvm::dyn_cast(node)) != NULL) { - is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data); + is->test = (Expr *)WalkAST(is->test, preFunc, postFunc, data, preUpdate); is->trueStmts = (Stmt *)WalkAST(is->trueStmts, preFunc, - postFunc, data); + postFunc, data, preUpdate); is->falseStmts = (Stmt *)WalkAST(is->falseStmts, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((dos = llvm::dyn_cast(node)) != NULL) { dos->testExpr = (Expr *)WalkAST(dos->testExpr, preFunc, - postFunc, data); + postFunc, data, preUpdate); dos->bodyStmts = (Stmt *)WalkAST(dos->bodyStmts, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((fs = llvm::dyn_cast(node)) != NULL) { - fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data); - fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data); - fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data); - fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data); + fs->init = (Stmt *)WalkAST(fs->init, preFunc, postFunc, data, preUpdate); + fs->test = (Expr *)WalkAST(fs->test, preFunc, postFunc, data, preUpdate); + fs->step = (Stmt *)WalkAST(fs->step, preFunc, postFunc, data, preUpdate); + fs->stmts = (Stmt *)WalkAST(fs->stmts, preFunc, postFunc, data, preUpdate); } else if ((fes = llvm::dyn_cast(node)) != NULL) { for (unsigned int i = 0; i < fes->startExprs.size(); ++i) fes->startExprs[i] = (Expr *)WalkAST(fes->startExprs[i], preFunc, - postFunc, data); + postFunc, data, preUpdate); for (unsigned int i = 0; i < fes->endExprs.size(); ++i) fes->endExprs[i] = (Expr *)WalkAST(fes->endExprs[i], preFunc, - postFunc, data); - fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data); + postFunc, data, preUpdate); + fes->stmts = (Stmt *)WalkAST(fes->stmts, preFunc, postFunc, data, preUpdate); } else if ((fas = llvm::dyn_cast(node)) != NULL) { - fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data); + fas->stmts = (Stmt *)WalkAST(fas->stmts, preFunc, postFunc, data, preUpdate); } else if ((fus = llvm::dyn_cast(node)) != NULL) { - fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data); - fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data); + fus->expr = (Expr *)WalkAST(fus->expr, preFunc, postFunc, data, preUpdate); + fus->stmts = (Stmt *)WalkAST(fus->stmts, preFunc, postFunc, data, preUpdate); } else if ((cs = llvm::dyn_cast(node)) != NULL) - cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data); + cs->stmts = (Stmt *)WalkAST(cs->stmts, preFunc, postFunc, data, preUpdate); else if ((defs = llvm::dyn_cast(node)) != NULL) - defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data); + defs->stmts = (Stmt *)WalkAST(defs->stmts, preFunc, postFunc, data, preUpdate); else if ((ss = llvm::dyn_cast(node)) != NULL) { - ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data); - ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data); + ss->expr = (Expr *)WalkAST(ss->expr, preFunc, postFunc, data, preUpdate); + ss->stmts = (Stmt *)WalkAST(ss->stmts, preFunc, postFunc, data, preUpdate); } else if (llvm::dyn_cast(node) != NULL || llvm::dyn_cast(node) != NULL || @@ -175,22 +179,22 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, // nothing } else if ((ls = llvm::dyn_cast(node)) != NULL) - ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data); + ls->stmt = (Stmt *)WalkAST(ls->stmt, preFunc, postFunc, data, preUpdate); else if ((rs = llvm::dyn_cast(node)) != NULL) - rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data); + rs->expr = (Expr *)WalkAST(rs->expr, preFunc, postFunc, data, preUpdate); else if ((sl = llvm::dyn_cast(node)) != NULL) { std::vector &sls = sl->stmts; for (unsigned int i = 0; i < sls.size(); ++i) - sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data); + sls[i] = (Stmt *)WalkAST(sls[i], preFunc, postFunc, data, preUpdate); } else if ((ps = llvm::dyn_cast(node)) != NULL) - ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data); + ps->values = (Expr *)WalkAST(ps->values, preFunc, postFunc, data, preUpdate); else if ((as = llvm::dyn_cast(node)) != NULL) - as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data); + as->expr = (Expr *)WalkAST(as->expr, preFunc, postFunc, data, preUpdate); else if ((dels = llvm::dyn_cast(node)) != NULL) - dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data); + dels->expr = (Expr *)WalkAST(dels->expr, preFunc, postFunc, data, preUpdate); else if ((ums = llvm::dyn_cast(node)) != NULL) - ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data); + ums->stmts = (Stmt *)WalkAST(ums->stmts, preFunc, postFunc, data, preUpdate); else FATAL("Unhandled statement type in WalkAST()"); } @@ -215,57 +219,57 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, NewExpr *newe; if ((ue = llvm::dyn_cast(node)) != NULL) - ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data); + ue->expr = (Expr *)WalkAST(ue->expr, preFunc, postFunc, data, preUpdate); else if ((be = llvm::dyn_cast(node)) != NULL) { - be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data); - be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data); + be->arg0 = (Expr *)WalkAST(be->arg0, preFunc, postFunc, data, preUpdate); + be->arg1 = (Expr *)WalkAST(be->arg1, preFunc, postFunc, data, preUpdate); } else if ((ae = llvm::dyn_cast(node)) != NULL) { - ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data); - ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data); + ae->lvalue = (Expr *)WalkAST(ae->lvalue, preFunc, postFunc, data, preUpdate); + ae->rvalue = (Expr *)WalkAST(ae->rvalue, preFunc, postFunc, data, preUpdate); } else if ((se = llvm::dyn_cast(node)) != NULL) { - se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data); - se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data); - se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data); + se->test = (Expr *)WalkAST(se->test, preFunc, postFunc, data, preUpdate); + se->expr1 = (Expr *)WalkAST(se->expr1, preFunc, postFunc, data, preUpdate); + se->expr2 = (Expr *)WalkAST(se->expr2, preFunc, postFunc, data, preUpdate); } else if ((el = llvm::dyn_cast(node)) != NULL) { for (unsigned int i = 0; i < el->exprs.size(); ++i) el->exprs[i] = (Expr *)WalkAST(el->exprs[i], preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((fce = llvm::dyn_cast(node)) != NULL) { - fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data); - fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data); + fce->func = (Expr *)WalkAST(fce->func, preFunc, postFunc, data, preUpdate); + fce->args = (ExprList *)WalkAST(fce->args, preFunc, postFunc, data, preUpdate); for (int k = 0; k < 3; k++) fce->launchCountExpr[0] = (Expr *)WalkAST(fce->launchCountExpr[0], preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if ((ie = llvm::dyn_cast(node)) != NULL) { - ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data); - ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data); + ie->baseExpr = (Expr *)WalkAST(ie->baseExpr, preFunc, postFunc, data, preUpdate); + ie->index = (Expr *)WalkAST(ie->index, preFunc, postFunc, data, preUpdate); } else if ((me = llvm::dyn_cast(node)) != NULL) - me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data); + me->expr = (Expr *)WalkAST(me->expr, preFunc, postFunc, data, preUpdate); else if ((tce = llvm::dyn_cast(node)) != NULL) - tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); + tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data, preUpdate); else if ((re = llvm::dyn_cast(node)) != NULL) - re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); + re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data, preUpdate); else if ((ptrderef = llvm::dyn_cast(node)) != NULL) ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, - data); + data, preUpdate); else if ((refderef = llvm::dyn_cast(node)) != NULL) refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, - data); + data, preUpdate); else if ((soe = llvm::dyn_cast(node)) != NULL) - soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); + soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data, preUpdate); else if ((aoe = llvm::dyn_cast(node)) != NULL) - aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data); + aoe->expr = (Expr *)WalkAST(aoe->expr, preFunc, postFunc, data, preUpdate); else if ((newe = llvm::dyn_cast(node)) != NULL) { newe->countExpr = (Expr *)WalkAST(newe->countExpr, preFunc, - postFunc, data); + postFunc, data, preUpdate); newe->initExpr = (Expr *)WalkAST(newe->initExpr, preFunc, - postFunc, data); + postFunc, data, preUpdate); } else if (llvm::dyn_cast(node) != NULL || llvm::dyn_cast(node) != NULL || @@ -536,11 +540,21 @@ lTranslatePolyNode(ASTNode *node, void *d) { return node->ReplacePolyType(data->polyType, data->replacement); } +static ASTNode * +lCopyNode(ASTNode *node, void *) { + return node->Copy(); +} + ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) { struct PolyData data; data.polyType = polyType; data.replacement = replacement; - return WalkAST(root, NULL, lTranslatePolyNode, &data); + return WalkAST(root, NULL, lTranslatePolyNode, &data, lCopyNode); +} + +ASTNode * +CopyAST(ASTNode *root) { + return WalkAST(root, NULL, NULL, NULL, lCopyNode); } diff --git a/ast.h b/ast.h index 91d3dd5b..5c40454f 100644 --- a/ast.h +++ b/ast.h @@ -68,6 +68,8 @@ public: pointer in place of the original ASTNode *. */ virtual ASTNode *TypeCheck() = 0; + virtual ASTNode *Copy() = 0; + virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0; /** Estimate the execution cost of the node (not including the cost of @@ -177,7 +179,8 @@ typedef ASTNode * (* ASTPostCallBackFunc)(ASTNode *node, void *data); doing so, calls postFunc, at the node. The return value from the postFunc call is ignored. */ extern ASTNode *WalkAST(ASTNode *root, ASTPreCallBackFunc preFunc, - ASTPostCallBackFunc postFunc, void *data); + ASTPostCallBackFunc postFunc, void *data, + ASTPostCallBackFunc preUpdate = NULL); /** Perform simple optimizations on the AST or portion thereof passed to this function, returning the resulting AST. */ @@ -209,6 +212,8 @@ extern int EstimateCost(ASTNode *root); extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement); +extern ASTNode * CopyAST(ASTNode *root); + /** Returns true if it would be safe to run the given code with an "all off" mask. */ extern bool SafeToRunWithMaskAllOff(ASTNode *root); diff --git a/expr.cpp b/expr.cpp index 1e8748e2..4d04a298 100644 --- a/expr.cpp +++ b/expr.cpp @@ -113,7 +113,81 @@ Expr::GetBaseSymbol() const { } Expr * -Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { +Expr::Copy() { + Expr *copy; + switch (getValueID()) { + case AddressOfExprID: + copy = (Expr*)new AddressOfExpr(*(AddressOfExpr*)this); + break; + case AssignExprID: + copy = (Expr*)new AssignExpr(*(AssignExpr*)this); + break; + case BinaryExprID: + copy = (Expr*)new BinaryExpr(*(BinaryExpr*)this); + break; + case ConstExprID: + copy = (Expr*)new ConstExpr(*(ConstExpr*)this); + break; + case PtrDerefExprID: + copy = (Expr*)new PtrDerefExpr(*(PtrDerefExpr*)this); + break; + case RefDerefExprID: + copy = (Expr*)new RefDerefExpr(*(RefDerefExpr*)this); + break; + case ExprListID: + copy = (Expr*)new ExprList(*(ExprList*)this); + break; + case FunctionCallExprID: + copy = (Expr*)new FunctionCallExpr(*(FunctionCallExpr*)this); + break; + case FunctionSymbolExprID: + copy = (Expr*)new FunctionSymbolExpr(*(FunctionSymbolExpr*)this); + break; + case IndexExprID: + copy = (Expr*)new IndexExpr(*(IndexExpr*)this); + break; + case StructMemberExprID: + copy = (Expr*)new StructMemberExpr(*(StructMemberExpr*)this); + break; + case VectorMemberExprID: + copy = (Expr*)new VectorMemberExpr(*(VectorMemberExpr*)this); + break; + case NewExprID: + copy = (Expr*)new NewExpr(*(NewExpr*)this); + break; + case NullPointerExprID: + copy = (Expr*)new NullPointerExpr(*(NullPointerExpr*)this); + break; + case ReferenceExprID: + copy = (Expr*)new ReferenceExpr(*(ReferenceExpr*)this); + break; + case SelectExprID: + copy = (Expr*)new SelectExpr(*(SelectExpr*)this); + break; + case SizeOfExprID: + copy = (Expr*)new SizeOfExpr(*(SizeOfExpr*)this); + break; + case SymbolExprID: + copy = (Expr*)new SymbolExpr(*(SymbolExpr*)this); + break; + case SyncExprID: + copy = (Expr*)new SyncExpr(*(SyncExpr*)this); + break; + case TypeCastExprID: + copy = (Expr*)new TypeCastExpr(*(TypeCastExpr*)this); + break; + case UnaryExprID: + copy = (Expr*)new UnaryExpr(*(UnaryExpr*)this); + break; + default: + FATAL("Unmatched case in Expr::Copy"); + copy = this; // just to silence the compiler + } + return copy; +} + +Expr * +Expr::ReplacePolyType(const PolyType *, const Type *) { return this; } @@ -4680,11 +4754,11 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (index == NULL || baseExpr == NULL) return NULL; - if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) { type = PolyType::ReplaceType(type, to); } - if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(GetLValueType()->GetBaseType(), from)) { lvalueType = new PointerType(to, lvalueType->GetVariability(), lvalueType->IsConstType()); } @@ -4757,27 +4831,6 @@ lIdentifierToVectorElement(char id) { ////////////////////////////////////////////////// // StructMemberExpr -class StructMemberExpr : public MemberExpr -{ -public: - StructMemberExpr(Expr *e, const char *id, SourcePos p, - SourcePos idpos, bool derefLValue); - - static inline bool classof(StructMemberExpr const*) { return true; } - static inline bool classof(ASTNode const* N) { - return N->getValueID() == StructMemberExprID; - } - - const Type *GetType() const; - const Type *GetLValueType() const; - int getElementNumber() const; - const Type *getElementType() const; - -private: - const StructType *getStructType() const; -}; - - StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue) : MemberExpr(e, id, p, idpos, derefLValue, StructMemberExprID) { @@ -4929,31 +4982,6 @@ StructMemberExpr::getStructType() const { ////////////////////////////////////////////////// // VectorMemberExpr -class VectorMemberExpr : public MemberExpr -{ -public: - VectorMemberExpr(Expr *e, const char *id, SourcePos p, - SourcePos idpos, bool derefLValue); - - static inline bool classof(VectorMemberExpr const*) { return true; } - static inline bool classof(ASTNode const* N) { - return N->getValueID() == VectorMemberExprID; - } - - llvm::Value *GetValue(FunctionEmitContext* ctx) const; - llvm::Value *GetLValue(FunctionEmitContext* ctx) const; - const Type *GetType() const; - const Type *GetLValueType() const; - - int getElementNumber() const; - const Type *getElementType() const; - -private: - const VectorType *exprVectorType; - const VectorType *memberType; -}; - - VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p, SourcePos idpos, bool derefLValue) : MemberExpr(e, id, p, idpos, derefLValue, VectorMemberExprID) { @@ -5339,14 +5367,10 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (expr == NULL) return NULL; - if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) { + if (Type::EqualForReplacement(GetType()->GetBaseType(), from)) { type = PolyType::ReplaceType(type, to); } - if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) { - lvalueType = PolyType::ReplaceType(lvalueType, lvalueType); - } - return this; } @@ -8075,6 +8099,12 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { if (!symbol) return NULL; + Symbol *tmp = m->symbolTable->LookupVariable(symbol->name.c_str()); + if (tmp) { + tmp->parentFunction = symbol->parentFunction; + symbol = tmp; + } + if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) { symbol->type = PolyType::ReplaceType(symbol->type, to); } @@ -8151,6 +8181,14 @@ FunctionSymbolExpr::Optimize() { return this; } +Expr * +FunctionSymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) { + // force re-evaluation of overloaded type + this->triedToResolve = false; + + return this; +} + int FunctionSymbolExpr::EstimateCost() const { @@ -8397,6 +8435,16 @@ FunctionSymbolExpr::computeOverloadCost(const FunctionType *ftype, cost[i] += 8 * costScale; continue; } + if (callTypeNC->IsPolymorphicType()) { + const PolyType *callTypeP = + CastType(callTypeNC->GetBaseType()); + if (callTypeP->CanBeType(fargTypeNC->GetBaseType()) && + callTypeNC->IsArrayType() == fargTypeNC->IsArrayType() && + callTypeNC->IsPointerType() == fargTypeNC->IsPointerType()){ + cost[i] += 8 * costScale; + continue; + } + } if (fargType->IsVaryingType() && callType->IsUniformType()) { // Here we deal with brodcasting uniform to varying. // callType - varying and fargType - uniform is forbidden. @@ -8523,6 +8571,12 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, return true; } else if (matches.size() > 1) { + for (size_t i=0; iIsPolymorphicType()) { + matchingFunc = matches[0]; + return true; + } + } // Multiple matches: ambiguous std::string candidateMessage = lGetOverloadCandidateMessage(matches, argTypes, argCouldBeNULL); diff --git a/expr.h b/expr.h index 6b655614..0fd44158 100644 --- a/expr.h +++ b/expr.h @@ -96,6 +96,7 @@ public: encountered, NULL should be returned. */ virtual Expr *TypeCheck() = 0; + Expr *Copy(); /** This method replaces a polymorphic type with a specific atomic type */ Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement); @@ -385,6 +386,51 @@ protected: mutable const Type *type, *lvalueType; }; +class StructMemberExpr : public MemberExpr +{ +public: + StructMemberExpr(Expr *e, const char *id, SourcePos p, + SourcePos idpos, bool derefLValue); + + static inline bool classof(StructMemberExpr const*) { return true; } + static inline bool classof(ASTNode const* N) { + return N->getValueID() == StructMemberExprID; + } + + const Type *GetType() const; + const Type *GetLValueType() const; + int getElementNumber() const; + const Type *getElementType() const; + +private: + const StructType *getStructType() const; +}; + + +class VectorMemberExpr : public MemberExpr +{ +public: + VectorMemberExpr(Expr *e, const char *id, SourcePos p, + SourcePos idpos, bool derefLValue); + + static inline bool classof(VectorMemberExpr const*) { return true; } + static inline bool classof(ASTNode const* N) { + return N->getValueID() == VectorMemberExprID; + } + + llvm::Value *GetValue(FunctionEmitContext* ctx) const; + llvm::Value *GetLValue(FunctionEmitContext* ctx) const; + const Type *GetType() const; + const Type *GetLValueType() const; + + int getElementNumber() const; + const Type *getElementType() const; + +private: + const VectorType *exprVectorType; + const VectorType *memberType; +}; + /** @brief Expression representing a compile-time constant value. @@ -715,6 +761,7 @@ public: Symbol *GetBaseSymbol() const; Expr *TypeCheck(); Expr *Optimize(); + Expr *ReplacePolyType(const PolyType *from, const Type *to); void Print() const; int EstimateCost() const; llvm::Constant *GetConstant(const Type *type) const; diff --git a/func.cpp b/func.cpp index 8776a9fc..6d6fdbb9 100644 --- a/func.cpp +++ b/func.cpp @@ -90,7 +90,7 @@ #endif #include -Function::Function(Symbol *s, Stmt *c) { +Function::Function(Symbol *s, Stmt *c, bool typecheck) { sym = s; code = c; @@ -98,13 +98,15 @@ Function::Function(Symbol *s, Stmt *c) { Assert(maskSymbol != NULL); if (code != NULL) { - code = TypeCheck(code); + if (typecheck) { + code = TypeCheck(code); - if (code != NULL && g->debugPrint) { - printf("After typechecking function \"%s\":\n", - sym->name.c_str()); - code->Print(0); - printf("---------------------\n"); + if (code != NULL && g->debugPrint) { + printf("After typechecking function \"%s\":\n", + sym->name.c_str()); + code->Print(0); + printf("---------------------\n"); + } } if (code != NULL) { @@ -135,6 +137,7 @@ Function::Function(Symbol *s, Stmt *c) { args.push_back(sym); const Type *t = type->GetParameterType(i); + if (sym != NULL && CastType(t) == NULL) sym->parentFunction = this; } @@ -640,7 +643,6 @@ Function::IsPolyFunction() const { return false; } - std::vector * Function::ExpandPolyArguments(SymbolTable *symbolTable) const { Assert(symbolTable != NULL); @@ -651,37 +653,64 @@ Function::ExpandPolyArguments(SymbolTable *symbolTable) const { const FunctionType *func = CastType(sym->type); - if (g->debugPrint) { - printf("%s before replacing anything:\n", sym->name.c_str()); - code->Print(0); - } - for (size_t i=0; idebugPrint) { + printf("%s before replacing anything:\n", sym->name.c_str()); + code->Print(0); + } const FunctionType *ft = CastType(versions[i]->type); - Stmt *ncode = code; + + symbolTable->PushScope(); + + Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft); + Stmt *ncode = (Stmt*)CopyAST(code); + + Function *f = new Function(s, ncode, false); + + for (size_t j=0; jargs[j] = new Symbol(*args[j]); + symbolTable->AddVariable(f->args[j], false); + } for (int j=0; jGetNumParameters(); j++) { if (func->GetParameterType(j)->IsPolymorphicType()) { const PolyType *from = CastType( func->GetParameterType(j)->GetBaseType()); - ncode = (Stmt*)TranslatePoly(ncode, from, + f->code = (Stmt*)TranslatePoly(f->code, from, ft->GetParameterType(j)->GetBaseType()); if (g->debugPrint) { printf("%s after replacing %s with %s:\n\n", sym->name.c_str(), from->GetString().c_str(), ft->GetParameterType(j)->GetBaseType()->GetString().c_str()); - ncode->Print(0); + f->code->Print(0); printf("------------------------------------------\n\n"); } } } - Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft); + // we didn't typecheck before, now we can + f->code = TypeCheck(f->code); - expanded->push_back(new Function(s, ncode)); + f->code = Optimize(f->code); + + if (g->debugPrint) { + printf("After optimizing expanded function \"%s\":\n", + f->sym->name.c_str()); + f->code->Print(0); + printf("---------------------\n"); + } + + + + symbolTable->PopScope(); + + + + + expanded->push_back(f); } return expanded; diff --git a/func.h b/func.h index 86d801f4..87013e2b 100644 --- a/func.h +++ b/func.h @@ -44,7 +44,7 @@ class Function { public: - Function(Symbol *sym, Stmt *code); + Function(Symbol *sym, Stmt *code, bool typecheck=true); const Type *GetReturnType() const; const FunctionType *GetType() const; diff --git a/stmt.cpp b/stmt.cpp index f03944e8..bab93ecc 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -35,6 +35,7 @@ @brief File with definitions classes related to statements in the language */ +#include "ast.h" #include "stmt.h" #include "ctx.h" #include "util.h" @@ -78,7 +79,81 @@ Stmt::Optimize() { } Stmt * -Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { +Stmt::Copy() { + Stmt *copy; + switch (getValueID()) { + case AssertStmtID: + copy = (Stmt*)new AssertStmt(*(AssertStmt*)this); + break; + case BreakStmtID: + copy = (Stmt*)new BreakStmt(*(BreakStmt*)this); + break; + case CaseStmtID: + copy = (Stmt*)new CaseStmt(*(CaseStmt*)this); + break; + case ContinueStmtID: + copy = (Stmt*)new ContinueStmt(*(ContinueStmt*)this); + break; + case DeclStmtID: + copy = (Stmt*)new DeclStmt(*(DeclStmt*)this); + break; + case DefaultStmtID: + copy = (Stmt*)new DefaultStmt(*(DefaultStmt*)this); + break; + case DeleteStmtID: + copy = (Stmt*)new DeleteStmt(*(DeleteStmt*)this); + break; + case DoStmtID: + copy = (Stmt*)new DoStmt(*(DoStmt*)this); + break; + case ExprStmtID: + copy = (Stmt*)new ExprStmt(*(ExprStmt*)this); + break; + case ForeachActiveStmtID: + copy = (Stmt*)new ForeachActiveStmt(*(ForeachActiveStmt*)this); + break; + case ForeachStmtID: + copy = (Stmt*)new ForeachStmt(*(ForeachStmt*)this); + break; + case ForeachUniqueStmtID: + copy = (Stmt*)new ForeachUniqueStmt(*(ForeachUniqueStmt*)this); + break; + case ForStmtID: + copy = (Stmt*)new ForStmt(*(ForStmt*)this); + break; + case GotoStmtID: + copy = (Stmt*)new GotoStmt(*(GotoStmt*)this); + break; + case IfStmtID: + copy = (Stmt*)new IfStmt(*(IfStmt*)this); + break; + case LabeledStmtID: + copy = (Stmt*)new LabeledStmt(*(LabeledStmt*)this); + break; + case PrintStmtID: + copy = (Stmt*)new PrintStmt(*(PrintStmt*)this); + break; + case ReturnStmtID: + copy = (Stmt*)new ReturnStmt(*(ReturnStmt*)this); + break; + case StmtListID: + copy = (Stmt*)new StmtList(*(StmtList*)this); + break; + case SwitchStmtID: + copy = (Stmt*)new SwitchStmt(*(SwitchStmt*)this); + break; + case UnmaskedStmtID: + copy = (Stmt*)new UnmaskedStmt(*(UnmaskedStmt*)this); + break; + default: + FATAL("Unmatched case in Stmt::Copy"); + copy = this; // just to silence the compiler + } + return copy; +} + +Stmt * +Stmt::ReplacePolyType(const PolyType *, const Type *) { return this; } @@ -484,7 +559,8 @@ DeclStmt::TypeCheck() { // an int as the constValue later... const Type *type = vars[i].sym->type; if (CastType(type) != NULL || - CastType(type) != NULL) { + CastType(type) != NULL || + CastType(type) != NULL) { // If it's an expr list with an atomic type, we'll later issue // an error. Need to leave vars[i].init as is in that case so // it is in fact caught later, though. @@ -502,9 +578,15 @@ DeclStmt::TypeCheck() { Stmt * DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) { for (size_t i = 0; i < vars.size(); i++) { + vars[i].sym = new Symbol(*vars[i].sym); + m->symbolTable->AddVariable(vars[i].sym, false); Symbol *s = vars[i].sym; if (Type::EqualForReplacement(s->type->GetBaseType(), from)) { s->type = PolyType::ReplaceType(s->type, to); + + // this typecast *should* be valid after typechecking + vars[i].init = TypeConvertExpr(vars[i].init, s->type, + "initializer"); } } @@ -1487,6 +1569,17 @@ ForeachStmt::ForeachStmt(const std::vector &lvs, stmts(s) { } +/* +ForeachStmt::ForeachStmt(ForeachStmt *base) + : Stmt(base->pos, ForeachStmtID) { + dimVariables = base->dimVariables; + startExprs = base->startExprs; + endExprs = base->endExprs; + isTiled = base->isTiled; + stmts = base->stmts; +} +*/ + /* Given a uniform counter value in the memory location pointed to by uniformCounterPtr, compute the corresponding set of varying counter @@ -1730,8 +1823,10 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const { // Start and end value for this loop dimension llvm::Value *sv = startExprs[i]->GetValue(ctx); llvm::Value *ev = endExprs[i]->GetValue(ctx); - if (sv == NULL || ev == NULL) + if (sv == NULL || ev == NULL) { + fprintf(stderr, "ev is NULL again :(\n"); return; + } startVals.push_back(sv); endVals.push_back(ev); @@ -2191,21 +2286,6 @@ ForeachStmt::TypeCheck() { return anyErrors ? NULL : this; } -Stmt * -ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) { - if (!stmts) - return NULL; - - for (size_t i=0; itype; - if (Type::EqualForReplacement(t->GetBaseType(), from)) { - t = PolyType::ReplaceType(t, to); - } - } - - return this; -} - int ForeachStmt::EstimateCost() const { diff --git a/stmt.h b/stmt.h index 237acaaf..371f5915 100644 --- a/stmt.h +++ b/stmt.h @@ -70,6 +70,7 @@ public: // Stmts don't have anything to do here. virtual Stmt *Optimize(); virtual Stmt *TypeCheck() = 0; + Stmt *Copy(); Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement); }; @@ -283,7 +284,6 @@ public: void Print(int indent) const; Stmt *TypeCheck(); - Stmt *ReplacePolyType(const PolyType *from, const Type *to); int EstimateCost() const; std::vector dimVariables; diff --git a/sym.cpp b/sym.cpp index 48ee06f7..796f5f02 100644 --- a/sym.cpp +++ b/sym.cpp @@ -95,14 +95,14 @@ SymbolTable::PopScope() { bool -SymbolTable::AddVariable(Symbol *symbol) { +SymbolTable::AddVariable(Symbol *symbol, bool issueScopeWarning) { Assert(symbol != NULL); // Check to see if a symbol of the same name has already been declared. for (int i = (int)variables.size() - 1; i >= 0; --i) { SymbolMapType &sm = *(variables[i]); if (sm.find(symbol->name) != sm.end()) { - if (i == (int)variables.size()-1) { + if (i == (int)variables.size()-1 && issueScopeWarning) { // If a symbol of the same name was declared in the // same scope, it's an error. Error(symbol->pos, "Ignoring redeclaration of symbol \"%s\".", @@ -112,9 +112,11 @@ SymbolTable::AddVariable(Symbol *symbol) { else { // Otherwise it's just shadowing something else, which // is legal but dangerous.. - Warning(symbol->pos, - "Symbol \"%s\" shadows symbol declared in outer scope.", - symbol->name.c_str()); + if (issueScopeWarning) { + Warning(symbol->pos, + "Symbol \"%s\" shadows symbol declared in outer scope.", + symbol->name.c_str()); + } (*variables.back())[symbol->name] = symbol; return true; } diff --git a/sym.h b/sym.h index 41973c72..620fb172 100644 --- a/sym.h +++ b/sym.h @@ -141,7 +141,7 @@ public: with a symbol defined at the same scope. (Symbols may shaodow symbols in outer scopes; a warning is issued in this case, but this method still returns true.) */ - bool AddVariable(Symbol *symbol); + bool AddVariable(Symbol *symbol, bool issueScopeWarning=true); /** Looks for a variable with the given name in the symbol table. This method searches outward from the innermost scope to the outermost, diff --git a/tests_ispcpp/Makefile b/tests_ispcpp/Makefile index e30bfc2d..7881a592 100644 --- a/tests_ispcpp/Makefile +++ b/tests_ispcpp/Makefile @@ -1,5 +1,5 @@ CXX=g++ -CXXFLAGS=-std=c++11 +CXXFLAGS=-std=c++11 -O2 ISPC=../ispc ISPCFLAGS=--target=sse4-x2 -O2 --arch=x86-64 diff --git a/tests_ispcpp/hello.cpp b/tests_ispcpp/hello.cpp index 219c7da8..06e0cd51 100644 --- a/tests_ispcpp/hello.cpp +++ b/tests_ispcpp/hello.cpp @@ -1,7 +1,7 @@ #include #include -#include "hello.ispc.h" +#include "hello.h" int main() { float A[100]; diff --git a/tests_ispcpp/varying.cpp b/tests_ispcpp/varying.cpp new file mode 100644 index 00000000..d56159a6 --- /dev/null +++ b/tests_ispcpp/varying.cpp @@ -0,0 +1,27 @@ +#include +#include + +#include "varying.h" + +int main() { + float A[256]; + double B[256]; + double outA[256]; + double outB[256]; + + + for (int i=0; i<256; i++) { + A[i] = 1. / (i+1); + B[i] = 1. / (i+1); + } + + ispc::square(256, (float*)&A, (double*)&outA); + + ispc::square(256, (double*)&B, (double*)&outB); + + for (int i=0; i<256; i++) { + printf("float: %.16f\tdouble: %.16f\n", outA[i], outB[i]); + } + + return 0; +} diff --git a/tests_ispcpp/varying.ispc b/tests_ispcpp/varying.ispc new file mode 100644 index 00000000..3657fee6 --- /dev/null +++ b/tests_ispcpp/varying.ispc @@ -0,0 +1,14 @@ +floating foo(const uniform int a, floating b) { + floating out = b; + for (int i = 1; iIsVaryingType()) t = t->GetAsVaryingType(); - fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n", - from->GetString().c_str(), - t->GetString().c_str()); + if (g->debugPrint) { + fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n", + from->GetString().c_str(), + t->GetString().c_str()); + } return t; }