diff --git a/expr.cpp b/expr.cpp index ec242fea..8925de2b 100644 --- a/expr.cpp +++ b/expr.cpp @@ -2924,10 +2924,77 @@ SelectExpr::Optimize() { // Varying test: see if all of the values are the same; if so, then // return the corresponding expression bool first = bv[0]; + bool mismatch = false; for (int i = 0; i < count; ++i) - if (bv[i] != first) - return this; - return (bv[0] == true) ? expr1 : expr2; + if (bv[i] != first) { + mismatch = true; + break; + } + if (mismatch == false) + return (bv[0] == true) ? expr1 : expr2; + + // Last chance: see if the two expressions are constants; if so, + // then we can do an element-wise selection based on the constant + // condition.. + ConstExpr *constExpr1 = dynamic_cast(expr1); + ConstExpr *constExpr2 = dynamic_cast(expr2); + if (constExpr1 == NULL || constExpr2 == NULL) + return this; + + Assert(constExpr1->GetType() == constExpr2->GetType()); + const Type *exprType = constExpr1->GetType()->GetAsNonConstType(); + Assert(exprType->IsVaryingType()); + + // FIXME: it's annoying to have to have all of this replicated code. + if (exprType == AtomicType::VaryingInt32 || + exprType == AtomicType::VaryingUInt32) { + int32_t v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC]; + int32_t result[ISPC_MAX_NVEC]; + constExpr1->AsInt32(v1); + constExpr2->AsInt32(v2); + for (int i = 0; i < count; ++i) + result[i] = bv[i] ? v1[i] : v2[i]; + return new ConstExpr(exprType, result, pos); + } + else if (exprType == AtomicType::VaryingInt64 || + exprType == AtomicType::VaryingUInt64) { + int64_t v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC]; + int64_t result[ISPC_MAX_NVEC]; + constExpr1->AsInt64(v1); + constExpr2->AsInt64(v2); + for (int i = 0; i < count; ++i) + result[i] = bv[i] ? v1[i] : v2[i]; + return new ConstExpr(exprType, result, pos); + } + else if (exprType == AtomicType::VaryingFloat) { + float v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC]; + float result[ISPC_MAX_NVEC]; + constExpr1->AsFloat(v1); + constExpr2->AsFloat(v2); + for (int i = 0; i < count; ++i) + result[i] = bv[i] ? v1[i] : v2[i]; + return new ConstExpr(exprType, result, pos); + } + else if (exprType == AtomicType::VaryingDouble) { + double v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC]; + double result[ISPC_MAX_NVEC]; + constExpr1->AsDouble(v1); + constExpr2->AsDouble(v2); + for (int i = 0; i < count; ++i) + result[i] = bv[i] ? v1[i] : v2[i]; + return new ConstExpr(exprType, result, pos); + } + else if (exprType == AtomicType::VaryingBool) { + bool v1[ISPC_MAX_NVEC], v2[ISPC_MAX_NVEC]; + bool result[ISPC_MAX_NVEC]; + constExpr1->AsBool(v1); + constExpr2->AsBool(v2); + for (int i = 0; i < count; ++i) + result[i] = bv[i] ? v1[i] : v2[i]; + return new ConstExpr(exprType, result, pos); + } + + return this; } }