diff --git a/.gitignore b/.gitignore index 6cc46644..0a2eac86 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ tests*/*cpp tests*/*run tests*/*.o tests_ispcpp/*.h +tests_ispcpp/*.out tests_ispcpp/*pre* logs/ notify_log.log diff --git a/ctx.cpp b/ctx.cpp index 98c8ec5c..b041c03d 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -1927,6 +1927,11 @@ FunctionEmitContext::BinaryOperator(llvm::Instruction::BinaryOps inst, return NULL; } + if (v0->getType() != v1->getType()) { + v0->dump(); + printf("\n\n"); + v1->dump(); + } AssertPos(currentPos, v0->getType() == v1->getType()); llvm::Type *type = v0->getType(); int arraySize = lArrayVectorWidth(type); diff --git a/expr.cpp b/expr.cpp index ad62e08b..c7cd1522 100644 --- a/expr.cpp +++ b/expr.cpp @@ -606,6 +606,7 @@ lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, "\"%s\" for %s", fromType->GetString().c_str(), toPolyType->GetString().c_str(), errorMsgBase); } + return false; } } @@ -7209,8 +7210,11 @@ TypeCastExpr::GetValue(FunctionEmitContext *ctx) const { return NULL; return ctx->IntToPtrInst(exprVal, llvmToType, "int_to_ptr"); - } - else { + } else if (CastType(toType)) { + Error(pos, "Unexpected polymorphic type cast to \"%s\"", + toType->GetString().c_str()); + return NULL; + } else { const AtomicType *toAtomic = CastType(toType); // typechecking should ensure this is the case if (!toAtomic) { diff --git a/func.cpp b/func.cpp index fa1e030d..8776a9fc 100644 --- a/func.cpp +++ b/func.cpp @@ -651,8 +651,10 @@ Function::ExpandPolyArguments(SymbolTable *symbolTable) const { const FunctionType *func = CastType(sym->type); - printf("%s before replacing anything:\n", sym->name.c_str()); - code->Print(0); + if (g->debugPrint) { + printf("%s before replacing anything:\n", sym->name.c_str()); + code->Print(0); + } for (size_t i=0; i(versions[i]->type); @@ -665,13 +667,15 @@ Function::ExpandPolyArguments(SymbolTable *symbolTable) const { ncode = (Stmt*)TranslatePoly(ncode, from, ft->GetParameterType(j)->GetBaseType()); - printf("%s after replacing %s with %s:\n\n", - sym->name.c_str(), from->GetString().c_str(), - ft->GetParameterType(j)->GetBaseType()->GetString().c_str()); + if (g->debugPrint) { + printf("%s after replacing %s with %s:\n\n", + sym->name.c_str(), from->GetString().c_str(), + ft->GetParameterType(j)->GetBaseType()->GetString().c_str()); - ncode->Print(0); + ncode->Print(0); - printf("------------------------------------------\n\n"); + printf("------------------------------------------\n\n"); + } } } diff --git a/module.cpp b/module.cpp index 85b119ff..4d885200 100644 --- a/module.cpp +++ b/module.cpp @@ -1032,8 +1032,7 @@ Module::AddFunctionDeclaration(const std::string &name, } std::vector nextExpanded; - std::set::iterator iter; - for (iter = toExpand.begin(); iter != toExpand.end(); iter++) { + for (auto iter = toExpand.begin(); iter != toExpand.end(); iter++) { for (size_t j=0; jGetParameterSourcePos(k)); } - nextExpanded.push_back(new FunctionType(eft->GetReturnType(), + const Type *ret = eft->GetReturnType(); + if (Type::EqualForReplacement(ret, pt)) { + printf("Replaced return type %s\n", + ret->GetString().c_str()); + ret = PolyType::ReplaceType(ret, *te); + } + + nextExpanded.push_back(new FunctionType(ret, nargs, nargsn, nargsd, @@ -1078,6 +1084,11 @@ Module::AddFunctionDeclaration(const std::string &name, if (expanded.size() > 1) { for (size_t i=0; iGetReturnType()->IsPolymorphicType()) { + Error(pos, "Unexpected polymorphic return type \"%s\"", + expanded[i]->GetReturnType()->GetString().c_str()); + return; + } std::string nname = name; if (functionType->isExported || functionType->isExternC) { for (int j=0; jGetNumParameters(); j++) { @@ -1977,6 +1988,27 @@ lPrintFunctionDeclarations(FILE *file, const std::vector &funcs, // fprintf(file, "#ifdef __cplusplus\n} /* end extern C */\n#endif // __cplusplus\n"); } +static void +lPrintPolyFunctionWrappers(FILE *file, const std::vector &funcs) { + fprintf(file, "#if defined(__cplusplus)\n"); + + for (size_t i=0; i poly = m->symbolTable->LookupPolyFunction(funcs[i].c_str()); + + for (size_t j=0; j(poly[j]->type); + Assert(ftype); + std::string decl = ftype->GetCDeclaration(funcs[i]); + fprintf(file, " %s {\n", decl.c_str()); + + std::string call = ftype->GetCCall(poly[j]->name); + fprintf(file, " return %s;\n }\n", call.c_str()); + } + } + + fprintf(file, "#endif // __cplusplus\n"); +} + @@ -2353,8 +2385,10 @@ Module::writeHeader(const char *fn) { // Collect single linear arrays of the exported and extern "C" // functions std::vector exportedFuncs, externCFuncs; + std::vector polyFuncs; m->symbolTable->GetMatchingFunctions(lIsExported, &exportedFuncs); m->symbolTable->GetMatchingFunctions(lIsExternC, &externCFuncs); + m->symbolTable->GetPolyFunctions(&polyFuncs); // Get all of the struct, vector, and enumerant types used as function // parameters. These vectors may have repeats. @@ -2391,6 +2425,16 @@ Module::writeHeader(const char *fn) { fprintf(f, "///////////////////////////////////////////////////////////////////////////\n"); lPrintFunctionDeclarations(f, exportedFuncs); } + + // emit wrappers for polymorphic functions + if (polyFuncs.size() > 0) { + fprintf(f, "\n"); + fprintf(f, "///////////////////////////////////////////////////////////////////////////\n"); + fprintf(f, "// Polymorphic function wrappers\n"); + fprintf(f, "///////////////////////////////////////////////////////////////////////////\n"); + lPrintPolyFunctionWrappers(f, polyFuncs); + } + #if 0 if (externCFuncs.size() > 0) { fprintf(f, "\n"); diff --git a/sym.cpp b/sym.cpp index 88777d5a..48ee06f7 100644 --- a/sym.cpp +++ b/sym.cpp @@ -147,7 +147,7 @@ bool SymbolTable::AddFunction(Symbol *symbol) { const FunctionType *ft = CastType(symbol->type); Assert(ft != NULL); - if (LookupFunction(symbol->name.c_str(), ft) != NULL) + if (LookupFunction(symbol->name.c_str(), ft, true) != NULL) // A function of the same name and type has already been added to // the symbol table return false; @@ -183,7 +183,8 @@ SymbolTable::LookupFunction(const char *name, std::vector *matches) { Symbol * -SymbolTable::LookupFunction(const char *name, const FunctionType *type) { +SymbolTable::LookupFunction(const char *name, const FunctionType *type, + bool ignorePoly) { FunctionMapType::iterator iter = functions.find(name); if (iter != functions.end()) { std::vector funcs = iter->second; @@ -193,7 +194,7 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) { } } // Try looking for a polymorphic function - if (polyFunctions[name].size() > 0) { + if (!ignorePoly && polyFunctions[name].size() > 0) { std::string n = name; return new Symbol(name, polyFunctions[name][0]->pos, type); } @@ -206,6 +207,14 @@ SymbolTable::LookupPolyFunction(const char *name) { return polyFunctions[name]; } +void +SymbolTable::GetPolyFunctions(std::vector *funcs) { + FunctionMapType::iterator it = polyFunctions.begin(); + for (; it != polyFunctions.end(); it++) { + funcs->push_back(it->first); + } +} + bool SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) { diff --git a/sym.h b/sym.h index 802ad58c..41973c72 100644 --- a/sym.h +++ b/sym.h @@ -181,10 +181,13 @@ public: in the symbol table. @return pointer to matching Symbol; NULL if none is found. */ - Symbol *LookupFunction(const char *name, const FunctionType *type); + Symbol *LookupFunction(const char *name, const FunctionType *type, + bool ignorePoly = false); std::vector& LookupPolyFunction(const char *name); + void GetPolyFunctions(std::vector *funcs); + /** Returns all of the functions in the symbol table that match the given predicate. diff --git a/tests_ispcpp/Makefile b/tests_ispcpp/Makefile new file mode 100644 index 00000000..e30bfc2d --- /dev/null +++ b/tests_ispcpp/Makefile @@ -0,0 +1,13 @@ +CXX=g++ +CXXFLAGS=-std=c++11 + +ISPC=../ispc +ISPCFLAGS=--target=sse4-x2 -O2 --arch=x86-64 + +%.out : %.cpp %.o + $(CXX) $(CXXFLAGS) -o $@ $^ + +$ : $.o + +%.o : %.ispc + $(ISPC) $(ISPCFLAGS) -h $*.h -o $*.o $< diff --git a/tests_ispcpp/simple.cpp b/tests_ispcpp/simple.cpp new file mode 100644 index 00000000..5312e57f --- /dev/null +++ b/tests_ispcpp/simple.cpp @@ -0,0 +1,20 @@ +#include +#include + +#include "simple.h" + +int main() { + double A[256]; + + for (int i=0; i<256; i++) { + A[i] = i / 11.; + } + + ispc::foo(256, (double*)&A); + + for (int i=0; i<256; i++) { + printf("%f\n", A[i]); + } + + return 0; +} diff --git a/type.cpp b/type.cpp index da8851b1..3c421bf4 100644 --- a/type.cpp +++ b/type.cpp @@ -249,6 +249,15 @@ Type::IsVoidType() const { bool Type::IsPolymorphicType() const { + const FunctionType *ft = CastType(this); + if (ft) { + for (int i=0; iGetNumParameters(); i++) { + if (ft->GetParameterType(i)->IsPolymorphicType()) + return true; + } + + return false; + } return (CastType(GetBaseType()) != NULL); } @@ -3585,6 +3594,34 @@ FunctionType::GetCDeclaration(const std::string &fname) const { return ret; } +std::string +FunctionType::GetCCall(const std::string &fname) const { + std::string ret; + ret += fname; + ret += "("; + for (unsigned int i = 0; i < paramTypes.size(); ++i) { + const Type *type = paramTypes[i]; + + // Convert pointers to arrays to unsized arrays, which are more clear + // to print out for multidimensional arrays (i.e. "float foo[][4] " + // versus "float (foo *)[4]"). + const PointerType *pt = CastType(type); + if (pt != NULL && + CastType(pt->GetBaseType()) != NULL) { + type = new ArrayType(pt->GetBaseType(), 0); + } + + if (paramNames[i] != "") + ret += paramNames[i]; + else + FATAL("Exporting a polymorphic function with incomplete arguments"); + if (i != paramTypes.size() - 1) + ret += ", "; + } + ret += ")"; + return ret; +} + std::string FunctionType::GetCDeclarationForDispatch(const std::string &fname) const { @@ -4041,7 +4078,8 @@ bool Type::IsBasicType(const Type *type) { return (CastType(type) != NULL || CastType(type) != NULL || - CastType(type) != NULL); + CastType(type) != NULL || + CastType(type) != NULL); } diff --git a/type.h b/type.h index 69d739d2..634d29b5 100644 --- a/type.h +++ b/type.h @@ -988,6 +988,7 @@ public: std::string GetString() const; std::string Mangle() const; std::string GetCDeclaration(const std::string &fname) const; + std::string GetCCall(const std::string &fname) const; std::string GetCDeclarationForDispatch(const std::string &fname) const; llvm::Type *LLVMType(llvm::LLVMContext *ctx) const;