[WIP] Plumbing to expand polymorphic functions

This commit is contained in:
2017-05-04 21:26:43 -04:00
parent 93c563e073
commit 46ed9bdb3c
10 changed files with 130 additions and 7 deletions

29
ast.cpp
View File

@@ -42,8 +42,11 @@
#include "func.h"
#include "stmt.h"
#include "sym.h"
#include "type.h"
#include "util.h"
#include <map>
///////////////////////////////////////////////////////////////////////////
// ASTNode
@@ -62,7 +65,9 @@ AST::AddFunction(Symbol *sym, Stmt *code) {
Function *f = new Function(sym, code);
if (f->IsPolyFunction()) {
FATAL("This is a good start, but implement me!");
std::vector<Function *> *expanded = f->ExpandPolyArguments();
for (size_t i=0; i<expanded->size(); i++)
functions.push_back((*expanded)[i]);
} else {
functions.push_back(f);
}
@@ -515,3 +520,25 @@ SafeToRunWithMaskAllOff(ASTNode *root) {
WalkAST(root, lCheckAllOffSafety, NULL, &safe);
return safe;
}
struct PolyData {
const PolyType *polyType;
const Type *replacement;
};
static ASTNode *
lTranslatePolyNode(ASTNode *node, void *d) {
struct PolyData *data = (struct PolyData*)d;
return node->ReplacePolyType(data->polyType, data->replacement);
}
ASTNode *
TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement) {
struct PolyData data;
data.polyType = polyType;
data.replacement = replacement;
return WalkAST(root, NULL, lTranslatePolyNode, &replacement);
}

3
ast.h
View File

@@ -39,6 +39,7 @@
#define ISPC_AST_H 1
#include "ispc.h"
#include "type.h"
#include <vector>
/** @brief Abstract base class for nodes in the abstract syntax tree (AST).
@@ -67,6 +68,8 @@ public:
pointer in place of the original ASTNode *. */
virtual ASTNode *TypeCheck() = 0;
virtual ASTNode *ReplacePolyType(const PolyType *, const Type *) = 0;
/** Estimate the execution cost of the node (not including the cost of
the children. The value returned should be based on the COST_*
enumerant values defined in ispc.h. */

View File

@@ -112,6 +112,11 @@ Expr::GetBaseSymbol() const {
return NULL;
}
Expr *
Expr::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this;
}
#if 0
/** If a conversion from 'fromAtomicType' to 'toAtomicType' may cause lost

4
expr.h
View File

@@ -96,6 +96,10 @@ public:
encountered, NULL should be returned. */
virtual Expr *TypeCheck() = 0;
/** This method replaces a polymorphic type with a specific atomic type */
Expr *ReplacePolyType(const PolyType *polyType, const Type *replacement);
/** Prints the expression to standard output (used for debugging). */
virtual void Print() const = 0;
};

View File

@@ -638,3 +638,42 @@ Function::IsPolyFunction() const {
return false;
}
std::vector<Function *> *
Function::ExpandPolyArguments() const {
std::vector<const Type *> toExpand;
std::vector<Function *> *expanded = new std::vector<Function *>();
for (size_t i = 0; i < args.size(); i++) {
if (args[i]->type->IsPolymorphicType()) {
toExpand.push_back(args[i]->type);
}
}
for (size_t i = 0; i < toExpand.size(); i++) {
const PolyType *pt = CastType<PolyType>(toExpand[i]->GetBaseType());
std::vector<AtomicType *>::iterator expanded;
expanded = pt->ExpandBegin();
for (; expanded != pt->ExpandEnd(); expanded++) {
Type *replacement = *expanded;
if (toExpand[i]->IsPointerType())
replacement = new PointerType(replacement,
toExpand[i]->GetVariability(),
toExpand[i]->IsConstType());
else if (toExpand[i]->IsArrayType())
replacement = new ArrayType(replacement,
(CastType<ArrayType>(toExpand[i]))->GetElementCount());
else if (toExpand[i]->IsReferenceType())
replacement = new ReferenceType(replacement);
printf("pretend I'm replacing %s with %s\n",
toExpand[i]->GetString().c_str(),
replacement->GetString().c_str());
}
}
return expanded;
}

2
func.h
View File

@@ -54,6 +54,8 @@ public:
/** Checks if the function has polymorphic parameters */
const bool IsPolyFunction() const;
std::vector<Function *> *ExpandPolyArguments() const;
private:
void emitCode(FunctionEmitContext *ctx, llvm::Function *function,
SourcePos firstStmtPos);

View File

@@ -77,6 +77,11 @@ Stmt::Optimize() {
return this;
}
Stmt *
Stmt::ReplacePolyType(const PolyType *polyType, const Type *replacement) {
return this;
}
///////////////////////////////////////////////////////////////////////////
// ExprStmt

1
stmt.h
View File

@@ -70,6 +70,7 @@ public:
// Stmts don't have anything to do here.
virtual Stmt *Optimize();
virtual Stmt *TypeCheck() = 0;
Stmt *ReplacePolyType(const PolyType *polyType, const Type *replacement);
};

View File

@@ -698,12 +698,14 @@ PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic, int q)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(q) {
asOtherConstType = NULL;
asUniformType = asVaryingType = NULL;
expandedTypes = NULL;
}
@@ -814,6 +816,39 @@ PolyType::GetAsUniformType() const {
return asUniformType;
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandBegin() const {
if (expandedTypes)
return expandedTypes->begin();
expandedTypes = new std::vector<AtomicType *>();
if (restriction == TYPE_INTEGER || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT8, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT16, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT32, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_INT64, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_UINT64, variability, isConst));
}
if (restriction == TYPE_FLOATING || restriction == TYPE_NUMBER) {
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_FLOAT, variability, isConst));
expandedTypes->push_back(new AtomicType(AtomicType::TYPE_DOUBLE, variability, isConst));
}
return expandedTypes->begin();
}
const std::vector<AtomicType *>::iterator
PolyType::ExpandEnd() const {
Assert(expandedTypes != NULL);
return expandedTypes->end();
}
const PolyType *
PolyType::GetAsUnboundVariabilityType() const {

6
type.h
View File

@@ -360,10 +360,10 @@ public:
static const AtomicType *UniformDouble, *VaryingDouble;
static const AtomicType *Void;
AtomicType(BasicType basicType, Variability v, bool isConst);
private:
const Variability variability;
const bool isConst;
AtomicType(BasicType basicType, Variability v, bool isConst);
mutable const AtomicType *asOtherConstType, *asUniformType, *asVaryingType;
};
@@ -420,7 +420,8 @@ public:
// Returns the list of AtomicTypes that are valid instantiations of the
// polymorphic type
const std::vector<AtomicType *> GetEnumeratedTypes() const;
const std::vector<AtomicType *>::iterator ExpandBegin() const;
const std::vector<AtomicType *>::iterator ExpandEnd() const;
private:
const Variability variability;
@@ -430,6 +431,7 @@ private:
PolyType(PolyRestriction type, Variability v, bool isConst, int quant);
mutable const PolyType *asOtherConstType, *asUniformType, *asVaryingType;
mutable std::vector<AtomicType *> *expandedTypes;
};