Translates polymorphic function to a single instance
This commit is contained in:
10
ast.cpp
10
ast.cpp
@@ -58,16 +58,18 @@ ASTNode::~ASTNode() {
|
||||
// AST
|
||||
|
||||
void
|
||||
AST::AddFunction(Symbol *sym, Stmt *code) {
|
||||
AST::AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable) {
|
||||
if (sym == NULL)
|
||||
return;
|
||||
|
||||
Function *f = new Function(sym, code);
|
||||
|
||||
if (f->IsPolyFunction()) {
|
||||
std::vector<Function *> *expanded = f->ExpandPolyArguments();
|
||||
for (size_t i=0; i<expanded->size(); i++)
|
||||
std::vector<Function *> *expanded = f->ExpandPolyArguments(symbolTable);
|
||||
for (size_t i=0; i<expanded->size(); i++) {
|
||||
functions.push_back((*expanded)[i]);
|
||||
}
|
||||
delete expanded;
|
||||
} else {
|
||||
functions.push_back(f);
|
||||
}
|
||||
@@ -540,5 +542,5 @@ TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement)
|
||||
data.polyType = polyType;
|
||||
data.replacement = replacement;
|
||||
|
||||
return WalkAST(root, NULL, lTranslatePolyNode, &replacement);
|
||||
return WalkAST(root, NULL, lTranslatePolyNode, &data);
|
||||
}
|
||||
|
||||
4
ast.h
4
ast.h
@@ -148,7 +148,7 @@ class AST {
|
||||
public:
|
||||
/** Add the AST for a function described by the given declaration
|
||||
information and source code. */
|
||||
void AddFunction(Symbol *sym, Stmt *code);
|
||||
void AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable=NULL);
|
||||
|
||||
/** Generate LLVM IR for all of the functions into the current
|
||||
module. */
|
||||
@@ -207,6 +207,8 @@ extern Stmt *TypeCheck(Stmt *);
|
||||
the given root. */
|
||||
extern int EstimateCost(ASTNode *root);
|
||||
|
||||
extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement);
|
||||
|
||||
/** Returns true if it would be safe to run the given code with an "all
|
||||
off" mask. */
|
||||
extern bool SafeToRunWithMaskAllOff(ASTNode *root);
|
||||
|
||||
16
expr.cpp
16
expr.cpp
@@ -4679,11 +4679,11 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
if (index == NULL || baseExpr == NULL)
|
||||
return NULL;
|
||||
|
||||
if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
|
||||
type = PolyType::ReplaceType(type, to);
|
||||
}
|
||||
|
||||
if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
|
||||
lvalueType = new PointerType(to, lvalueType->GetVariability(),
|
||||
lvalueType->IsConstType());
|
||||
}
|
||||
@@ -5338,11 +5338,11 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
if (expr == NULL)
|
||||
return NULL;
|
||||
|
||||
if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
|
||||
type = PolyType::ReplaceType(type, to);
|
||||
}
|
||||
|
||||
if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
|
||||
lvalueType = PolyType::ReplaceType(lvalueType, lvalueType);
|
||||
}
|
||||
|
||||
@@ -7386,10 +7386,10 @@ TypeCastExpr::Optimize() {
|
||||
|
||||
Expr *
|
||||
TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
if (expr == NULL)
|
||||
if (type == NULL)
|
||||
return NULL;
|
||||
|
||||
if (Type::EqualIgnoringConst(type->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(type->GetBaseType(), from)) {
|
||||
type = PolyType::ReplaceType(type, to);
|
||||
}
|
||||
|
||||
@@ -8071,7 +8071,7 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
if (!symbol)
|
||||
return NULL;
|
||||
|
||||
if (Type::EqualIgnoringConst(symbol->type->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
|
||||
symbol->type = PolyType::ReplaceType(symbol->type, to);
|
||||
}
|
||||
|
||||
@@ -8881,7 +8881,7 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
if (!allocType)
|
||||
return this;
|
||||
|
||||
if (Type::EqualIgnoringConst(allocType->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
|
||||
allocType = PolyType::ReplaceType(allocType, to);
|
||||
}
|
||||
|
||||
|
||||
102
func.cpp
102
func.cpp
@@ -640,87 +640,45 @@ Function::IsPolyFunction() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool
|
||||
lPolyTypeLess(const Type *a, const Type *b) {
|
||||
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
|
||||
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
|
||||
|
||||
if (!pa || !pb) {
|
||||
char buf[1024];
|
||||
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
|
||||
"\"%s\" and \"%s\"\n",
|
||||
a->GetString().c_str(), b->GetString().c_str());
|
||||
FATAL(buf);
|
||||
}
|
||||
|
||||
|
||||
if (pa->restriction < pb->restriction)
|
||||
return true;
|
||||
if (pa->restriction > pb->restriction)
|
||||
return false;
|
||||
|
||||
if (pa->GetQuant() < pb->GetQuant())
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Function *> *
|
||||
Function::ExpandPolyArguments() const {
|
||||
std::set<const Type *, bool(*)(const Type *, const Type *)> toExpand(&lPolyTypeLess);
|
||||
Function::ExpandPolyArguments(SymbolTable *symbolTable) const {
|
||||
Assert(symbolTable != NULL);
|
||||
|
||||
std::vector<Function *> *expanded = new std::vector<Function *>();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (args[i]->type->IsPolymorphicType() &&
|
||||
!toExpand.count(args[i]->type)) {
|
||||
toExpand.insert(args[i]->type);
|
||||
}
|
||||
}
|
||||
std::vector<Symbol *> versions = symbolTable->LookupPolyFunction(sym->name.c_str());
|
||||
|
||||
std::set<const Type *>::iterator te;
|
||||
for (te = toExpand.begin(); te != toExpand.end(); te++) {
|
||||
const PolyType *pt = CastType<PolyType>((*te)->GetBaseType());
|
||||
const FunctionType *func = CastType<FunctionType>(sym->type);
|
||||
|
||||
std::vector<AtomicType *>::iterator expand;
|
||||
expand = pt->ExpandBegin();
|
||||
for (; expand != pt->ExpandEnd(); expand++) {
|
||||
const Type *replacement = *expand;
|
||||
Stmt *code_r = code->ReplacePolyType(pt, replacement);
|
||||
printf("%s before replacing anything:\n", sym->name.c_str());
|
||||
code->Print(0);
|
||||
|
||||
const FunctionType *ft = CastType<FunctionType>(sym->type);
|
||||
llvm::SmallVector<const Type *, 8> nargs;
|
||||
llvm::SmallVector<std::string, 8> nargsn;
|
||||
llvm::SmallVector<Expr *, 8> nargsd;
|
||||
llvm::SmallVector<SourcePos, 8> nargsp;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
if (Type::EqualIgnoringConst(args[i]->type->GetBaseType(), pt)) {
|
||||
nargs.push_back(PolyType::ReplaceType(args[i]->type, replacement));
|
||||
} else {
|
||||
nargs.push_back(args[i]->type);
|
||||
}
|
||||
nargsn.push_back(ft->GetParameterName(i));
|
||||
nargsd.push_back(ft->GetParameterDefault(i));
|
||||
nargsp.push_back(ft->GetParameterSourcePos(i));
|
||||
for (size_t i=0; i<versions.size(); i++) {
|
||||
const FunctionType *ft = CastType<FunctionType>(versions[i]->type);
|
||||
Stmt *ncode = code;
|
||||
|
||||
for (int j=0; j<ft->GetNumParameters(); j++) {
|
||||
if (func->GetParameterType(j)->IsPolymorphicType()) {
|
||||
const PolyType *from = CastType<PolyType>(
|
||||
func->GetParameterType(j)->GetBaseType());
|
||||
|
||||
ncode = (Stmt*)TranslatePoly(ncode, from,
|
||||
ft->GetParameterType(j)->GetBaseType());
|
||||
printf("%s after replacing %s with %s:\n\n",
|
||||
sym->name.c_str(), from->GetString().c_str(),
|
||||
ft->GetParameterType(j)->GetBaseType()->GetString().c_str());
|
||||
|
||||
ncode->Print(0);
|
||||
|
||||
printf("------------------------------------------\n\n");
|
||||
}
|
||||
|
||||
|
||||
Symbol *nsym = new Symbol(sym->name, sym->pos,
|
||||
new FunctionType(ft->GetReturnType(),
|
||||
nargs,
|
||||
nargsn,
|
||||
nargsd,
|
||||
nargsp,
|
||||
ft->isTask,
|
||||
ft->isExported,
|
||||
ft->isExternC,
|
||||
ft->isUnmasked));
|
||||
nsym->function = sym->function;
|
||||
nsym->exportedFunction = sym->exportedFunction;
|
||||
|
||||
expanded->push_back(new Function(nsym, code_r));
|
||||
|
||||
replacement = PolyType::ReplaceType(*te, replacement);
|
||||
}
|
||||
|
||||
Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft);
|
||||
|
||||
expanded->push_back(new Function(s, ncode));
|
||||
}
|
||||
|
||||
return expanded;
|
||||
}
|
||||
|
||||
3
func.h
3
func.h
@@ -39,6 +39,7 @@
|
||||
#define ISPC_FUNC_H 1
|
||||
|
||||
#include "ispc.h"
|
||||
#include "sym.h"
|
||||
#include <vector>
|
||||
|
||||
class Function {
|
||||
@@ -54,7 +55,7 @@ public:
|
||||
/** Checks if the function has polymorphic parameters */
|
||||
const bool IsPolyFunction() const;
|
||||
|
||||
std::vector<Function *> *ExpandPolyArguments() const;
|
||||
std::vector<Function *> *ExpandPolyArguments(SymbolTable *symbolTable) const;
|
||||
|
||||
private:
|
||||
void emitCode(FunctionEmitContext *ctx, llvm::Function *function,
|
||||
|
||||
34
module.cpp
34
module.cpp
@@ -915,8 +915,6 @@ Module::AddFunctionDeclaration(const std::string &name,
|
||||
SourcePos pos) {
|
||||
Assert(functionType != NULL);
|
||||
|
||||
fprintf(stderr, "Adding %s\n", name.c_str());
|
||||
|
||||
// If a global variable with the same name has already been declared
|
||||
// issue an error.
|
||||
if (symbolTable->LookupVariable(name.c_str()) != NULL) {
|
||||
@@ -1020,26 +1018,26 @@ Module::AddFunctionDeclaration(const std::string &name,
|
||||
* these functions will be overloaded if they are not exported, or mangled
|
||||
* if exported */
|
||||
|
||||
std::vector<int> toExpand;
|
||||
std::set<const Type *, bool(*)(const Type*, const Type*)> toExpand(&PolyType::Less);
|
||||
std::vector<const FunctionType *> expanded;
|
||||
expanded.push_back(functionType);
|
||||
|
||||
for (int i=0; i<functionType->GetNumParameters(); i++) {
|
||||
if (functionType->GetParameterType(i)->IsPolymorphicType()) {
|
||||
fprintf(stderr, "Expanding polymorphic function \"%s\"\n",
|
||||
name.c_str());
|
||||
const Type *param = functionType->GetParameterType(i);
|
||||
if (param->IsPolymorphicType() &&
|
||||
!toExpand.count(param->GetBaseType())) {
|
||||
|
||||
toExpand.push_back(i);
|
||||
toExpand.insert(param->GetBaseType());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const FunctionType *> nextExpanded;
|
||||
for (size_t i=0; i<toExpand.size(); i++) {
|
||||
std::set<const Type*>::iterator iter;
|
||||
for (iter = toExpand.begin(); iter != toExpand.end(); iter++) {
|
||||
for (size_t j=0; j<expanded.size(); j++) {
|
||||
const FunctionType *eft = expanded[j];
|
||||
|
||||
const PolyType *pt=CastType<PolyType>(
|
||||
eft->GetParameterType(toExpand[i])->GetBaseType());
|
||||
const PolyType *pt=CastType<PolyType>(*iter);
|
||||
|
||||
std::vector<AtomicType *>::iterator te;
|
||||
for (te = pt->ExpandBegin(); te != pt->ExpandEnd(); te++) {
|
||||
@@ -1048,9 +1046,10 @@ Module::AddFunctionDeclaration(const std::string &name,
|
||||
llvm::SmallVector<Expr *, 8> nargsd;
|
||||
llvm::SmallVector<SourcePos, 8> nargsp;
|
||||
for (size_t k=0; k<eft->GetNumParameters(); k++) {
|
||||
if (k == toExpand[i]) {
|
||||
if (Type::Equal(eft->GetParameterType(k)->GetBaseType(),
|
||||
pt)) {
|
||||
const Type *r;
|
||||
r = PolyType::ReplaceType(eft->GetParameterType(j),*te);
|
||||
r = PolyType::ReplaceType(eft->GetParameterType(k),*te);
|
||||
nargs.push_back(r);
|
||||
} else {
|
||||
nargs.push_back(eft->GetParameterType(k));
|
||||
@@ -1087,8 +1086,8 @@ Module::AddFunctionDeclaration(const std::string &name,
|
||||
}
|
||||
}
|
||||
|
||||
fprintf(stderr, "Adding expanded function %s\n", nname.c_str());
|
||||
|
||||
symbolTable->MapPolyFunction(name, nname, expanded[i]);
|
||||
AddFunctionDeclaration(nname, expanded[i], storageClass,
|
||||
isInline, pos);
|
||||
}
|
||||
@@ -1263,14 +1262,7 @@ Module::AddFunctionDefinition(const std::string &name, const FunctionType *type,
|
||||
|
||||
sym->pos = code->pos;
|
||||
|
||||
// FIXME: because we encode the parameter names in the function type,
|
||||
// we need to override the function type here in case the function had
|
||||
// earlier been declared with anonymous parameter names but is now
|
||||
// defined with actual names. This is yet another reason we shouldn't
|
||||
// include the names in FunctionType...
|
||||
sym->type = type;
|
||||
|
||||
ast->AddFunction(sym, code);
|
||||
ast->AddFunction(sym, code, symbolTable);
|
||||
}
|
||||
|
||||
|
||||
|
||||
4
stmt.cpp
4
stmt.cpp
@@ -503,7 +503,7 @@ Stmt *
|
||||
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
for (size_t i = 0; i < vars.size(); i++) {
|
||||
Symbol *s = vars[i].sym;
|
||||
if (Type::EqualIgnoringConst(s->type->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
|
||||
s->type = PolyType::ReplaceType(s->type, to);
|
||||
}
|
||||
}
|
||||
@@ -2198,7 +2198,7 @@ ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
|
||||
|
||||
for (size_t i=0; i<dimVariables.size(); i++) {
|
||||
const Type *t = dimVariables[i]->type;
|
||||
if (Type::EqualIgnoringConst(t->GetBaseType(), from)) {
|
||||
if (Type::EqualForReplacement(t->GetBaseType(), from)) {
|
||||
t = PolyType::ReplaceType(t, to);
|
||||
}
|
||||
}
|
||||
|
||||
19
sym.cpp
19
sym.cpp
@@ -157,6 +157,14 @@ SymbolTable::AddFunction(Symbol *symbol) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
SymbolTable::MapPolyFunction(std::string name, std::string polyname,
|
||||
const FunctionType *type) {
|
||||
std::vector<Symbol *> &polyExpansions = polyFunctions[name];
|
||||
SourcePos p;
|
||||
polyExpansions.push_back(new Symbol(polyname, p, type, SC_NONE));
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
SymbolTable::LookupFunction(const char *name, std::vector<Symbol *> *matches) {
|
||||
@@ -184,9 +192,20 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) {
|
||||
return funcs[j];
|
||||
}
|
||||
}
|
||||
// Try looking for a polymorphic function
|
||||
if (polyFunctions[name].size() > 0) {
|
||||
std::string n = name;
|
||||
return new Symbol(name, polyFunctions[name][0]->pos, type);
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
std::vector<Symbol *>&
|
||||
SymbolTable::LookupPolyFunction(const char *name) {
|
||||
return polyFunctions[name];
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) {
|
||||
|
||||
13
sym.h
13
sym.h
@@ -108,6 +108,7 @@ public:
|
||||
};
|
||||
|
||||
|
||||
|
||||
/** @brief Symbol table that holds all known symbols during parsing and compilation.
|
||||
|
||||
A single instance of a SymbolTable is stored in the Module class
|
||||
@@ -159,6 +160,14 @@ public:
|
||||
already present in the symbol table. */
|
||||
bool AddFunction(Symbol *symbol);
|
||||
|
||||
/** Adds the given function to the list of polymorphic definitions for the
|
||||
given name
|
||||
@param name The name of the original function
|
||||
@param type The expanded FunctionType */
|
||||
|
||||
void MapPolyFunction(std::string name, std::string polyname,
|
||||
const FunctionType *type);
|
||||
|
||||
/** Looks for the function or functions with the given name in the
|
||||
symbol name. If a function has been overloaded and multiple
|
||||
definitions are present for a given function name, all of them will
|
||||
@@ -174,6 +183,8 @@ public:
|
||||
@return pointer to matching Symbol; NULL if none is found. */
|
||||
Symbol *LookupFunction(const char *name, const FunctionType *type);
|
||||
|
||||
std::vector<Symbol *>& LookupPolyFunction(const char *name);
|
||||
|
||||
/** Returns all of the functions in the symbol table that match the given
|
||||
predicate.
|
||||
|
||||
@@ -276,6 +287,8 @@ private:
|
||||
typedef std::map<std::string, std::vector<Symbol *> > FunctionMapType;
|
||||
FunctionMapType functions;
|
||||
|
||||
FunctionMapType polyFunctions;
|
||||
|
||||
/** Type definitions can't currently be scoped.
|
||||
*/
|
||||
typedef std::map<std::string, const Type *> TypeMapType;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
export void foo(uniform int N, floating$1 X[])
|
||||
export void foo(uniform int N, uniform floating$1 X[])
|
||||
{
|
||||
foreach (i = 0 ... N) {
|
||||
X[i] = X[i] + 1.0;
|
||||
|
||||
41
type.cpp
41
type.cpp
@@ -709,6 +709,9 @@ PolyType::ReplaceType(const Type *from, const Type *to) {
|
||||
t = new ReferenceType(to);
|
||||
}
|
||||
|
||||
if (from->IsVaryingType())
|
||||
t = t->GetAsVaryingType();
|
||||
|
||||
fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n",
|
||||
from->GetString().c_str(),
|
||||
t->GetString().c_str());
|
||||
@@ -716,6 +719,31 @@ PolyType::ReplaceType(const Type *from, const Type *to) {
|
||||
return t;
|
||||
}
|
||||
|
||||
bool
|
||||
PolyType::Less(const Type *a, const Type *b) {
|
||||
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
|
||||
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
|
||||
|
||||
if (!pa || !pb) {
|
||||
char buf[1024];
|
||||
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
|
||||
"\"%s\" and \"%s\"\n",
|
||||
a->GetString().c_str(), b->GetString().c_str());
|
||||
FATAL(buf);
|
||||
}
|
||||
|
||||
|
||||
if (pa->restriction < pb->restriction)
|
||||
return true;
|
||||
if (pa->restriction > pb->restriction)
|
||||
return false;
|
||||
|
||||
if (pa->GetQuant() < pb->GetQuant())
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
|
||||
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
|
||||
asOtherConstType = NULL;
|
||||
@@ -4137,3 +4165,16 @@ bool
|
||||
Type::EqualIgnoringConst(const Type *a, const Type *b) {
|
||||
return lCheckTypeEquality(a, b, true);
|
||||
}
|
||||
|
||||
bool
|
||||
Type::EqualForReplacement(const Type *a, const Type *b) {
|
||||
const PolyType *pa = CastType<PolyType>(a);
|
||||
const PolyType *pb = CastType<PolyType>(b);
|
||||
|
||||
|
||||
if (!pa || !pb)
|
||||
return false;
|
||||
|
||||
return pa->restriction == pb->restriction &&
|
||||
pa->GetQuant() == pb->GetQuant();
|
||||
}
|
||||
|
||||
4
type.h
4
type.h
@@ -244,6 +244,8 @@ public:
|
||||
the same (ignoring const-ness of the type), false otherwise. */
|
||||
static bool EqualIgnoringConst(const Type *a, const Type *b);
|
||||
|
||||
static bool EqualForReplacement(const Type *a, const Type *b);
|
||||
|
||||
/** Given two types, returns the least general Type that is more general
|
||||
than both of them. (i.e. that can represent their values without
|
||||
any loss of data.) If there is no such Type, return NULL.
|
||||
@@ -415,6 +417,8 @@ public:
|
||||
|
||||
static const Type * ReplaceType(const Type *from, const Type *to);
|
||||
|
||||
static bool Less(const Type *a, const Type *b);
|
||||
|
||||
|
||||
static const PolyType *UniformInteger, *VaryingInteger;
|
||||
static const PolyType *UniformFloating, *VaryingFloating;
|
||||
|
||||
Reference in New Issue
Block a user