Add patterns to better-match code generated when accessing SOA data.
In particular, LLVMVectorIsLinear() and LLVMVectorValuesAllEqual() are able to reason a bit about the effects of the shifts and the ANDs that are generated from SOA indexing calculations, so that they can detect more cases where a linear sequence of locations are in fact being accessed in the presence of SOA data.
This commit is contained in:
314
llvmutil.cpp
314
llvmutil.cpp
@@ -687,6 +687,253 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
|||||||
std::vector<llvm::PHINode *> &seenPhis);
|
std::vector<llvm::PHINode *> &seenPhis);
|
||||||
|
|
||||||
|
|
||||||
|
/** This function checks to see if the given (scalar or vector) value is an
|
||||||
|
exact multiple of baseValue. It returns true if so, and false if not
|
||||||
|
(or if it's not able to determine if it is). Any vector value passed
|
||||||
|
in is required to have the same value in all elements (so that we can
|
||||||
|
just check the first element to be a multiple of the given value.)
|
||||||
|
*/
|
||||||
|
static bool
|
||||||
|
lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength,
|
||||||
|
std::vector<llvm::PHINode *> &seenPhis) {
|
||||||
|
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()) == false) {
|
||||||
|
// If we've worked down to a constant int, then the moment of truth
|
||||||
|
// has arrived...
|
||||||
|
llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(val);
|
||||||
|
if (ci != NULL)
|
||||||
|
return (ci->getZExtValue() % baseValue) == 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
Assert(LLVMVectorValuesAllEqual(val));
|
||||||
|
|
||||||
|
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(val);
|
||||||
|
if (ie != NULL) {
|
||||||
|
llvm::Value *elts[ISPC_MAX_NVEC];
|
||||||
|
LLVMFlattenInsertChain(ie, g->target.vectorWidth, elts);
|
||||||
|
// We just need to check the scalar first value, since we know that
|
||||||
|
// all elements are equal
|
||||||
|
return lIsExactMultiple(elts[0], baseValue, vectorLength,
|
||||||
|
seenPhis);
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
|
||||||
|
if (phi != NULL) {
|
||||||
|
for (unsigned int i = 0; i < seenPhis.size(); ++i)
|
||||||
|
if (phi == seenPhis[i])
|
||||||
|
return true;
|
||||||
|
|
||||||
|
seenPhis.push_back(phi);
|
||||||
|
unsigned int numIncoming = phi->getNumIncomingValues();
|
||||||
|
|
||||||
|
// Check all of the incoming values: if all of them pass, then
|
||||||
|
// we're good.
|
||||||
|
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||||
|
llvm::Value *incoming = phi->getIncomingValue(i);
|
||||||
|
bool mult = lIsExactMultiple(incoming, baseValue, vectorLength,
|
||||||
|
seenPhis);
|
||||||
|
if (mult == false) {
|
||||||
|
seenPhis.pop_back();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seenPhis.pop_back();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
|
||||||
|
if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add) {
|
||||||
|
llvm::Value *op0 = bop->getOperand(0);
|
||||||
|
llvm::Value *op1 = bop->getOperand(1);
|
||||||
|
|
||||||
|
bool be0 = lIsExactMultiple(op0, baseValue, vectorLength, seenPhis);
|
||||||
|
bool be1 = lIsExactMultiple(op1, baseValue, vectorLength, seenPhis);
|
||||||
|
return (be0 && be1);
|
||||||
|
}
|
||||||
|
// FIXME: mul? casts? ... ?
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Returns the next power of two greater than or equal to the given
|
||||||
|
value. */
|
||||||
|
static int
|
||||||
|
lRoundUpPow2(int v) {
|
||||||
|
v--;
|
||||||
|
v |= v >> 1;
|
||||||
|
v |= v >> 2;
|
||||||
|
v |= v >> 4;
|
||||||
|
v |= v >> 8;
|
||||||
|
v |= v >> 16;
|
||||||
|
return v+1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Try to determine if all of the elements of the given vector value have
|
||||||
|
the same value when divided by the given baseValue. The function
|
||||||
|
returns true if this can be determined to be the case, and false
|
||||||
|
otherwise. (This function may fail to identify some cases where it
|
||||||
|
does in fact have this property, but should never report a given value
|
||||||
|
as being a multiple if it isn't!)
|
||||||
|
*/
|
||||||
|
static bool
|
||||||
|
lAllDivBaseEqual(llvm::Value *val, int baseValue, int vectorLength,
|
||||||
|
std::vector<llvm::PHINode *> &seenPhis,
|
||||||
|
bool &canAdd) {
|
||||||
|
Assert(llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()));
|
||||||
|
// Make sure the base value is a positive power of 2
|
||||||
|
Assert(baseValue > 0 && (baseValue & (baseValue-1)) == 0);
|
||||||
|
|
||||||
|
// The easy case
|
||||||
|
if (lVectorValuesAllEqual(val, vectorLength, seenPhis))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
int64_t vecVals[ISPC_MAX_NVEC];
|
||||||
|
int nElts;
|
||||||
|
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()) &&
|
||||||
|
LLVMExtractVectorInts(val, vecVals, &nElts)) {
|
||||||
|
// If we have a vector of compile-time constant integer values,
|
||||||
|
// then go ahead and check them directly..
|
||||||
|
int64_t firstDiv = vecVals[0] / baseValue;
|
||||||
|
for (int i = 1; i < nElts; ++i)
|
||||||
|
if ((vecVals[i] / baseValue) != firstDiv)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
|
||||||
|
if (phi != NULL) {
|
||||||
|
for (unsigned int i = 0; i < seenPhis.size(); ++i)
|
||||||
|
if (phi == seenPhis[i])
|
||||||
|
return true;
|
||||||
|
|
||||||
|
seenPhis.push_back(phi);
|
||||||
|
unsigned int numIncoming = phi->getNumIncomingValues();
|
||||||
|
|
||||||
|
// Check all of the incoming values: if all of them pass, then
|
||||||
|
// we're good.
|
||||||
|
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||||
|
llvm::Value *incoming = phi->getIncomingValue(i);
|
||||||
|
bool ca = canAdd;
|
||||||
|
bool mult = lAllDivBaseEqual(incoming, baseValue, vectorLength,
|
||||||
|
seenPhis, ca);
|
||||||
|
if (mult == false) {
|
||||||
|
seenPhis.pop_back();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seenPhis.pop_back();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
|
||||||
|
if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add &&
|
||||||
|
canAdd == true) {
|
||||||
|
llvm::Value *op0 = bop->getOperand(0);
|
||||||
|
llvm::Value *op1 = bop->getOperand(1);
|
||||||
|
|
||||||
|
// Otherwise we're only going to worry about the following case,
|
||||||
|
// which comes up often when looping over SOA data:
|
||||||
|
// ashr %val, <constant shift>
|
||||||
|
// where %val = add %smear, <0,1,2,3...>
|
||||||
|
// and where the maximum of the <0,...> vector in the add is less than
|
||||||
|
// 1<<(constant shift),
|
||||||
|
// and where %smear is a smear of a value that is a multiple of
|
||||||
|
// baseValue.
|
||||||
|
|
||||||
|
int64_t addConstants[ISPC_MAX_NVEC];
|
||||||
|
if (LLVMExtractVectorInts(op1, addConstants, &nElts) == false)
|
||||||
|
return false;
|
||||||
|
Assert(nElts == vectorLength);
|
||||||
|
|
||||||
|
// Do all of them give the same value when divided by baseValue?
|
||||||
|
int64_t firstConstDiv = addConstants[0] / baseValue;
|
||||||
|
for (int i = 1; i < vectorLength; ++i)
|
||||||
|
if ((addConstants[i] / baseValue) != firstConstDiv)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (lVectorValuesAllEqual(op0, vectorLength, seenPhis) == false)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Note that canAdd is a reference parameter; setting this ensures
|
||||||
|
// that we don't allow multiple adds in other parts of the chain of
|
||||||
|
// dependent values from here.
|
||||||
|
canAdd = false;
|
||||||
|
|
||||||
|
// Now we need to figure out the required alignment (in numbers of
|
||||||
|
// elements of the underlying type being indexed) of the value to
|
||||||
|
// which these integer addConstant[] values are being added to. We
|
||||||
|
// know that we have addConstant[] values that all give the same
|
||||||
|
// value when divided by baseValue, but we may need a less-strict
|
||||||
|
// alignment than baseValue depending on the actual values.
|
||||||
|
//
|
||||||
|
// As an example, consider a case where the baseValue alignment is
|
||||||
|
// 16, but the addConstants here are <0,1,2,3>. In that case, the
|
||||||
|
// value to which addConstants is added to only needs to be a
|
||||||
|
// multiple of 4. Conversely, if addConstants are <4,5,6,7>, then
|
||||||
|
// we need a multiple of 8 to ensure that the final added result
|
||||||
|
// will still have the same value for all vector elements when
|
||||||
|
// divided by baseValue.
|
||||||
|
//
|
||||||
|
// All that said, here we need to find the maximum value of any of
|
||||||
|
// the addConstants[], mod baseValue. If we round that up to the
|
||||||
|
// next power of 2, we'll have a value that will be no greater than
|
||||||
|
// baseValue and sometimes less.
|
||||||
|
int maxMod = addConstants[0] % baseValue;
|
||||||
|
for (int i = 1; i < vectorLength; ++i)
|
||||||
|
maxMod = std::max(maxMod, int(addConstants[i] % baseValue));
|
||||||
|
int requiredAlignment = lRoundUpPow2(maxMod);
|
||||||
|
|
||||||
|
std::vector<llvm::PHINode *> seenPhisEEM;
|
||||||
|
return lIsExactMultiple(op0, requiredAlignment, vectorLength,
|
||||||
|
seenPhisEEM);
|
||||||
|
}
|
||||||
|
// TODO: could handle mul by a vector of equal constant integer values
|
||||||
|
// and the like here and adjust the 'baseValue' value when it evenly
|
||||||
|
// divides, but unclear if it's worthwhile...
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Given a vector shift right of some value by some amount, try to
|
||||||
|
determine if all of the elements of the final result have the same
|
||||||
|
value (i.e. whether the high bits are all equal, disregarding the low
|
||||||
|
bits that are shifted out.) Returns true if so, and false otherwise.
|
||||||
|
*/
|
||||||
|
static bool
|
||||||
|
lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift,
|
||||||
|
int vectorLength) {
|
||||||
|
// Are we shifting all elements by a compile-time constant amount? If
|
||||||
|
// not, give up.
|
||||||
|
int64_t shiftAmount[ISPC_MAX_NVEC];
|
||||||
|
int nElts;
|
||||||
|
if (LLVMExtractVectorInts(shift, shiftAmount, &nElts) == false)
|
||||||
|
return false;
|
||||||
|
Assert(nElts == vectorLength);
|
||||||
|
|
||||||
|
// Is it the same amount for all elements?
|
||||||
|
for (int i = 0; i < vectorLength; ++i)
|
||||||
|
if (shiftAmount[i] != shiftAmount[0])
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// Now see if the value divided by (1 << shift) can be determined to
|
||||||
|
// have the same value for all vector elements.
|
||||||
|
int pow2 = 1 << shiftAmount[0];
|
||||||
|
bool canAdd = true;
|
||||||
|
std::vector<llvm::PHINode *> seenPhis;
|
||||||
|
bool eq = lAllDivBaseEqual(val, pow2, vectorLength, seenPhis, canAdd);
|
||||||
|
#if 0
|
||||||
|
fprintf(stderr, "check all div base equal:\n");
|
||||||
|
LLVMDumpValue(shift);
|
||||||
|
LLVMDumpValue(val);
|
||||||
|
fprintf(stderr, "----> %s\n\n", eq ? "true" : "false");
|
||||||
|
#endif
|
||||||
|
return eq;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static bool
|
static bool
|
||||||
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||||
std::vector<llvm::PHINode *> &seenPhis) {
|
std::vector<llvm::PHINode *> &seenPhis) {
|
||||||
@@ -707,11 +954,23 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
||||||
if (bop != NULL)
|
if (bop != NULL) {
|
||||||
return (LLVMVectorValuesAllEqual(bop->getOperand(0), vectorLength,
|
// Easy case: both operands are all equal -> return true
|
||||||
seenPhis) &&
|
if (lVectorValuesAllEqual(bop->getOperand(0), vectorLength,
|
||||||
LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
seenPhis) &&
|
||||||
seenPhis));
|
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||||
|
seenPhis))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// If it's a shift, take a special path that tries to check if the
|
||||||
|
// high (surviving) bits of the values are equal.
|
||||||
|
if (bop->getOpcode() == llvm::Instruction::AShr ||
|
||||||
|
bop->getOpcode() == llvm::Instruction::LShr)
|
||||||
|
return lVectorShiftRightAllEqual(bop->getOperand(0),
|
||||||
|
bop->getOperand(1), vectorLength);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
|
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
|
||||||
if (cast != NULL)
|
if (cast != NULL)
|
||||||
@@ -923,6 +1182,45 @@ lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/** Given (op0 AND op1), try and see if we can determine if the result is a
|
||||||
|
linear sequence with a step of "stride" between values. Returns true
|
||||||
|
if so and false otherwise. This pattern comes up when accessing SOA
|
||||||
|
data.
|
||||||
|
*/
|
||||||
|
static bool
|
||||||
|
lCheckAndForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
|
||||||
|
int stride, std::vector<llvm::PHINode *> &seenPhis) {
|
||||||
|
// Require op1 to be a compile-time constant
|
||||||
|
int64_t maskValue[ISPC_MAX_NVEC];
|
||||||
|
int nElts;
|
||||||
|
if (LLVMExtractVectorInts(op1, maskValue, &nElts) == false)
|
||||||
|
return false;
|
||||||
|
Assert(nElts == vectorLength);
|
||||||
|
|
||||||
|
// Is op1 a smear of the same value across all lanes? Give up if not.
|
||||||
|
for (int i = 1; i < vectorLength; ++i)
|
||||||
|
if (maskValue[i] != maskValue[0])
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// If the op1 value isn't power of 2 minus one, then also give up.
|
||||||
|
int64_t maskPlusOne = maskValue[0] + 1;
|
||||||
|
bool isPowTwo = (maskPlusOne & (maskPlusOne - 1)) == 0;
|
||||||
|
if (isPowTwo == false)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
// The case we'll covert here is op0 being a linear vector with desired
|
||||||
|
// stride, and where all of the values of op0, when divided by
|
||||||
|
// maskPlusOne, have the same value.
|
||||||
|
if (lVectorIsLinear(op0, vectorLength, stride, seenPhis) == false)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
bool canAdd = true;
|
||||||
|
bool isMult = lAllDivBaseEqual(op0, maskPlusOne, vectorLength, seenPhis,
|
||||||
|
canAdd);
|
||||||
|
return isMult;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static bool
|
static bool
|
||||||
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||||
std::vector<llvm::PHINode *> &seenPhis) {
|
std::vector<llvm::PHINode *> &seenPhis) {
|
||||||
@@ -971,6 +1269,12 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
|||||||
bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
|
bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
|
||||||
return m1;
|
return m1;
|
||||||
}
|
}
|
||||||
|
else if (bop->getOpcode() == llvm::Instruction::And) {
|
||||||
|
// Special case for some AND-related patterns that come up when
|
||||||
|
// looping over SOA data
|
||||||
|
bool linear = lCheckAndForLinear(op0, op1, vectorLength, stride, seenPhis);
|
||||||
|
return linear;
|
||||||
|
}
|
||||||
else
|
else
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
25
tests/soa-27.ispc
Normal file
25
tests/soa-27.ispc
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
|
||||||
|
struct Point { float x, y, z; };
|
||||||
|
|
||||||
|
export uniform int width() { return programCount; }
|
||||||
|
|
||||||
|
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
|
||||||
|
float a = aFOO[programIndex];
|
||||||
|
|
||||||
|
soa<8> Point pts[10];
|
||||||
|
|
||||||
|
foreach (i = 1 ... 80) {
|
||||||
|
pts[i].x = b*i;
|
||||||
|
pts[i].y = 2*b*i;
|
||||||
|
pts[i].z = 3*b*i;
|
||||||
|
}
|
||||||
|
pts[0].x = pts[0].y = pts[0].z = 0;
|
||||||
|
|
||||||
|
uniform Point up = pts[4];
|
||||||
|
|
||||||
|
RET[programIndex] = up.z;
|
||||||
|
}
|
||||||
|
|
||||||
|
export void result(uniform float RET[]) {
|
||||||
|
RET[programIndex] = 60;
|
||||||
|
}
|
||||||
24
tests/soa-28.ispc
Normal file
24
tests/soa-28.ispc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
struct Point { float x, y, z; };
|
||||||
|
|
||||||
|
export uniform int width() { return programCount; }
|
||||||
|
|
||||||
|
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
|
||||||
|
float a = aFOO[programIndex];
|
||||||
|
|
||||||
|
soa<8> Point pts[10];
|
||||||
|
|
||||||
|
foreach (i = b-5 ... 80) {
|
||||||
|
pts[i].x = b*i;
|
||||||
|
pts[i].y = 2*b*i;
|
||||||
|
pts[i].z = 3*b*i;
|
||||||
|
}
|
||||||
|
|
||||||
|
uniform Point up = pts[4];
|
||||||
|
|
||||||
|
RET[programIndex] = pts[2*programIndex].x;
|
||||||
|
}
|
||||||
|
|
||||||
|
export void result(uniform float RET[]) {
|
||||||
|
RET[programIndex] = 10 * programIndex;
|
||||||
|
}
|
||||||
24
tests/soa-29.ispc
Normal file
24
tests/soa-29.ispc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
|
||||||
|
struct Point { float x, y; int8 zzz0; float z; double aa[3]; int8 zzz; };
|
||||||
|
|
||||||
|
export uniform int width() { return programCount; }
|
||||||
|
|
||||||
|
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
|
||||||
|
float a = aFOO[programIndex];
|
||||||
|
|
||||||
|
soa<8> Point pts[10];
|
||||||
|
|
||||||
|
for (int i = programIndex; i < 16*b; i += programCount) {
|
||||||
|
pts[i].x = b*i;
|
||||||
|
pts[i].y = 2*b*i;
|
||||||
|
pts[i].z = 3*b*i;
|
||||||
|
}
|
||||||
|
|
||||||
|
uniform Point up = pts[4];
|
||||||
|
|
||||||
|
RET[programIndex] = pts[2*programIndex].x;
|
||||||
|
}
|
||||||
|
|
||||||
|
export void result(uniform float RET[]) {
|
||||||
|
RET[programIndex] = 10 * programIndex;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user