[WIP] implement ReplacePolyType for stmts

This commit is contained in:
2017-05-09 15:30:39 -04:00
parent aeb4c0b6f9
commit 7bb1741b9a
3 changed files with 158 additions and 43 deletions

View File

@@ -349,7 +349,7 @@ lStripUnusedDebugInfo(llvm::Module *module) {
// And now we can go and stuff it into the unit with some // And now we can go and stuff it into the unit with some
// confidence... // confidence...
llvm::MDNode *replNode = llvm::MDNode::get(module->getContext(), llvm::MDNode *replNode = llvm::MDNode::get(module->getContext(),
llvm::ArrayRef<llvm::Metadata *>(usedSubprograms)); llvm::ArrayRef<llvm::Metadata *>(usedSubprograms));
cu.replaceSubprograms(llvm::DIArray(replNode)); cu.replaceSubprograms(llvm::DIArray(replNode));
#else // LLVM 3.7+ #else // LLVM 3.7+
@@ -589,7 +589,7 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE
} }
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
if (g->target->getISA() == Target::NVPTX && if (g->target->getISA() == Target::NVPTX &&
#if 0 #if 0
!type->IsConstType() && !type->IsConstType() &&
#endif #endif
@@ -609,7 +609,7 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE
* or 128 threads. * or 128 threads.
* ***note-to-me***:please define these value (128threads/4warps) * ***note-to-me***:please define these value (128threads/4warps)
* in nvptx-target definition * in nvptx-target definition
* instead of compile-time constants * instead of compile-time constants
*/ */
nel *= at->GetElementCount(); nel *= at->GetElementCount();
assert (!type->IsSOAType()); assert (!type->IsSOAType());
@@ -830,7 +830,7 @@ lRecursiveCheckValidParamType(const Type *t, bool vectorOk) {
if (pt != NULL) { if (pt != NULL) {
// Only allow exported uniform pointers // Only allow exported uniform pointers
// Uniform pointers to varying data, however, are ok. // Uniform pointers to varying data, however, are ok.
if (pt->IsVaryingType()) if (pt->IsVaryingType())
return false; return false;
else else
return lRecursiveCheckValidParamType(pt->GetBaseType(), true); return lRecursiveCheckValidParamType(pt->GetBaseType(), true);
@@ -838,7 +838,7 @@ lRecursiveCheckValidParamType(const Type *t, bool vectorOk) {
if (t->IsVaryingType() && !vectorOk) if (t->IsVaryingType() && !vectorOk)
return false; return false;
else else
return true; return true;
} }
@@ -871,7 +871,7 @@ lCheckExportedParameterTypes(const Type *type, const std::string &name,
static void static void
lCheckTaskParameterTypes(const Type *type, const std::string &name, lCheckTaskParameterTypes(const Type *type, const std::string &name,
SourcePos pos) { SourcePos pos) {
if (g->target->getISA() != Target::NVPTX) if (g->target->getISA() != Target::NVPTX)
return; return;
if (lRecursiveCheckValidParamType(type, false) == false) { if (lRecursiveCheckValidParamType(type, false) == false) {
if (CastType<VectorType>(type)) if (CastType<VectorType>(type))
@@ -915,6 +915,8 @@ Module::AddFunctionDeclaration(const std::string &name,
SourcePos pos) { SourcePos pos) {
Assert(functionType != NULL); Assert(functionType != NULL);
fprintf(stderr, "Adding %s\n", name.c_str());
// If a global variable with the same name has already been declared // If a global variable with the same name has already been declared
// issue an error. // issue an error.
if (symbolTable->LookupVariable(name.c_str()) != NULL) { if (symbolTable->LookupVariable(name.c_str()) != NULL) {
@@ -1009,6 +1011,90 @@ Module::AddFunctionDeclaration(const std::string &name,
} }
} }
/* Handle Polymorphic functions
* a function
* int foo(number n, floating, f)
* will produce versions such as
* int foo(int n, float f)
*
* these functions will be overloaded if they are not exported, or mangled
* if exported */
std::vector<int> toExpand;
std::vector<const FunctionType *> expanded;
expanded.push_back(functionType);
for (int i=0; i<functionType->GetNumParameters(); i++) {
if (functionType->GetParameterType(i)->IsPolymorphicType()) {
fprintf(stderr, "Expanding polymorphic function \"%s\"\n",
name.c_str());
toExpand.push_back(i);
}
}
std::vector<const FunctionType *> nextExpanded;
for (size_t i=0; i<toExpand.size(); i++) {
for (size_t j=0; j<expanded.size(); j++) {
const FunctionType *eft = expanded[j];
const PolyType *pt=CastType<PolyType>(
eft->GetParameterType(toExpand[i])->GetBaseType());
std::vector<AtomicType *>::iterator te;
for (te = pt->ExpandBegin(); te != pt->ExpandEnd(); te++) {
llvm::SmallVector<const Type *, 8> nargs;
llvm::SmallVector<std::string, 8> nargsn;
llvm::SmallVector<Expr *, 8> nargsd;
llvm::SmallVector<SourcePos, 8> nargsp;
for (size_t k=0; k<eft->GetNumParameters(); k++) {
if (k == toExpand[i]) {
const Type *r;
r = PolyType::ReplaceType(eft->GetParameterType(j),*te);
nargs.push_back(r);
} else {
nargs.push_back(eft->GetParameterType(k));
}
nargsn.push_back(eft->GetParameterName(k));
nargsd.push_back(eft->GetParameterDefault(k));
nargsp.push_back(eft->GetParameterSourcePos(k));
}
nextExpanded.push_back(new FunctionType(eft->GetReturnType(),
nargs,
nargsn,
nargsd,
nargsp,
eft->isTask,
eft->isExported,
eft->isExternC,
eft->isUnmasked));
}
}
expanded.swap(nextExpanded);
nextExpanded.clear();
}
if (expanded.size() > 1) {
for (size_t i=0; i<expanded.size(); i++) {
std::string nname = name;
if (functionType->isExported || functionType->isExternC) {
for (int j=0; j<expanded[i]->GetNumParameters(); j++) {
nname += "_";
nname += expanded[i]->GetParameterType(j)->Mangle();
}
}
fprintf(stderr, "Adding expanded function %s\n", nname.c_str());
AddFunctionDeclaration(nname, expanded[i], storageClass,
isInline, pos);
}
return;
}
// Get the LLVM FunctionType // Get the LLVM FunctionType
bool disableMask = (storageClass == SC_EXTERN_C); bool disableMask = (storageClass == SC_EXTERN_C);
llvm::FunctionType *llvmFunctionType = llvm::FunctionType *llvmFunctionType =
@@ -1026,7 +1112,7 @@ Module::AddFunctionDeclaration(const std::string &name,
functionName += functionType->Mangle(); functionName += functionType->Mangle();
// If we treat generic as smth, we should have appropriate mangling // If we treat generic as smth, we should have appropriate mangling
if (g->mangleFunctionsWithTarget) { if (g->mangleFunctionsWithTarget) {
if (g->target->getISA() == Target::GENERIC && if (g->target->getISA() == Target::GENERIC &&
!g->target->getTreatGenericAsSmth().empty()) !g->target->getTreatGenericAsSmth().empty())
functionName += g->target->getTreatGenericAsSmth(); functionName += g->target->getTreatGenericAsSmth();
else else
@@ -1326,7 +1412,7 @@ Module::writeOutput(OutputType outputType, const char *outFileName,
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
typedef std::vector<std::string> vecString_t; typedef std::vector<std::string> vecString_t;
static vecString_t static vecString_t
lSplitString(const std::string &s) lSplitString(const std::string &s)
{ {
std::stringstream ss(s); std::stringstream ss(s);
@@ -1335,7 +1421,7 @@ lSplitString(const std::string &s)
return vecString_t(begin,end); return vecString_t(begin,end);
} }
static void static void
lFixAttributes(const vecString_t &src, vecString_t &dst) lFixAttributes(const vecString_t &src, vecString_t &dst)
{ {
dst.clear(); dst.clear();
@@ -1434,7 +1520,7 @@ Module::writeBitcode(llvm::Module *module, const char *outFileName) {
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
if (g->target->getISA() == Target::NVPTX) if (g->target->getISA() == Target::NVPTX)
{ {
/* when using "nvptx" target, emit patched/hacked assembly /* when using "nvptx" target, emit patched/hacked assembly
* NVPTX only accepts 3.2-style LLVM assembly, where attributes * NVPTX only accepts 3.2-style LLVM assembly, where attributes
* must be inlined, rather then referenced by #attribute_d * must be inlined, rather then referenced by #attribute_d
* As soon as NVVM support 3.3,3.4 style assembly this fix won't be needed * As soon as NVVM support 3.3,3.4 style assembly this fix won't be needed
@@ -1506,7 +1592,7 @@ Module::writeObjectFileOrAssembly(llvm::TargetMachine *targetMachine,
#if ISPC_LLVM_VERSION <= ISPC_LLVM_3_5 #if ISPC_LLVM_VERSION <= ISPC_LLVM_3_5
std::string error; std::string error;
#else // LLVM 3.6+ #else // LLVM 3.6+
std::error_code error; std::error_code error;
#endif #endif
@@ -1518,7 +1604,7 @@ Module::writeObjectFileOrAssembly(llvm::TargetMachine *targetMachine,
#if ISPC_LLVM_VERSION <= ISPC_LLVM_3_5 #if ISPC_LLVM_VERSION <= ISPC_LLVM_3_5
if (error.size()) { if (error.size()) {
#else // LLVM 3.6+ #else // LLVM 3.6+
if (error) { if (error) {
#endif #endif
@@ -1603,7 +1689,7 @@ static void
lEmitStructDecl(const StructType *st, std::vector<const StructType *> *emittedStructs, lEmitStructDecl(const StructType *st, std::vector<const StructType *> *emittedStructs,
FILE *file, bool emitUnifs=true) { FILE *file, bool emitUnifs=true) {
// if we're emitting this for a generic dispatch header file and it's // if we're emitting this for a generic dispatch header file and it's
// struct that only contains uniforms, don't bother if we're emitting uniforms // struct that only contains uniforms, don't bother if we're emitting uniforms
if (!emitUnifs && !lContainsPtrToVarying(st)) { if (!emitUnifs && !lContainsPtrToVarying(st)) {
return; return;
@@ -1626,7 +1712,7 @@ lEmitStructDecl(const StructType *st, std::vector<const StructType *> *emittedSt
// And now it's safe to declare this one // And now it's safe to declare this one
emittedStructs->push_back(st); emittedStructs->push_back(st);
fprintf(file, "#ifndef __ISPC_STRUCT_%s__\n",st->GetCStructName().c_str()); fprintf(file, "#ifndef __ISPC_STRUCT_%s__\n",st->GetCStructName().c_str());
fprintf(file, "#define __ISPC_STRUCT_%s__\n",st->GetCStructName().c_str()); fprintf(file, "#define __ISPC_STRUCT_%s__\n",st->GetCStructName().c_str());
@@ -1848,7 +1934,7 @@ lGetExportedTypes(const Type *type,
lGetExportedTypes(ftype->GetParameterType(j), exportedStructTypes, lGetExportedTypes(ftype->GetParameterType(j), exportedStructTypes,
exportedEnumTypes, exportedVectorTypes); exportedEnumTypes, exportedVectorTypes);
} }
else else
Assert(CastType<AtomicType>(type) != NULL); Assert(CastType<AtomicType>(type) != NULL);
} }
@@ -2349,7 +2435,7 @@ struct DispatchHeaderInfo {
bool bool
Module::writeDispatchHeader(DispatchHeaderInfo *DHI) { Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
FILE *f = DHI->file; FILE *f = DHI->file;
if (DHI->EmitFrontMatter) { if (DHI->EmitFrontMatter) {
fprintf(f, "//\n// %s\n// (Header automatically generated by the ispc compiler.)\n", DHI->fn); fprintf(f, "//\n// %s\n// (Header automatically generated by the ispc compiler.)\n", DHI->fn);
fprintf(f, "// DO NOT EDIT THIS FILE.\n//\n\n"); fprintf(f, "// DO NOT EDIT THIS FILE.\n//\n\n");
@@ -2392,10 +2478,10 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
std::vector<Symbol *> exportedFuncs, externCFuncs; std::vector<Symbol *> exportedFuncs, externCFuncs;
m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs); m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs);
m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs); m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs);
int programCount = g->target->getVectorWidth(); int programCount = g->target->getVectorWidth();
if ((DHI->Emit4 && (programCount == 4)) || if ((DHI->Emit4 && (programCount == 4)) ||
(DHI->Emit8 && (programCount == 8)) || (DHI->Emit8 && (programCount == 8)) ||
(DHI->Emit16 && (programCount == 16))) { (DHI->Emit16 && (programCount == 16))) {
// Get all of the struct, vector, and enumerant types used as function // Get all of the struct, vector, and enumerant types used as function
@@ -2407,7 +2493,7 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
&exportedEnumTypes, &exportedVectorTypes); &exportedEnumTypes, &exportedVectorTypes);
lGetExportedParamTypes(externCFuncs, &exportedStructTypes, lGetExportedParamTypes(externCFuncs, &exportedStructTypes,
&exportedEnumTypes, &exportedVectorTypes); &exportedEnumTypes, &exportedVectorTypes);
// Go through the explicitly exported types // Go through the explicitly exported types
for (int i = 0; i < (int)exportedTypes.size(); ++i) { for (int i = 0; i < (int)exportedTypes.size(); ++i) {
if (const StructType *st = CastType<StructType>(exportedTypes[i].first)) if (const StructType *st = CastType<StructType>(exportedTypes[i].first))
@@ -2420,19 +2506,19 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
FATAL("Unexpected type in export list"); FATAL("Unexpected type in export list");
} }
// And print them // And print them
if (DHI->EmitUnifs) { if (DHI->EmitUnifs) {
lEmitVectorTypedefs(exportedVectorTypes, f); lEmitVectorTypedefs(exportedVectorTypes, f);
lEmitEnumDecls(exportedEnumTypes, f); lEmitEnumDecls(exportedEnumTypes, f);
} }
lEmitStructDecls(exportedStructTypes, f, DHI->EmitUnifs); lEmitStructDecls(exportedStructTypes, f, DHI->EmitUnifs);
// Update flags // Update flags
DHI->EmitUnifs = false; DHI->EmitUnifs = false;
if (programCount == 4) { if (programCount == 4) {
DHI->Emit4 = false; DHI->Emit4 = false;
} }
else if (programCount == 8) { else if (programCount == 8) {
DHI->Emit8 = false; DHI->Emit8 = false;
} }
@@ -2457,12 +2543,12 @@ Module::writeDispatchHeader(DispatchHeaderInfo *DHI) {
// end namespace // end namespace
fprintf(f, "\n"); fprintf(f, "\n");
fprintf(f, "\n#ifdef __cplusplus\n} /* namespace */\n#endif // __cplusplus\n"); fprintf(f, "\n#ifdef __cplusplus\n} /* namespace */\n#endif // __cplusplus\n");
// end guard // end guard
fprintf(f, "\n#endif // %s\n", guard.c_str()); fprintf(f, "\n#endif // %s\n", guard.c_str());
DHI->EmitBackMatter = false; DHI->EmitBackMatter = false;
} }
return true; return true;
} }
@@ -2477,17 +2563,17 @@ Module::execPreprocessor(const char *infilename, llvm::raw_string_ostream *ostre
clang::DiagnosticOptions *diagOptions = new clang::DiagnosticOptions(); clang::DiagnosticOptions *diagOptions = new clang::DiagnosticOptions();
clang::TextDiagnosticPrinter *diagPrinter = clang::TextDiagnosticPrinter *diagPrinter =
new clang::TextDiagnosticPrinter(stderrRaw, diagOptions); new clang::TextDiagnosticPrinter(stderrRaw, diagOptions);
llvm::IntrusiveRefCntPtr<clang::DiagnosticIDs> diagIDs(new clang::DiagnosticIDs); llvm::IntrusiveRefCntPtr<clang::DiagnosticIDs> diagIDs(new clang::DiagnosticIDs);
clang::DiagnosticsEngine *diagEngine = clang::DiagnosticsEngine *diagEngine =
new clang::DiagnosticsEngine(diagIDs, diagOptions, diagPrinter); new clang::DiagnosticsEngine(diagIDs, diagOptions, diagPrinter);
inst.setDiagnostics(diagEngine); inst.setDiagnostics(diagEngine);
#if ISPC_LLVM_VERSION <= ISPC_LLVM_3_4 // 3.2, 3.3, 3.4 #if ISPC_LLVM_VERSION <= ISPC_LLVM_3_4 // 3.2, 3.3, 3.4
clang::TargetOptions &options = inst.getTargetOpts(); clang::TargetOptions &options = inst.getTargetOpts();
#else // LLVM 3.5+ #else // LLVM 3.5+
const std::shared_ptr< clang::TargetOptions > &options = const std::shared_ptr< clang::TargetOptions > &options =
std::make_shared< clang::TargetOptions >(inst.getTargetOpts()); std::make_shared< clang::TargetOptions >(inst.getTargetOpts());
#endif #endif
@@ -2654,7 +2740,7 @@ lGetTargetFileName(const char *outFileName, const char *isaString, bool forceCXX
strcpy(targetOutFileName, outFileName); strcpy(targetOutFileName, outFileName);
strcat(targetOutFileName, "_"); strcat(targetOutFileName, "_");
strcat(targetOutFileName, isaString); strcat(targetOutFileName, isaString);
// Append ".cpp" suffix to the original file if it is *-generic target // Append ".cpp" suffix to the original file if it is *-generic target
if (forceCXX) if (forceCXX)
strcat(targetOutFileName, ".cpp"); strcat(targetOutFileName, ".cpp");
@@ -2760,11 +2846,11 @@ lGetVaryingDispatchType(FunctionTargetVariants &funcs) {
} }
} }
} }
// We should've found at least one variant here // We should've found at least one variant here
// or else something fishy is going on. // or else something fishy is going on.
Assert(resultFuncTy); Assert(resultFuncTy);
return resultFuncTy; return resultFuncTy;
} }
@@ -2847,7 +2933,7 @@ lCreateDispatchFunction(llvm::Module *module, llvm::Function *setISAFunc,
// dispatchNum is needed to separate generic from *-generic target // dispatchNum is needed to separate generic from *-generic target
int dispatchNum = i; int dispatchNum = i;
if ((Target::ISA)(i == Target::GENERIC) && if ((Target::ISA)(i == Target::GENERIC) &&
!g->target->getTreatGenericAsSmth().empty()) { !g->target->getTreatGenericAsSmth().empty()) {
if (g->target->getTreatGenericAsSmth() == "knl_generic") if (g->target->getTreatGenericAsSmth() == "knl_generic")
dispatchNum = Target::KNL_AVX512; dispatchNum = Target::KNL_AVX512;
@@ -2879,7 +2965,7 @@ lCreateDispatchFunction(llvm::Module *module, llvm::Function *setISAFunc,
args.push_back(&*argIter); args.push_back(&*argIter);
} }
else { else {
llvm::CastInst *argCast = llvm::CastInst *argCast =
llvm::CastInst::CreatePointerCast(&*argIter, targsIter->getType(), llvm::CastInst::CreatePointerCast(&*argIter, targsIter->getType(),
"dpatch_arg_bitcast", callBBlock); "dpatch_arg_bitcast", callBBlock);
args.push_back(argCast); args.push_back(argCast);
@@ -3053,7 +3139,7 @@ lExtractOrCheckGlobals(llvm::Module *msrc, llvm::Module *mdst, bool check) {
} }
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
static std::string lCBEMangle(const std::string &S) static std::string lCBEMangle(const std::string &S)
{ {
std::string Result; std::string Result;
@@ -3102,7 +3188,7 @@ Module::CompileAndOutput(const char *srcFile,
if (m->CompileFile() == 0) { if (m->CompileFile() == 0) {
#ifdef ISPC_NVPTX_ENABLED #ifdef ISPC_NVPTX_ENABLED
/* NVPTX: /* NVPTX:
* for PTX target replace '.' with '_' in all global variables * for PTX target replace '.' with '_' in all global variables
* a PTX identifier name must match [a-zA-Z$_][a-zA-Z$_0-9]* * a PTX identifier name must match [a-zA-Z$_][a-zA-Z$_0-9]*
*/ */
if (g->target->getISA() == Target::NVPTX) if (g->target->getISA() == Target::NVPTX)
@@ -3135,7 +3221,7 @@ Module::CompileAndOutput(const char *srcFile,
} }
} }
else if (outputType == Asm || outputType == Object) { else if (outputType == Asm || outputType == Object) {
if (target != NULL && if (target != NULL &&
(strncmp(target, "generic-", 8) == 0 || strstr(target, "-generic-") != NULL)) { (strncmp(target, "generic-", 8) == 0 || strstr(target, "-generic-") != NULL)) {
Error(SourcePos(), "When using a \"generic-*\" compilation target, " Error(SourcePos(), "When using a \"generic-*\" compilation target, "
"%s output can not be used.", "%s output can not be used.",
@@ -3212,7 +3298,7 @@ Module::CompileAndOutput(const char *srcFile,
std::map<std::string, FunctionTargetVariants> exportedFunctions; std::map<std::string, FunctionTargetVariants> exportedFunctions;
int errorCount = 0; int errorCount = 0;
// Handle creating a "generic" header file for multiple targets // Handle creating a "generic" header file for multiple targets
// that use exported varyings // that use exported varyings
DispatchHeaderInfo DHI; DispatchHeaderInfo DHI;
@@ -3234,7 +3320,7 @@ Module::CompileAndOutput(const char *srcFile,
} }
// Variable is needed later for approptiate dispatch function. // Variable is needed later for approptiate dispatch function.
// It indicates if we have *-generic target. // It indicates if we have *-generic target.
std::string treatGenericAsSmth = ""; std::string treatGenericAsSmth = "";
for (unsigned int i = 0; i < targets.size(); ++i) { for (unsigned int i = 0; i < targets.size(); ++i) {
@@ -3272,9 +3358,9 @@ Module::CompileAndOutput(const char *srcFile,
if (outFileName != NULL) { if (outFileName != NULL) {
std::string targetOutFileName; std::string targetOutFileName;
// We always generate cpp file for *-generic target during multitarget compilation // We always generate cpp file for *-generic target during multitarget compilation
if (g->target->getISA() == Target::GENERIC && if (g->target->getISA() == Target::GENERIC &&
!g->target->getTreatGenericAsSmth().empty()) { !g->target->getTreatGenericAsSmth().empty()) {
targetOutFileName = lGetTargetFileName(outFileName, targetOutFileName = lGetTargetFileName(outFileName,
g->target->getTreatGenericAsSmth().c_str(), true); g->target->getTreatGenericAsSmth().c_str(), true);
if (!m->writeOutput(CXX, targetOutFileName.c_str(), includeFileName)) if (!m->writeOutput(CXX, targetOutFileName.c_str(), includeFileName))
return 1; return 1;
@@ -3299,14 +3385,14 @@ Module::CompileAndOutput(const char *srcFile,
// only print backmatter on the last target. // only print backmatter on the last target.
DHI.EmitBackMatter = true; DHI.EmitBackMatter = true;
} }
const char *isaName; const char *isaName;
if (g->target->getISA() == Target::GENERIC && if (g->target->getISA() == Target::GENERIC &&
!g->target->getTreatGenericAsSmth().empty()) !g->target->getTreatGenericAsSmth().empty())
isaName = g->target->getTreatGenericAsSmth().c_str(); isaName = g->target->getTreatGenericAsSmth().c_str();
else else
isaName = g->target->GetISAString(); isaName = g->target->GetISAString();
std::string targetHeaderFileName = std::string targetHeaderFileName =
lGetTargetFileName(headerFileName, isaName, false); lGetTargetFileName(headerFileName, isaName, false);
// write out a header w/o target name for the first target only // write out a header w/o target name for the first target only
if (!m->writeOutput(Module::Header, headerFileName, "", &DHI)) { if (!m->writeOutput(Module::Header, headerFileName, "", &DHI)) {

View File

@@ -499,6 +499,18 @@ DeclStmt::TypeCheck() {
return encounteredError ? NULL : this; return encounteredError ? NULL : this;
} }
Stmt *
DeclStmt::ReplacePolyType(const PolyType *from, const Type *to) {
for (size_t i = 0; i < vars.size(); i++) {
Symbol *s = vars[i].sym;
if (Type::EqualIgnoringConst(s->type->GetBaseType(), from)) {
s->type = PolyType::ReplaceType(s->type, to);
}
}
return this;
}
void void
DeclStmt::Print(int indent) const { DeclStmt::Print(int indent) const {
@@ -2179,6 +2191,21 @@ ForeachStmt::TypeCheck() {
return anyErrors ? NULL : this; return anyErrors ? NULL : this;
} }
Stmt *
ForeachStmt::ReplacePolyType(const PolyType *from, const Type *to) {
if (!stmts)
return NULL;
for (size_t i=0; i<dimVariables.size(); i++) {
const Type *t = dimVariables[i]->type;
if (Type::EqualIgnoringConst(t->GetBaseType(), from)) {
t = PolyType::ReplaceType(t, to);
}
}
return this;
}
int int
ForeachStmt::EstimateCost() const { ForeachStmt::EstimateCost() const {

2
stmt.h
View File

@@ -118,6 +118,7 @@ public:
Stmt *Optimize(); Stmt *Optimize();
Stmt *TypeCheck(); Stmt *TypeCheck();
Stmt *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const; int EstimateCost() const;
std::vector<VariableDeclaration> vars; std::vector<VariableDeclaration> vars;
@@ -282,6 +283,7 @@ public:
void Print(int indent) const; void Print(int indent) const;
Stmt *TypeCheck(); Stmt *TypeCheck();
Stmt *ReplacePolyType(const PolyType *from, const Type *to);
int EstimateCost() const; int EstimateCost() const;
std::vector<Symbol *> dimVariables; std::vector<Symbol *> dimVariables;