[WIP] replace polymorphic types from expressions

This commit is contained in:
2017-05-09 01:46:36 -04:00
parent 9c0f9be022
commit aeb4c0b6f9
5 changed files with 182 additions and 33 deletions

View File

@@ -45,6 +45,7 @@
#include "sym.h"
#include "util.h"
#include <stdio.h>
#include <set>
#if ISPC_LLVM_VERSION == ISPC_LLVM_3_2 // 3.2
#ifdef ISPC_NVPTX_ENABLED
@@ -639,41 +640,87 @@ Function::IsPolyFunction() const {
return false;
}
static bool
lPolyTypeLess(const Type *a, const Type *b) {
const PolyType *pa = CastType<PolyType>(a->GetBaseType());
const PolyType *pb = CastType<PolyType>(b->GetBaseType());
if (!pa || !pb) {
char buf[1024];
snprintf(buf, 1024, "Calling lPolyTypeLess on non-polymorphic types"
"\"%s\" and \"%s\"\n",
a->GetString().c_str(), b->GetString().c_str());
FATAL(buf);
}
if (pa->restriction < pb->restriction)
return true;
if (pa->restriction > pb->restriction)
return false;
if (pa->GetQuant() < pb->GetQuant())
return true;
return false;
}
std::vector<Function *> *
Function::ExpandPolyArguments() const {
std::vector<const Type *> toExpand;
std::set<const Type *, bool(*)(const Type *, const Type *)> toExpand(&lPolyTypeLess);
std::vector<Function *> *expanded = new std::vector<Function *>();
for (size_t i = 0; i < args.size(); i++) {
if (args[i]->type->IsPolymorphicType()) {
toExpand.push_back(args[i]->type);
if (args[i]->type->IsPolymorphicType() &&
!toExpand.count(args[i]->type)) {
toExpand.insert(args[i]->type);
}
}
for (size_t i = 0; i < toExpand.size(); i++) {
const PolyType *pt = CastType<PolyType>(toExpand[i]->GetBaseType());
std::set<const Type *>::iterator te;
for (te = toExpand.begin(); te != toExpand.end(); te++) {
const PolyType *pt = CastType<PolyType>((*te)->GetBaseType());
std::vector<AtomicType *>::iterator expanded;
expanded = pt->ExpandBegin();
for (; expanded != pt->ExpandEnd(); expanded++) {
Type *replacement = *expanded;
std::vector<AtomicType *>::iterator expand;
expand = pt->ExpandBegin();
for (; expand != pt->ExpandEnd(); expand++) {
const Type *replacement = *expand;
Stmt *code_r = code->ReplacePolyType(pt, replacement);
if (toExpand[i]->IsPointerType())
replacement = new PointerType(replacement,
toExpand[i]->GetVariability(),
toExpand[i]->IsConstType());
else if (toExpand[i]->IsArrayType())
replacement = new ArrayType(replacement,
(CastType<ArrayType>(toExpand[i]))->GetElementCount());
else if (toExpand[i]->IsReferenceType())
replacement = new ReferenceType(replacement);
const FunctionType *ft = CastType<FunctionType>(sym->type);
llvm::SmallVector<const Type *, 8> nargs;
llvm::SmallVector<std::string, 8> nargsn;
llvm::SmallVector<Expr *, 8> nargsd;
llvm::SmallVector<SourcePos, 8> nargsp;
for (size_t i = 0; i < args.size(); i++) {
if (Type::EqualIgnoringConst(args[i]->type->GetBaseType(), pt)) {
nargs.push_back(PolyType::ReplaceType(args[i]->type, replacement));
} else {
nargs.push_back(args[i]->type);
}
nargsn.push_back(ft->GetParameterName(i));
nargsd.push_back(ft->GetParameterDefault(i));
nargsp.push_back(ft->GetParameterSourcePos(i));
}
printf("pretend I'm replacing %s with %s\n",
toExpand[i]->GetString().c_str(),
replacement->GetString().c_str());
Symbol *nsym = new Symbol(sym->name, sym->pos,
new FunctionType(ft->GetReturnType(),
nargs,
nargsn,
nargsd,
nargsp,
ft->isTask,
ft->isExported,
ft->isExternC,
ft->isUnmasked));
nsym->function = sym->function;
nsym->exportedFunction = sym->exportedFunction;
expanded->push_back(new Function(nsym, code_r));
replacement = PolyType::ReplaceType(*te, replacement);
}
}
return expanded;
}