diff --git a/module.cpp b/module.cpp index 3888d923..9e1188c1 100644 --- a/module.cpp +++ b/module.cpp @@ -544,27 +544,37 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE types are found and returns false otherwise. */ static bool -lRecursiveCheckValidParamType(const Type *t) { +lRecursiveCheckValidParamType(const Type *t, bool vectorOk) { const StructType *st = CastType(t); if (st != NULL) { for (int i = 0; i < st->GetElementCount(); ++i) - if (lRecursiveCheckValidParamType(st->GetElementType(i))) - return true; - return false; + if (!lRecursiveCheckValidParamType(st->GetElementType(i), + vectorOk)) + return false; + return true; } + // Vector types are also not supported, pending ispc properly + // supporting the platform ABI. (Pointers to vector types are ok, + // though.) (https://github.com/ispc/ispc/issues/363)... + if (vectorOk == false && CastType(t) != NULL) + return false; + const SequentialType *seqt = CastType(t); if (seqt != NULL) - return lRecursiveCheckValidParamType(seqt->GetElementType()); + return lRecursiveCheckValidParamType(seqt->GetElementType(), vectorOk); const PointerType *pt = CastType(t); if (pt != NULL) { if (pt->IsSlice() || pt->IsVaryingType()) - return true; - return lRecursiveCheckValidParamType(pt->GetBaseType()); + return false; + return lRecursiveCheckValidParamType(pt->GetBaseType(), true); } - return t->IsVaryingType(); + if (t->IsVaryingType()) + return false; + else + return true; } @@ -574,15 +584,18 @@ lRecursiveCheckValidParamType(const Type *t) { varying parameters is illegal. */ static void -lCheckForVaryingParameter(const Type *type, const std::string &name, - SourcePos pos) { - if (lRecursiveCheckValidParamType(type)) { +lCheckExportedParameterTypes(const Type *type, const std::string &name, + SourcePos pos) { + if (lRecursiveCheckValidParamType(type, false) == false) { if (CastType(type)) Error(pos, "Varying pointer type parameter \"%s\" is illegal " "in an exported function.", name.c_str()); else if (CastType(type->GetBaseType())) - Error(pos, "Struct parameter \"%s\" with varying member(s) is illegal " - "in an exported function.", name.c_str()); + Error(pos, "Struct parameter \"%s\" with varying or vector typed " + "member(s) is illegal in an exported function.", name.c_str()); + else if (CastType(type)) + Error(pos, "Vector-typed parameter \"%s\" is illegal in an exported " + "function.", name.c_str()); else Error(pos, "Varying parameter \"%s\" is illegal in an exported function.", name.c_str()); @@ -747,12 +760,12 @@ Module::AddFunctionDeclaration(const std::string &name, // This also applies transitively to members I think? function->setDoesNotAlias(1, true); - // Make sure that the return type isn't 'varying' if the function is - // 'export'ed. + // Make sure that the return type isn't 'varying' or vector typed if + // the function is 'export'ed. if (functionType->isExported && - lRecursiveCheckValidParamType(functionType->GetReturnType())) - Error(pos, "Illegal to return a \"varying\" type from exported " - "function \"%s\"", name.c_str()); + lRecursiveCheckValidParamType(functionType->GetReturnType(), false) == false) + Error(pos, "Illegal to return a \"varying\" or vector type from " + "exported function \"%s\"", name.c_str()); if (functionType->isTask && Type::Equal(functionType->GetReturnType(), AtomicType::Void) == false) @@ -774,7 +787,7 @@ Module::AddFunctionDeclaration(const std::string &name, // If the function is exported, make sure that the parameter // doesn't have any varying stuff going on in it. if (functionType->isExported) - lCheckForVaryingParameter(argType, argName, argPos); + lCheckExportedParameterTypes(argType, argName, argPos); // ISPC assumes that no pointers alias. (It should be possible to // specify when this is not the case, but this should be the diff --git a/tests_errors/export-vector-param.ispc b/tests_errors/export-vector-param.ispc index b0df1ef9..72da253a 100644 --- a/tests_errors/export-vector-param.ispc +++ b/tests_errors/export-vector-param.ispc @@ -1,4 +1,4 @@ -// dsgsdhg +// Vector-typed parameter "x" is illegal export void foo(uniform float<3> x) { }