[WIP] Plumbing to expand polymorphic functions

This commit is contained in:
2017-05-04 21:26:43 -04:00
parent 93c563e073
commit 46ed9bdb3c
10 changed files with 130 additions and 7 deletions

29
ast.cpp
View File

@@ -42,8 +42,11 @@
#include "func.h"
#include "stmt.h"
#include "sym.h"
#include "type.h"
#include "util.h"
#include <map>
///////////////////////////////////////////////////////////////////////////
// ASTNode
@@ -62,7 +65,9 @@ AST::AddFunction(Symbol *sym, Stmt *code) {
Function *f = new Function(sym, code);
if (f->IsPolyFunction()) {
FATAL("This is a good start, but implement me!");
std::vector<Function *> *expanded = f->ExpandPolyArguments();
for (size_t i=0; i<expanded->size(); i++)
functions.push_back((*expanded)[i]);
} else {
functions.push_back(f);
}
@@ -515,3 +520,25 @@ SafeToRunWithMaskAllOff(ASTNode *root) {
WalkAST(root, lCheckAllOffSafety, NULL, &safe);
return safe;
}
struct PolyData {
const PolyType *polyType;
const Type *replacement;
};
static ASTNode *
lTranslatePolyNode(ASTNode *node, void *d) {
struct PolyData *data = (struct PolyData*)d;
return node->ReplacePolyType(data->polyType, data->replacement);
}
ASTNode *
TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) {
struct PolyData data;
data.polyType = polyType;
data.replacement = replacement;
return WalkAST(root, NULL, lTranslatePolyNode, &replacement);
}

11
ast.h
View File

@@ -39,6 +39,7 @@
#define ISPC_AST_H 1
#include "ispc.h"
#include "type.h"
#include <vector>
/** @brief Abstract base class for nodes in the abstract syntax tree (AST).
@@ -67,6 +68,8 @@ public:
pointer in place of the original ASTNode *. */
virtual ASTNode *TypeCheck() = 0;
virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0;
/** Estimate the execution cost of the node (not including the cost of
the children. The value returned should be based on the COST_*
enumerant values defined in ispc.h. */
@@ -75,8 +78,8 @@ public:
/** All AST nodes must track the file position where they are
defined. */
SourcePos pos;
/** An enumeration for keeping track of the concrete subclass of Value
/** An enumeration for keeping track of the concrete subclass of Value
that is actually instantiated.*/
enum ASTNodeTy {
/* For classes inherited from Expr */
@@ -127,9 +130,9 @@ public:
SwitchStmtID,
UnmaskedStmtID
};
/** Return an ID for the concrete type of this object. This is used to
implement the classof checks. This should not be used for any
implement the classof checks. This should not be used for any
other purpose, as the values may change as ISPC evolves */
unsigned getValueID() const {
return SubclassID;

View File

@@ -112,6 +112,11 @@ Expr::GetBaseSymbol() const {
return NULL;
}
Expr *
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this;
}
#if 0
/** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost

4
expr.h
View File

@@ -96,6 +96,10 @@ public:
encountered, NULL should be returned. */
virtual Expr *TypeCheck() = 0;
/** This method replaces a polymorphic type with a specific atomic type */
Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement);
/** Prints the expression to standard output (used for debugging). */
virtual void Print() const = 0;
};

View File

@@ -638,3 +638,42 @@ Function::IsPolyFunction() const {
return false;
}
std::vector<Function *> *
Function::ExpandPolyArguments() const {
std::vector<const Type *> toExpand;
std::vector<Function *> *expanded = new std::vector<Function *>();
for (size_t i = 0; i < args.size(); i++) {
if (args[i]->type->IsPolymorphicType()) {
toExpand.push_back(args[i]->type);
}
}
for (size_t i = 0; i < toExpand.size(); i++) {
const PolyType *pt = CastType<PolyType>(toExpand[i]->GetBaseType());
std::vector<AtomicType *>::iterator expanded;
expanded = pt->ExpandBegin();
for (; expanded != pt->ExpandEnd(); expanded++) {
Type *replacement = *expanded;
if (toExpand[i]->IsPointerType())
replacement = new PointerType(replacement,
toExpand[i]->GetVariability(),
toExpand[i]->IsConstType());
else if (toExpand[i]->IsArrayType())
replacement = new ArrayType(replacement,
(CastType<ArrayType>(toExpand[i]))->GetElementCount());
else if (toExpand[i]->IsReferenceType())
replacement = new ReferenceType(replacement);
printf("pretend I'm replacing %s with %s\n",
toExpand[i]->GetString().c_str(),
replacement->GetString().c_str());
}
}
return expanded;
}

2
func.h
View File

@@ -54,6 +54,8 @@ public:
/** Checks if the function has polymorphic parameters */
const bool IsPolyFunction() const;
std::vector<Function *> *ExpandPolyArguments() const;
private:
void emitCode(FunctionEmitContext *ctx, llvm::Function *function,
SourcePos firstStmtPos);

View File

@@ -77,6 +77,11 @@ Stmt::Optimize() {
return this;
}
Stmt *
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this;
}
///////////////////////////////////////////////////////////////////////////
// ExprStmt

1
stmt.h
View File

@@ -70,6 +70,7 @@ public:
// Stmts don't have anything to do here.
virtual Stmt *Optimize();
virtual Stmt *TypeCheck() = 0;
Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement);
};

View File

@@ -698,12 +698,14 @@ PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic, int q)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(q) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
@@ -814,6 +816,39 @@ PolyType::GetAsUniformType() const {
return asUniformType;
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandBegin() const {
if (expandedTypes)
return expandedTypes->begin();
expandedTypes = new std::vector<AtomicType *>();
if (restriction == TYPE_INTEGER || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT64, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT64, variability, isConst));
}
if (restriction == TYPE_FLOATING || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_FLOAT, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_DOUBLE, variability, isConst));
}
return expandedTypes->begin();
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandEnd() const {
Assert(expandedTypes != NULL);
return expandedTypes->end();
}
const PolyType *
PolyType::GetAsUnboundVariabilityType() const {

6
type.h
View File

@@ -360,10 +360,10 @@ public:
static const AtomicType *UniformDouble, *VaryingDouble;
static const AtomicType *Void;
AtomicType(BasicType basicType, Variability v, bool isConst);
private:
const Variability variability;
const bool isConst;
AtomicType(BasicType basicType, Variability v, bool isConst);
mutable const AtomicType *asOtherConstType, *asUniformType, *asVaryingType;
};
@@ -420,7 +420,8 @@ public:
// Returns the list of AtomicTypes that are valid instantiations of the
// polymorphic type
const std::vector<AtomicType *> GetEnumeratedTypes() const;
const std::vector<AtomicType *>::iterator ExpandBegin() const;
const std::vector<AtomicType *>::iterator ExpandEnd() const;
private:
const Variability variability;
@@ -430,6 +431,7 @@ private:
PolyType(PolyRestriction type, Variability v, bool isConst, int quant);
mutable const PolyType *asOtherConstType, *asUniformType, *asVaryingType;
mutable std::vector<AtomicType *> *expandedTypes;
};