diff --git a/.gitignore b/.gitignore index 6cc46644..0a2eac86 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ tests*/*cpp tests*/*run tests*/*.o tests_ispcpp/*.h +tests_ispcpp/*.out tests_ispcpp/*pre* logs/ notify_log.log diff --git a/module.cpp b/module.cpp index 85b119ff..4a5dcbe5 100644 --- a/module.cpp +++ b/module.cpp @@ -1977,6 +1977,27 @@ lPrintFunctionDeclarations(FILE *file, const std::vector &funcs, // fprintf(file, "#ifdef __cplusplus\n} /* end extern C */\n#endif // __cplusplus\n"); } +static void +lPrintPolyFunctionWrappers(FILE *file, const std::vector &funcs) { + fprintf(file, "#if defined(__cplusplus)\n"); + + for (size_t i=0; i poly = m->symbolTable->LookupPolyFunction(funcs[i].c_str()); + + for (size_t j=0; j(poly[j]->type); + Assert(ftype); + std::string decl = ftype->GetCDeclaration(funcs[i]); + fprintf(file, " %s {\n", decl.c_str()); + + std::string call = ftype->GetCCall(poly[j]->name); + fprintf(file, " return %s;\n }\n", call.c_str()); + } + } + + fprintf(file, "#endif // __cplusplus\n"); +} + @@ -2353,8 +2374,10 @@ Module::writeHeader(const char *fn) { // Collect single linear arrays of the exported and extern "C" // functions std::vector exportedFuncs, externCFuncs; + std::vector polyFuncs; m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs); m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs); + m->symbolTable->GetPolyFunctions(&polyFuncs); // Get all of the struct, vector, and enumerant types used as function // parameters. These vectors may have repeats. @@ -2391,6 +2414,16 @@ Module::writeHeader(const char *fn) { fprintf(f, "///////////////////////////////////////////////////////////////////////////\n"); lPrintFunctionDeclarations(f, exportedFuncs); } + + // emit wrappers for polymorphic functions + if (polyFuncs.size() > 0) { + fprintf(f, "\n"); + fprintf(f, "///////////////////////////////////////////////////////////////////////////\n"); + fprintf(f, "// Polymorphic function wrappers\n"); + fprintf(f, "///////////////////////////////////////////////////////////////////////////\n"); + lPrintPolyFunctionWrappers(f, polyFuncs); + } + #if 0 if (externCFuncs.size() > 0) { fprintf(f, "\n"); diff --git a/sym.cpp b/sym.cpp index 88777d5a..e22105eb 100644 --- a/sym.cpp +++ b/sym.cpp @@ -206,6 +206,14 @@ SymbolTable::LookupPolyFunction(const char *name) { return polyFunctions[name]; } +void +SymbolTable::GetPolyFunctions(std::vector *funcs) { + FunctionMapType::iterator it = polyFunctions.begin(); + for (; it != polyFunctions.end(); it++) { + funcs->push_back(it->first); + } +} + bool SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) { diff --git a/sym.h b/sym.h index 802ad58c..46c6fe9e 100644 --- a/sym.h +++ b/sym.h @@ -185,6 +185,8 @@ public: std::vector& LookupPolyFunction(const char *name); + void GetPolyFunctions(std::vector *funcs); + /** Returns all of the functions in the symbol table that match the given predicate. diff --git a/tests_ispcpp/Makefile b/tests_ispcpp/Makefile new file mode 100644 index 00000000..e30bfc2d --- /dev/null +++ b/tests_ispcpp/Makefile @@ -0,0 +1,13 @@ +CXX=g++ +CXXFLAGS=-std=c++11 + +ISPC=../ispc +ISPCFLAGS=--target=sse4-x2 -O2 --arch=x86-64 + +%.out : %.cpp %.o + $(CXX) $(CXXFLAGS) -o $@ $^ + +$ : $.o + +%.o : %.ispc + $(ISPC) $(ISPCFLAGS) -h $*.h -o $*.o $< diff --git a/type.cpp b/type.cpp index da8851b1..2f102dfc 100644 --- a/type.cpp +++ b/type.cpp @@ -249,6 +249,15 @@ Type::IsVoidType() const { bool Type::IsPolymorphicType() const { + const FunctionType *ft = CastType(this); + if (ft) { + for (int i=0; iGetNumParameters(); i++) { + if (ft->GetParameterType(i)->IsPolymorphicType()) + return true; + } + + return false; + } return (CastType(GetBaseType()) != NULL); } @@ -3585,6 +3594,34 @@ FunctionType::GetCDeclaration(const std::string &fname) const { return ret; } +std::string +FunctionType::GetCCall(const std::string &fname) const { + std::string ret; + ret += fname; + ret += "("; + for (unsigned int i = 0; i < paramTypes.size(); ++i) { + const Type *type = paramTypes[i]; + + // Convert pointers to arrays to unsized arrays, which are more clear + // to print out for multidimensional arrays (i.e. "float foo[][4] " + // versus "float (foo *)[4]"). + const PointerType *pt = CastType(type); + if (pt != NULL && + CastType(pt->GetBaseType()) != NULL) { + type = new ArrayType(pt->GetBaseType(), 0); + } + + if (paramNames[i] != "") + ret += paramNames[i]; + else + FATAL("Exporting a polymorphic function with incomplete arguments"); + if (i != paramTypes.size() - 1) + ret += ", "; + } + ret += ")"; + return ret; +} + std::string FunctionType::GetCDeclarationForDispatch(const std::string &fname) const { diff --git a/type.h b/type.h index 69d739d2..634d29b5 100644 --- a/type.h +++ b/type.h @@ -988,6 +988,7 @@ public: std::string GetString() const; std::string Mangle() const; std::string GetCDeclaration(const std::string &fname) const; + std::string GetCCall(const std::string &fname) const; std::string GetCDeclarationForDispatch(const std::string &fname) const; llvm::Type *LLVMType(llvm::LLVMContext *ctx) const;