Translates polymorphic function to a single instance

This commit is contained in:
2017-05-09 23:41:36 -04:00
parent 871af918ad
commit 192b99f21d
12 changed files with 142 additions and 110 deletions

10
ast.cpp
View File

@@ -58,16 +58,18 @@ ASTNode::~ASTNode() {
// AST
void
AST::AddFunction(Symbol *sym, Stmt *code) {
AST::AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable) {
if (sym == NULL)
return;
Function *f = new Function(sym, code);
if (f->IsPolyFunction()) {
std::vector<Function *> *expanded = f->ExpandPolyArguments();
for (size_t i=0; i<expanded->size(); i++)
std::vector<Function *> *expanded = f->ExpandPolyArguments(symbolTable);
for (size_t i=0; i<expanded->size(); i++) {
functions.push_back((*expanded)[i]);
}
delete expanded;
} else {
functions.push_back(f);
}
@@ -540,5 +542,5 @@ TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement)
data.polyType = polyType;
data.replacement = replacement;
return WalkAST(root, NULL, lTranslatePolyNode, &replacement);
return WalkAST(root, NULL, lTranslatePolyNode, &data);
}

4
ast.h
View File

@@ -148,7 +148,7 @@ class AST {
public:
/** Add the AST for a function described by the given declaration
information and source code. */
void AddFunction(Symbol *sym, Stmt *code);
void AddFunction(Symbol *sym, Stmt *code, SymbolTable *symbolTable=NULL);
/** Generate LLVM IR for all of the functions into the current
module. */
@@ -207,6 +207,8 @@ extern Stmt *TypeCheck(Stmt *);
the given root. */
extern int EstimateCost(ASTNode *root);
extern ASTNode * TranslatePoly(ASTNode *root, const PolyType *polyType, const Type *replacement);
/** Returns true if it would be safe to run the given code with an "all
off" mask. */
extern bool SafeToRunWithMaskAllOff(ASTNode *root);

View File

@@ -4679,11 +4679,11 @@ IndexExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (index == NULL || baseExpr == NULL)
return NULL;
if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) {
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) {
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
lvalueType = new PointerType(to, lvalueType->GetVariability(),
lvalueType->IsConstType());
}
@@ -5338,11 +5338,11 @@ MemberExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL)
return NULL;
if (Type::EqualIgnoringConst(this->GetType()->GetBaseType(), from)) {
if (Type::EqualForReplacement(this->GetType()->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
if (Type::EqualIgnoringConst(this->GetLValueType()->GetBaseType(), from)) {
if (Type::EqualForReplacement(this->GetLValueType()->GetBaseType(), from)) {
lvalueType = PolyType::ReplaceType(lvalueType, lvalueType);
}
@@ -7386,10 +7386,10 @@ TypeCastExpr::Optimize() {
Expr *
TypeCastExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (expr == NULL)
if (type == NULL)
return NULL;
if (Type::EqualIgnoringConst(type->GetBaseType(), from)) {
if (Type::EqualForReplacement(type->GetBaseType(), from)) {
type = PolyType::ReplaceType(type, to);
}
@@ -8071,7 +8071,7 @@ SymbolExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!symbol)
return NULL;
if (Type::EqualIgnoringConst(symbol->type->GetBaseType(), from)) {
if (Type::EqualForReplacement(symbol->type->GetBaseType(), from)) {
symbol->type = PolyType::ReplaceType(symbol->type, to);
}
@@ -8881,7 +8881,7 @@ NewExpr::ReplacePolyType(const PolyType *from, const Type *to) {
if (!allocType)
return this;
if (Type::EqualIgnoringConst(allocType->GetBaseType(), from)) {
if (Type::EqualForReplacement(allocType->GetBaseType(), from)) {
allocType = PolyType::ReplaceType(allocType, to);
}

102
func.cpp
View File

@@ -640,87 +640,45 @@ Function::IsPolyFunction() const {
return false;
}
static bool
lPolyTypeLess(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
if (!pa || !pb) {
char buf[1024];
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
"\"%s\" and \"%s\"\n",
a->GetString().c_str(), b->GetString().c_str());
FATAL(buf);
}
if (pa->restriction < pb->restriction)
return true;
if (pa->restriction > pb->restriction)
return false;
if (pa->GetQuant() < pb->GetQuant())
return true;
return false;
}
std::vector<Function *> *
Function::ExpandPolyArguments() const {
std::set<const Type *, bool(*)(const Type *, const Type *)> toExpand(&lPolyTypeLess);
Function::ExpandPolyArguments(SymbolTable *symbolTable) const {
Assert(symbolTable != NULL);
std::vector<Function *> *expanded = new std::vector<Function *>();
for (size_t i = 0; i < args.size(); i++) {
if (args[i]->type->IsPolymorphicType() &&
!toExpand.count(args[i]->type)) {
toExpand.insert(args[i]->type);
}
}
std::vector<Symbol *> versions = symbolTable->LookupPolyFunction(sym->name.c_str());
std::set<const Type *>::iterator te;
for (te = toExpand.begin(); te != toExpand.end(); te++) {
const PolyType *pt = CastType<PolyType>((*te)->GetBaseType());
const FunctionType *func = CastType<FunctionType>(sym->type);
std::vector<AtomicType *>::iterator expand;
expand = pt->ExpandBegin();
for (; expand != pt->ExpandEnd(); expand++) {
const Type *replacement = *expand;
Stmt *code_r = code->ReplacePolyType(pt, replacement);
printf("%s before replacing anything:\n", sym->name.c_str());
code->Print(0);
const FunctionType *ft = CastType<FunctionType>(sym->type);
llvm::SmallVector<const Type *, 8> nargs;
llvm::SmallVector<std::string, 8> nargsn;
llvm::SmallVector<Expr *, 8> nargsd;
llvm::SmallVector<SourcePos, 8> nargsp;
for (size_t i = 0; i < args.size(); i++) {
if (Type::EqualIgnoringConst(args[i]->type->GetBaseType(), pt)) {
nargs.push_back(PolyType::ReplaceType(args[i]->type, replacement));
} else {
nargs.push_back(args[i]->type);
}
nargsn.push_back(ft->GetParameterName(i));
nargsd.push_back(ft->GetParameterDefault(i));
nargsp.push_back(ft->GetParameterSourcePos(i));
for (size_t i=0; i<versions.size(); i++) {
const FunctionType *ft = CastType<FunctionType>(versions[i]->type);
Stmt *ncode = code;
for (int j=0; j<ft->GetNumParameters(); j++) {
if (func->GetParameterType(j)->IsPolymorphicType()) {
const PolyType *from = CastType<PolyType>(
func->GetParameterType(j)->GetBaseType());
ncode = (Stmt*)TranslatePoly(ncode, from,
ft->GetParameterType(j)->GetBaseType());
printf("%s after replacing %s with %s:\n\n",
sym->name.c_str(), from->GetString().c_str(),
ft->GetParameterType(j)->GetBaseType()->GetString().c_str());
ncode->Print(0);
printf("------------------------------------------\n\n");
}
Symbol *nsym = new Symbol(sym->name, sym->pos,
new FunctionType(ft->GetReturnType(),
nargs,
nargsn,
nargsd,
nargsp,
ft->isTask,
ft->isExported,
ft->isExternC,
ft->isUnmasked));
nsym->function = sym->function;
nsym->exportedFunction = sym->exportedFunction;
expanded->push_back(new Function(nsym, code_r));
replacement = PolyType::ReplaceType(*te, replacement);
}
Symbol *s = symbolTable->LookupFunction(versions[i]->name.c_str(), ft);
expanded->push_back(new Function(s, ncode));
}
return expanded;
}

3
func.h
View File

@@ -39,6 +39,7 @@
#define ISPC_FUNC_H 1
#include "ispc.h"
#include "sym.h"
#include <vector>
class Function {
@@ -54,7 +55,7 @@ public:
/** Checks if the function has polymorphic parameters */
const bool IsPolyFunction() const;
std::vector<Function *> *ExpandPolyArguments() const;
std::vector<Function *> *ExpandPolyArguments(SymbolTable *symbolTable) const;
private:
void emitCode(FunctionEmitContext *ctx, llvm::Function *function,

View File

@@ -915,8 +915,6 @@ Module::AddFunctionDeclaration(const std::string &name,
SourcePos pos) {
Assert(functionType != NULL);
fprintf(stderr, "Adding %s\n", name.c_str());
// If a global variable with the same name has already been declared
// issue an error.
if (symbolTable->LookupVariable(name.c_str()) != NULL) {
@@ -1020,26 +1018,26 @@ Module::AddFunctionDeclaration(const std::string &name,
* these functions will be overloaded if they are not exported, or mangled
* if exported */
std::vector<int> toExpand;
std::set<const Type *, bool(*)(const Type*, const Type*)> toExpand(&PolyType::Less);
std::vector<const FunctionType *> expanded;
expanded.push_back(functionType);
for (int i=0; i<functionType->GetNumParameters(); i++) {
if (functionType->GetParameterType(i)->IsPolymorphicType()) {
fprintf(stderr, "Expanding polymorphic function \"%s\"\n",
name.c_str());
const Type *param = functionType->GetParameterType(i);
if (param->IsPolymorphicType() &&
!toExpand.count(param->GetBaseType())) {
toExpand.push_back(i);
toExpand.insert(param->GetBaseType());
}
}
std::vector<const FunctionType *> nextExpanded;
for (size_t i=0; i<toExpand.size(); i++) {
std::set<const Type*>::iterator iter;
for (iter = toExpand.begin(); iter != toExpand.end(); iter++) {
for (size_t j=0; j<expanded.size(); j++) {
const FunctionType *eft = expanded[j];
const PolyType *pt=CastType<PolyType>(
eft->GetParameterType(toExpand[i])->GetBaseType());
const PolyType *pt=CastType<PolyType>(*iter);
std::vector<AtomicType *>::iterator te;
for (te = pt->ExpandBegin(); te != pt->ExpandEnd(); te++) {
@@ -1048,9 +1046,10 @@ Module::AddFunctionDeclaration(const std::string &name,
llvm::SmallVector<Expr *, 8> nargsd;
llvm::SmallVector<SourcePos, 8> nargsp;
for (size_t k=0; k<eft->GetNumParameters(); k++) {
if (k == toExpand[i]) {
if (Type::Equal(eft->GetParameterType(k)->GetBaseType(),
pt)) {
const Type *r;
r = PolyType::ReplaceType(eft->GetParameterType(j),*te);
r = PolyType::ReplaceType(eft->GetParameterType(k),*te);
nargs.push_back(r);
} else {
nargs.push_back(eft->GetParameterType(k));
@@ -1087,8 +1086,8 @@ Module::AddFunctionDeclaration(const std::string &name,
}
}
fprintf(stderr, "Adding expanded function %s\n", nname.c_str());
symbolTable->MapPolyFunction(name, nname, expanded[i]);
AddFunctionDeclaration(nname, expanded[i], storageClass,
isInline, pos);
}
@@ -1263,14 +1262,7 @@ Module::AddFunctionDefinition(const std::string &name, const FunctionType *type,
sym->pos = code->pos;
// FIXME: because we encode the parameter names in the function type,
// we need to override the function type here in case the function had
// earlier been declared with anonymous parameter names but is now
// defined with actual names. This is yet another reason we shouldn't
// include the names in FunctionType...
sym->type = type;
ast->AddFunction(sym, code);
ast->AddFunction(sym, code, symbolTable);
}

View File

@@ -503,7 +503,7 @@ Stmt *
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
for (size_t i = 0; i < vars.size(); i++) {
Symbol *s = vars[i].sym;
if (Type::EqualIgnoringConst(s->type->GetBaseType(), from)) {
if (Type::EqualForReplacement(s->type->GetBaseType(), from)) {
s->type = PolyType::ReplaceType(s->type, to);
}
}
@@ -2198,7 +2198,7 @@ ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
for (size_t i=0; i<dimVariables.size(); i++) {
const Type *t = dimVariables[i]->type;
if (Type::EqualIgnoringConst(t->GetBaseType(), from)) {
if (Type::EqualForReplacement(t->GetBaseType(), from)) {
t = PolyType::ReplaceType(t, to);
}
}

19
sym.cpp
View File

@@ -157,6 +157,14 @@ SymbolTable::AddFunction(Symbol *symbol) {
return true;
}
void
SymbolTable::MapPolyFunction(std::string name, std::string polyname,
const FunctionType *type) {
std::vector<Symbol *> &polyExpansions = polyFunctions[name];
SourcePos p;
polyExpansions.push_back(new Symbol(polyname, p, type, SC_NONE));
}
bool
SymbolTable::LookupFunction(const char *name, std::vector<Symbol *> *matches) {
@@ -184,9 +192,20 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) {
return funcs[j];
}
}
// Try looking for a polymorphic function
if (polyFunctions[name].size() > 0) {
std::string n = name;
return new Symbol(name, polyFunctions[name][0]->pos, type);
}
return NULL;
}
std::vector<Symbol *>&
SymbolTable::LookupPolyFunction(const char *name) {
return polyFunctions[name];
}
bool
SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) {

13
sym.h
View File

@@ -108,6 +108,7 @@ public:
};
/** @brief Symbol table that holds all known symbols during parsing and compilation.
A single instance of a SymbolTable is stored in the Module class
@@ -159,6 +160,14 @@ public:
already present in the symbol table. */
bool AddFunction(Symbol *symbol);
/** Adds the given function to the list of polymorphic definitions for the
given name
@param name The name of the original function
@param type The expanded FunctionType */
void MapPolyFunction(std::string name, std::string polyname,
const FunctionType *type);
/** Looks for the function or functions with the given name in the
symbol name. If a function has been overloaded and multiple
definitions are present for a given function name, all of them will
@@ -174,6 +183,8 @@ public:
@return pointer to matching Symbol; NULL if none is found. */
Symbol *LookupFunction(const char *name, const FunctionType *type);
std::vector<Symbol *>& LookupPolyFunction(const char *name);
/** Returns all of the functions in the symbol table that match the given
predicate.
@@ -276,6 +287,8 @@ private:
typedef std::map<std::string, std::vector<Symbol *> > FunctionMapType;
FunctionMapType functions;
FunctionMapType polyFunctions;
/** Type definitions can't currently be scoped.
*/
typedef std::map<std::string, const Type *> TypeMapType;

View File

@@ -1,4 +1,4 @@
export void foo(uniform int N, floating$1 X[])
export void foo(uniform int N, uniform floating$1 X[])
{
foreach (i = 0 ... N) {
X[i] = X[i] + 1.0;

View File

@@ -709,6 +709,9 @@ PolyType::ReplaceType(const Type *from, const Type *to) {
t = new ReferenceType(to);
}
if (from->IsVaryingType())
t = t->GetAsVaryingType();
fprintf(stderr, "Replacing type \"%s\" with \"%s\"\n",
from->GetString().c_str(),
t->GetString().c_str());
@@ -716,6 +719,31 @@ PolyType::ReplaceType(const Type *from, const Type *to) {
return t;
}
bool
PolyType::Less(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
if (!pa || !pb) {
char buf[1024];
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
"\"%s\" and \"%s\"\n",
a->GetString().c_str(), b->GetString().c_str());
FATAL(buf);
}
if (pa->restriction < pb->restriction)
return true;
if (pa->restriction > pb->restriction)
return false;
if (pa->GetQuant() < pb->GetQuant())
return true;
return false;
}
PolyType::PolyType(PolyRestriction r, Variability v, bool ic)
: Type(POLY_TYPE), restriction(r), variability(v), isConst(ic), quant(-1) {
asOtherConstType = NULL;
@@ -4137,3 +4165,16 @@ bool
Type::EqualIgnoringConst(const Type *a, const Type *b) {
return lCheckTypeEquality(a, b, true);
}
bool
Type::EqualForReplacement(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a);
const PolyType *pb = CastType<PolyType>(b);
if (!pa || !pb)
return false;
return pa->restriction == pb->restriction &&
pa->GetQuant() == pb->GetQuant();
}

4
type.h
View File

@@ -244,6 +244,8 @@ public:
the same (ignoring const-ness of the type), false otherwise. */
static bool EqualIgnoringConst(const Type *a, const Type *b);
static bool EqualForReplacement(const Type *a, const Type *b);
/** Given two types, returns the least general Type that is more general
than both of them. (i.e. that can represent their values without
any loss of data.) If there is no such Type, return NULL.
@@ -415,6 +417,8 @@ public:
static const Type * ReplaceType(const Type *from, const Type *to);
static bool Less(const Type *a, const Type *b);
static const PolyType *UniformInteger, *VaryingInteger;
static const PolyType *UniformFloating, *VaryingFloating;