From 46ed9bdb3c9050b3e3bd5865f77d2ea28dbb5c46 Mon Sep 17 00:00:00 2001 From: Aaron Gutierrez Date: Thu, 4 May 2017 21:26:43 -0400 Subject: [PATCH] [WIP] Plumbing to expand polymorphic functions --- ast.cpp | 29 ++++++++++++++++++++++++++++- ast.h | 11 +++++++---- expr.cpp | 5 +++++ expr.h | 4 ++++ func.cpp | 39 +++++++++++++++++++++++++++++++++++++++ func.h | 2 ++ stmt.cpp | 5 +++++ stmt.h | 1 + type.cpp | 35 +++++++++++++++++++++++++++++++++++ type.h | 6 ++++-- 10 files changed, 130 insertions(+), 7 deletions(-) diff --git a/ast.cpp b/ast.cpp index 097439bd..5acfb651 100644 --- a/ast.cpp +++ b/ast.cpp @@ -42,8 +42,11 @@ #include "func.h" #include "stmt.h" #include "sym.h" +#include "type.h" #include "util.h" +#include + /////////////////////////////////////////////////////////////////////////// // ASTNode @@ -62,7 +65,9 @@ AST::AddFunction(Symbol *sym, Stmt *code) { Function *f = new Function(sym, code); if (f->IsPolyFunction()) { - FATAL("This is a good start, but implement me!"); + std::vector *expanded = f->ExpandPolyArguments(); + for (size_t i=0; isize(); i++) + functions.push_back((*expanded)[i]); } else { functions.push_back(f); } @@ -515,3 +520,25 @@ SafeToRunWithMaskAllOff(ASTNode *root) { WalkAST(root, lCheckAllOffSafety, NULL, &safe); return safe; } + +struct PolyData { + const PolyType *polyType; + const Type *replacement; +}; + + +static ASTNode * +lTranslatePolyNode(ASTNode *node, void *d) { + struct PolyData *data = (struct PolyData*)d; + + return node->ReplacePolyType(data->polyType, data->replacement); +} + +ASTNode * +TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) { + struct PolyData data; + data.polyType = polyType; + data.replacement = replacement; + + return WalkAST(root, NULL, lTranslatePolyNode, &replacement); +} diff --git a/ast.h b/ast.h index 0a2a2edf..b13120d8 100644 --- a/ast.h +++ b/ast.h @@ -39,6 +39,7 @@ #define ISPC_AST_H 1 #include "ispc.h" +#include "type.h" #include /** @brief Abstract base class for nodes in the abstract syntax tree (AST). @@ -67,6 +68,8 @@ public: pointer in place of the original ASTNode *. */ virtual ASTNode *TypeCheck() = 0; + virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0; + /** Estimate the execution cost of the node (not including the cost of the children. The value returned should be based on the COST_* enumerant values defined in ispc.h. */ @@ -75,8 +78,8 @@ public: /** All AST nodes must track the file position where they are defined. */ SourcePos pos; - - /** An enumeration for keeping track of the concrete subclass of Value + + /** An enumeration for keeping track of the concrete subclass of Value that is actually instantiated.*/ enum ASTNodeTy { /* For classes inherited from Expr */ @@ -127,9 +130,9 @@ public: SwitchStmtID, UnmaskedStmtID }; - + /** Return an ID for the concrete type of this object. This is used to - implement the classof checks. This should not be used for any + implement the classof checks. This should not be used for any other purpose, as the values may change as ISPC evolves */ unsigned getValueID() const { return SubclassID; diff --git a/expr.cpp b/expr.cpp index 51e7dd56..423c5e60 100644 --- a/expr.cpp +++ b/expr.cpp @@ -112,6 +112,11 @@ Expr::GetBaseSymbol() const { return NULL; } +Expr * +Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) { + return this; +} + #if 0 /** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost diff --git a/expr.h b/expr.h index 90bf3db8..446b7bd1 100644 --- a/expr.h +++ b/expr.h @@ -96,6 +96,10 @@ public: encountered, NULL should be returned. */ virtual Expr *TypeCheck() = 0; + + /** This method replaces a polymorphic type with a specific atomic type */ + Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement); + /** Prints the expression to standard output (used for debugging). */ virtual void Print() const = 0; }; diff --git a/func.cpp b/func.cpp index 0d68fb5e..8a4d9779 100644 --- a/func.cpp +++ b/func.cpp @@ -638,3 +638,42 @@ Function::IsPolyFunction() const { return false; } + +std::vector * +Function::ExpandPolyArguments() const { + std::vector toExpand; + std::vector *expanded = new std::vector(); + + for (size_t i = 0; i < args.size(); i++) { + if (args[i]->type->IsPolymorphicType()) { + toExpand.push_back(args[i]->type); + } + } + + for (size_t i = 0; i < toExpand.size(); i++) { + const PolyType *pt = CastType(toExpand[i]->GetBaseType()); + + std::vector::iterator expanded; + expanded = pt->ExpandBegin(); + for (; expanded != pt->ExpandEnd(); expanded++) { + Type *replacement = *expanded; + + if (toExpand[i]->IsPointerType()) + replacement = new PointerType(replacement, + toExpand[i]->GetVariability(), + toExpand[i]->IsConstType()); + else if (toExpand[i]->IsArrayType()) + replacement = new ArrayType(replacement, + (CastType(toExpand[i]))->GetElementCount()); + else if (toExpand[i]->IsReferenceType()) + replacement = new ReferenceType(replacement); + + + printf("pretend I'm replacing %s with %s\n", + toExpand[i]->GetString().c_str(), + replacement->GetString().c_str()); + } + } + + return expanded; +} diff --git a/func.h b/func.h index 94f012a7..2ac9cc90 100644 --- a/func.h +++ b/func.h @@ -54,6 +54,8 @@ public: /** Checks if the function has polymorphic parameters */ const bool IsPolyFunction() const; + std::vector *ExpandPolyArguments() const; + private: void emitCode(FunctionEmitContext *ctx, llvm::Function *function, SourcePos firstStmtPos); diff --git a/stmt.cpp b/stmt.cpp index 7f72ad33..03d89e38 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -77,6 +77,11 @@ Stmt::Optimize() { return this; } +Stmt * +Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) { + return this; +} + /////////////////////////////////////////////////////////////////////////// // ExprStmt diff --git a/stmt.h b/stmt.h index ab0f666b..270fcd91 100644 --- a/stmt.h +++ b/stmt.h @@ -70,6 +70,7 @@ public: // Stmts don't have anything to do here. virtual Stmt *Optimize(); virtual Stmt *TypeCheck() = 0; + Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement); }; diff --git a/type.cpp b/type.cpp index 9da21a0e..f94ac266 100644 --- a/type.cpp +++ b/type.cpp @@ -698,12 +698,14 @@ PolyType::PolyType(PolyRestriction r, Variability v, bool ic) : Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) { asOtherConstType = NULL; asUniformType = asVaryingType = NULL; + expandedTypes = NULL; } PolyType::PolyType(PolyRestriction r, Variability v, bool ic, int q) : Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(q) { asOtherConstType = NULL; asUniformType = asVaryingType = NULL; + expandedTypes = NULL; } @@ -814,6 +816,39 @@ PolyType::GetAsUniformType() const { return asUniformType; } +const std::vector::iterator +PolyType::ExpandBegin() const { + if (expandedTypes) + return expandedTypes->begin(); + + expandedTypes = new std::vector(); + + if (restriction == TYPE_INTEGER || restriction == TYPE_NUMBER) { + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT8, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT8, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT16, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT16, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT32, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT32, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT64, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT64, variability, isConst)); + } + if (restriction == TYPE_FLOATING || restriction == TYPE_NUMBER) { + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_FLOAT, variability, isConst)); + expandedTypes->push_back(new AtomicType(AtomicType::TYPE_DOUBLE, variability, isConst)); + } + + return expandedTypes->begin(); +} + +const std::vector::iterator +PolyType::ExpandEnd() const { + Assert(expandedTypes != NULL); + + return expandedTypes->end(); +} + + const PolyType * PolyType::GetAsUnboundVariabilityType() const { diff --git a/type.h b/type.h index 07831e01..cdb6300f 100644 --- a/type.h +++ b/type.h @@ -360,10 +360,10 @@ public: static const AtomicType *UniformDouble, *VaryingDouble; static const AtomicType *Void; + AtomicType(BasicType basicType, Variability v, bool isConst); private: const Variability variability; const bool isConst; - AtomicType(BasicType basicType, Variability v, bool isConst); mutable const AtomicType *asOtherConstType, *asUniformType, *asVaryingType; }; @@ -420,7 +420,8 @@ public: // Returns the list of AtomicTypes that are valid instantiations of the // polymorphic type - const std::vector GetEnumeratedTypes() const; + const std::vector::iterator ExpandBegin() const; + const std::vector::iterator ExpandEnd() const; private: const Variability variability; @@ -430,6 +431,7 @@ private: PolyType(PolyRestriction type, Variability v, bool isConst, int quant); mutable const PolyType *asOtherConstType, *asUniformType, *asVaryingType; + mutable std::vector *expandedTypes; };