diff --git a/expr.cpp b/expr.cpp index dc56aca7..436d9e8e 100644 --- a/expr.cpp +++ b/expr.cpp @@ -3274,15 +3274,19 @@ SelectExpr::TypeCheck() { if (!type1 || !type2) return NULL; - if (CastType(type1)) { - Error(pos, "Array type \"%s\" can't be used in select expression", - type1->GetString().c_str()); - return NULL; + if (const ArrayType *at1 = CastType(type1)) { + expr1 = TypeConvertExpr(expr1, PointerType::GetUniform(at1->GetBaseType()), + "select"); + if (expr1 == NULL) + return NULL; + type1 = expr1->GetType(); } - if (CastType(type2)) { - Error(pos, "Array type \"%s\" can't be used in select expression", - type2->GetString().c_str()); - return NULL; + if (const ArrayType *at2 = CastType(type2)) { + expr2 = TypeConvertExpr(expr2, PointerType::GetUniform(at2->GetBaseType()), + "select"); + if (expr2 == NULL) + return NULL; + type2 = expr2->GetType(); } const Type *testType = test->GetType(); diff --git a/tests/select-array-ptr-typeconv.ispc b/tests/select-array-ptr-typeconv.ispc new file mode 100644 index 00000000..a7d1e51d --- /dev/null +++ b/tests/select-array-ptr-typeconv.ispc @@ -0,0 +1,18 @@ + +uniform float a[1234]; + +float * uniform func(uniform bool x) { + return x ? a : NULL; +} + +export uniform int width() { return programCount; } + +export void f_f(uniform float RET[], uniform float aFOO[]) { + a[programIndex] = aFOO[programIndex]; + float * uniform ptr = func(aFOO[0] == 1); + RET[programIndex] = ptr[programIndex]; +} + +export void result(uniform float RET[]) { + RET[programIndex] = 1 + programIndex; +}