diff --git a/llvmutil.cpp b/llvmutil.cpp index 1a065ac3..4088ad7f 100644 --- a/llvmutil.cpp +++ b/llvmutil.cpp @@ -687,6 +687,253 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector &seenPhis); +/** This function checks to see if the given (scalar or vector) value is an + exact multiple of baseValue. It returns true if so, and false if not + (or if it's not able to determine if it is). Any vector value passed + in is required to have the same value in all elements (so that we can + just check the first element to be a multiple of the given value.) + */ +static bool +lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength, + std::vector &seenPhis) { + if (llvm::isa(val->getType()) == false) { + // If we've worked down to a constant int, then the moment of truth + // has arrived... + llvm::ConstantInt *ci = llvm::dyn_cast(val); + if (ci != NULL) + return (ci->getZExtValue() % baseValue) == 0; + } + else + Assert(LLVMVectorValuesAllEqual(val)); + + llvm::InsertElementInst *ie = llvm::dyn_cast(val); + if (ie != NULL) { + llvm::Value *elts[ISPC_MAX_NVEC]; + LLVMFlattenInsertChain(ie, g->target.vectorWidth, elts); + // We just need to check the scalar first value, since we know that + // all elements are equal + return lIsExactMultiple(elts[0], baseValue, vectorLength, + seenPhis); + } + + llvm::PHINode *phi = llvm::dyn_cast(val); + if (phi != NULL) { + for (unsigned int i = 0; i < seenPhis.size(); ++i) + if (phi == seenPhis[i]) + return true; + + seenPhis.push_back(phi); + unsigned int numIncoming = phi->getNumIncomingValues(); + + // Check all of the incoming values: if all of them pass, then + // we're good. + for (unsigned int i = 0; i < numIncoming; ++i) { + llvm::Value *incoming = phi->getIncomingValue(i); + bool mult = lIsExactMultiple(incoming, baseValue, vectorLength, + seenPhis); + if (mult == false) { + seenPhis.pop_back(); + return false; + } + } + seenPhis.pop_back(); + return true; + } + + llvm::BinaryOperator *bop = llvm::dyn_cast(val); + if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add) { + llvm::Value *op0 = bop->getOperand(0); + llvm::Value *op1 = bop->getOperand(1); + + bool be0 = lIsExactMultiple(op0, baseValue, vectorLength, seenPhis); + bool be1 = lIsExactMultiple(op1, baseValue, vectorLength, seenPhis); + return (be0 && be1); + } + // FIXME: mul? casts? ... ? + + return false; +} + + +/** Returns the next power of two greater than or equal to the given + value. */ +static int +lRoundUpPow2(int v) { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + return v+1; +} + + +/** Try to determine if all of the elements of the given vector value have + the same value when divided by the given baseValue. The function + returns true if this can be determined to be the case, and false + otherwise. (This function may fail to identify some cases where it + does in fact have this property, but should never report a given value + as being a multiple if it isn't!) + */ +static bool +lAllDivBaseEqual(llvm::Value *val, int baseValue, int vectorLength, + std::vector &seenPhis, + bool &canAdd) { + Assert(llvm::isa(val->getType())); + // Make sure the base value is a positive power of 2 + Assert(baseValue > 0 && (baseValue & (baseValue-1)) == 0); + + // The easy case + if (lVectorValuesAllEqual(val, vectorLength, seenPhis)) + return true; + + int64_t vecVals[ISPC_MAX_NVEC]; + int nElts; + if (llvm::isa(val->getType()) && + LLVMExtractVectorInts(val, vecVals, &nElts)) { + // If we have a vector of compile-time constant integer values, + // then go ahead and check them directly.. + int64_t firstDiv = vecVals[0] / baseValue; + for (int i = 1; i < nElts; ++i) + if ((vecVals[i] / baseValue) != firstDiv) + return false; + + return true; + } + + llvm::PHINode *phi = llvm::dyn_cast(val); + if (phi != NULL) { + for (unsigned int i = 0; i < seenPhis.size(); ++i) + if (phi == seenPhis[i]) + return true; + + seenPhis.push_back(phi); + unsigned int numIncoming = phi->getNumIncomingValues(); + + // Check all of the incoming values: if all of them pass, then + // we're good. + for (unsigned int i = 0; i < numIncoming; ++i) { + llvm::Value *incoming = phi->getIncomingValue(i); + bool ca = canAdd; + bool mult = lAllDivBaseEqual(incoming, baseValue, vectorLength, + seenPhis, ca); + if (mult == false) { + seenPhis.pop_back(); + return false; + } + } + seenPhis.pop_back(); + return true; + } + + llvm::BinaryOperator *bop = llvm::dyn_cast(val); + if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add && + canAdd == true) { + llvm::Value *op0 = bop->getOperand(0); + llvm::Value *op1 = bop->getOperand(1); + + // Otherwise we're only going to worry about the following case, + // which comes up often when looping over SOA data: + // ashr %val, + // where %val = add %smear, <0,1,2,3...> + // and where the maximum of the <0,...> vector in the add is less than + // 1<<(constant shift), + // and where %smear is a smear of a value that is a multiple of + // baseValue. + + int64_t addConstants[ISPC_MAX_NVEC]; + if (LLVMExtractVectorInts(op1, addConstants, &nElts) == false) + return false; + Assert(nElts == vectorLength); + + // Do all of them give the same value when divided by baseValue? + int64_t firstConstDiv = addConstants[0] / baseValue; + for (int i = 1; i < vectorLength; ++i) + if ((addConstants[i] / baseValue) != firstConstDiv) + return false; + + if (lVectorValuesAllEqual(op0, vectorLength, seenPhis) == false) + return false; + + // Note that canAdd is a reference parameter; setting this ensures + // that we don't allow multiple adds in other parts of the chain of + // dependent values from here. + canAdd = false; + + // Now we need to figure out the required alignment (in numbers of + // elements of the underlying type being indexed) of the value to + // which these integer addConstant[] values are being added to. We + // know that we have addConstant[] values that all give the same + // value when divided by baseValue, but we may need a less-strict + // alignment than baseValue depending on the actual values. + // + // As an example, consider a case where the baseValue alignment is + // 16, but the addConstants here are <0,1,2,3>. In that case, the + // value to which addConstants is added to only needs to be a + // multiple of 4. Conversely, if addConstants are <4,5,6,7>, then + // we need a multiple of 8 to ensure that the final added result + // will still have the same value for all vector elements when + // divided by baseValue. + // + // All that said, here we need to find the maximum value of any of + // the addConstants[], mod baseValue. If we round that up to the + // next power of 2, we'll have a value that will be no greater than + // baseValue and sometimes less. + int maxMod = addConstants[0] % baseValue; + for (int i = 1; i < vectorLength; ++i) + maxMod = std::max(maxMod, int(addConstants[i] % baseValue)); + int requiredAlignment = lRoundUpPow2(maxMod); + + std::vector seenPhisEEM; + return lIsExactMultiple(op0, requiredAlignment, vectorLength, + seenPhisEEM); + } + // TODO: could handle mul by a vector of equal constant integer values + // and the like here and adjust the 'baseValue' value when it evenly + // divides, but unclear if it's worthwhile... + + return false; +} + + +/** Given a vector shift right of some value by some amount, try to + determine if all of the elements of the final result have the same + value (i.e. whether the high bits are all equal, disregarding the low + bits that are shifted out.) Returns true if so, and false otherwise. + */ +static bool +lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift, + int vectorLength) { + // Are we shifting all elements by a compile-time constant amount? If + // not, give up. + int64_t shiftAmount[ISPC_MAX_NVEC]; + int nElts; + if (LLVMExtractVectorInts(shift, shiftAmount, &nElts) == false) + return false; + Assert(nElts == vectorLength); + + // Is it the same amount for all elements? + for (int i = 0; i < vectorLength; ++i) + if (shiftAmount[i] != shiftAmount[0]) + return false; + + // Now see if the value divided by (1 << shift) can be determined to + // have the same value for all vector elements. + int pow2 = 1 << shiftAmount[0]; + bool canAdd = true; + std::vector seenPhis; + bool eq = lAllDivBaseEqual(val, pow2, vectorLength, seenPhis, canAdd); +#if 0 + fprintf(stderr, "check all div base equal:\n"); + LLVMDumpValue(shift); + LLVMDumpValue(val); + fprintf(stderr, "----> %s\n\n", eq ? "true" : "false"); +#endif + return eq; +} + + static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector &seenPhis) { @@ -707,11 +954,23 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength, #endif llvm::BinaryOperator *bop = llvm::dyn_cast(v); - if (bop != NULL) - return (LLVMVectorValuesAllEqual(bop->getOperand(0), vectorLength, - seenPhis) && - LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength, - seenPhis)); + if (bop != NULL) { + // Easy case: both operands are all equal -> return true + if (lVectorValuesAllEqual(bop->getOperand(0), vectorLength, + seenPhis) && + lVectorValuesAllEqual(bop->getOperand(1), vectorLength, + seenPhis)) + return true; + + // If it's a shift, take a special path that tries to check if the + // high (surviving) bits of the values are equal. + if (bop->getOpcode() == llvm::Instruction::AShr || + bop->getOpcode() == llvm::Instruction::LShr) + return lVectorShiftRightAllEqual(bop->getOperand(0), + bop->getOperand(1), vectorLength); + + return false; + } llvm::CastInst *cast = llvm::dyn_cast(v); if (cast != NULL) @@ -923,6 +1182,45 @@ lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, } +/** Given (op0 AND op1), try and see if we can determine if the result is a + linear sequence with a step of "stride" between values. Returns true + if so and false otherwise. This pattern comes up when accessing SOA + data. + */ +static bool +lCheckAndForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength, + int stride, std::vector &seenPhis) { + // Require op1 to be a compile-time constant + int64_t maskValue[ISPC_MAX_NVEC]; + int nElts; + if (LLVMExtractVectorInts(op1, maskValue, &nElts) == false) + return false; + Assert(nElts == vectorLength); + + // Is op1 a smear of the same value across all lanes? Give up if not. + for (int i = 1; i < vectorLength; ++i) + if (maskValue[i] != maskValue[0]) + return false; + + // If the op1 value isn't power of 2 minus one, then also give up. + int64_t maskPlusOne = maskValue[0] + 1; + bool isPowTwo = (maskPlusOne & (maskPlusOne - 1)) == 0; + if (isPowTwo == false) + return false; + + // The case we'll covert here is op0 being a linear vector with desired + // stride, and where all of the values of op0, when divided by + // maskPlusOne, have the same value. + if (lVectorIsLinear(op0, vectorLength, stride, seenPhis) == false) + return false; + + bool canAdd = true; + bool isMult = lAllDivBaseEqual(op0, maskPlusOne, vectorLength, seenPhis, + canAdd); + return isMult; +} + + static bool lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, std::vector &seenPhis) { @@ -971,6 +1269,12 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis); return m1; } + else if (bop->getOpcode() == llvm::Instruction::And) { + // Special case for some AND-related patterns that come up when + // looping over SOA data + bool linear = lCheckAndForLinear(op0, op1, vectorLength, stride, seenPhis); + return linear; + } else return false; } diff --git a/tests/soa-27.ispc b/tests/soa-27.ispc new file mode 100644 index 00000000..ee353a0b --- /dev/null +++ b/tests/soa-27.ispc @@ -0,0 +1,25 @@ + +struct Point { float x, y, z; }; + +export uniform int width() { return programCount; } + +export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) { + float a = aFOO[programIndex]; + + soa<8> Point pts[10]; + + foreach (i = 1 ... 80) { + pts[i].x = b*i; + pts[i].y = 2*b*i; + pts[i].z = 3*b*i; + } + pts[0].x = pts[0].y = pts[0].z = 0; + + uniform Point up = pts[4]; + + RET[programIndex] = up.z; +} + +export void result(uniform float RET[]) { + RET[programIndex] = 60; +} diff --git a/tests/soa-28.ispc b/tests/soa-28.ispc new file mode 100644 index 00000000..92f3c4a3 --- /dev/null +++ b/tests/soa-28.ispc @@ -0,0 +1,24 @@ + +struct Point { float x, y, z; }; + +export uniform int width() { return programCount; } + +export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) { + float a = aFOO[programIndex]; + + soa<8> Point pts[10]; + + foreach (i = b-5 ... 80) { + pts[i].x = b*i; + pts[i].y = 2*b*i; + pts[i].z = 3*b*i; + } + + uniform Point up = pts[4]; + + RET[programIndex] = pts[2*programIndex].x; +} + +export void result(uniform float RET[]) { + RET[programIndex] = 10 * programIndex; +} diff --git a/tests/soa-29.ispc b/tests/soa-29.ispc new file mode 100644 index 00000000..e9a5a069 --- /dev/null +++ b/tests/soa-29.ispc @@ -0,0 +1,24 @@ + +struct Point { float x, y; int8 zzz0; float z; double aa[3]; int8 zzz; }; + +export uniform int width() { return programCount; } + +export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) { + float a = aFOO[programIndex]; + + soa<8> Point pts[10]; + + for (int i = programIndex; i < 16*b; i += programCount) { + pts[i].x = b*i; + pts[i].y = 2*b*i; + pts[i].z = 3*b*i; + } + + uniform Point up = pts[4]; + + RET[programIndex] = pts[2*programIndex].x; +} + +export void result(uniform float RET[]) { + RET[programIndex] = 10 * programIndex; +}