Add patterns to better-match code generated when accessing SOA data.

In particular, LLVMVectorIsLinear() and LLVMVectorValuesAllEqual() are able
to reason a bit about the effects of the shifts and the ANDs that are
generated from SOA indexing calculations, so that they can detect more cases
where a linear sequence of locations are in fact being accessed in
the presence of SOA data.
This commit is contained in:
Matt Pharr
2012-03-19 12:02:07 -07:00
parent 57af0eb64f
commit a062653743
4 changed files with 382 additions and 5 deletions

View File

@@ -687,6 +687,253 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis);
/** This function checks to see if the given (scalar or vector) value is an
exact multiple of baseValue. It returns true if so, and false if not
(or if it's not able to determine if it is). Any vector value passed
in is required to have the same value in all elements (so that we can
just check the first element to be a multiple of the given value.)
*/
static bool
lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()) == false) {
// If we've worked down to a constant int, then the moment of truth
// has arrived...
llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(val);
if (ci != NULL)
return (ci->getZExtValue() % baseValue) == 0;
}
else
Assert(LLVMVectorValuesAllEqual(val));
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(val);
if (ie != NULL) {
llvm::Value *elts[ISPC_MAX_NVEC];
LLVMFlattenInsertChain(ie, g->target.vectorWidth, elts);
// We just need to check the scalar first value, since we know that
// all elements are equal
return lIsExactMultiple(elts[0], baseValue, vectorLength,
seenPhis);
}
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
if (phi != NULL) {
for (unsigned int i = 0; i < seenPhis.size(); ++i)
if (phi == seenPhis[i])
return true;
seenPhis.push_back(phi);
unsigned int numIncoming = phi->getNumIncomingValues();
// Check all of the incoming values: if all of them pass, then
// we're good.
for (unsigned int i = 0; i < numIncoming; ++i) {
llvm::Value *incoming = phi->getIncomingValue(i);
bool mult = lIsExactMultiple(incoming, baseValue, vectorLength,
seenPhis);
if (mult == false) {
seenPhis.pop_back();
return false;
}
}
seenPhis.pop_back();
return true;
}
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add) {
llvm::Value *op0 = bop->getOperand(0);
llvm::Value *op1 = bop->getOperand(1);
bool be0 = lIsExactMultiple(op0, baseValue, vectorLength, seenPhis);
bool be1 = lIsExactMultiple(op1, baseValue, vectorLength, seenPhis);
return (be0 && be1);
}
// FIXME: mul? casts? ... ?
return false;
}
/** Returns the next power of two greater than or equal to the given
value. */
static int
lRoundUpPow2(int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
return v+1;
}
/** Try to determine if all of the elements of the given vector value have
the same value when divided by the given baseValue. The function
returns true if this can be determined to be the case, and false
otherwise. (This function may fail to identify some cases where it
does in fact have this property, but should never report a given value
as being a multiple if it isn't!)
*/
static bool
lAllDivBaseEqual(llvm::Value *val, int baseValue, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis,
bool &canAdd) {
Assert(llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()));
// Make sure the base value is a positive power of 2
Assert(baseValue > 0 && (baseValue & (baseValue-1)) == 0);
// The easy case
if (lVectorValuesAllEqual(val, vectorLength, seenPhis))
return true;
int64_t vecVals[ISPC_MAX_NVEC];
int nElts;
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()) &&
LLVMExtractVectorInts(val, vecVals, &nElts)) {
// If we have a vector of compile-time constant integer values,
// then go ahead and check them directly..
int64_t firstDiv = vecVals[0] / baseValue;
for (int i = 1; i < nElts; ++i)
if ((vecVals[i] / baseValue) != firstDiv)
return false;
return true;
}
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
if (phi != NULL) {
for (unsigned int i = 0; i < seenPhis.size(); ++i)
if (phi == seenPhis[i])
return true;
seenPhis.push_back(phi);
unsigned int numIncoming = phi->getNumIncomingValues();
// Check all of the incoming values: if all of them pass, then
// we're good.
for (unsigned int i = 0; i < numIncoming; ++i) {
llvm::Value *incoming = phi->getIncomingValue(i);
bool ca = canAdd;
bool mult = lAllDivBaseEqual(incoming, baseValue, vectorLength,
seenPhis, ca);
if (mult == false) {
seenPhis.pop_back();
return false;
}
}
seenPhis.pop_back();
return true;
}
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add &&
canAdd == true) {
llvm::Value *op0 = bop->getOperand(0);
llvm::Value *op1 = bop->getOperand(1);
// Otherwise we're only going to worry about the following case,
// which comes up often when looping over SOA data:
// ashr %val, <constant shift>
// where %val = add %smear, <0,1,2,3...>
// and where the maximum of the <0,...> vector in the add is less than
// 1<<(constant shift),
// and where %smear is a smear of a value that is a multiple of
// baseValue.
int64_t addConstants[ISPC_MAX_NVEC];
if (LLVMExtractVectorInts(op1, addConstants, &nElts) == false)
return false;
Assert(nElts == vectorLength);
// Do all of them give the same value when divided by baseValue?
int64_t firstConstDiv = addConstants[0] / baseValue;
for (int i = 1; i < vectorLength; ++i)
if ((addConstants[i] / baseValue) != firstConstDiv)
return false;
if (lVectorValuesAllEqual(op0, vectorLength, seenPhis) == false)
return false;
// Note that canAdd is a reference parameter; setting this ensures
// that we don't allow multiple adds in other parts of the chain of
// dependent values from here.
canAdd = false;
// Now we need to figure out the required alignment (in numbers of
// elements of the underlying type being indexed) of the value to
// which these integer addConstant[] values are being added to. We
// know that we have addConstant[] values that all give the same
// value when divided by baseValue, but we may need a less-strict
// alignment than baseValue depending on the actual values.
//
// As an example, consider a case where the baseValue alignment is
// 16, but the addConstants here are <0,1,2,3>. In that case, the
// value to which addConstants is added to only needs to be a
// multiple of 4. Conversely, if addConstants are <4,5,6,7>, then
// we need a multiple of 8 to ensure that the final added result
// will still have the same value for all vector elements when
// divided by baseValue.
//
// All that said, here we need to find the maximum value of any of
// the addConstants[], mod baseValue. If we round that up to the
// next power of 2, we'll have a value that will be no greater than
// baseValue and sometimes less.
int maxMod = addConstants[0] % baseValue;
for (int i = 1; i < vectorLength; ++i)
maxMod = std::max(maxMod, int(addConstants[i] % baseValue));
int requiredAlignment = lRoundUpPow2(maxMod);
std::vector<llvm::PHINode *> seenPhisEEM;
return lIsExactMultiple(op0, requiredAlignment, vectorLength,
seenPhisEEM);
}
// TODO: could handle mul by a vector of equal constant integer values
// and the like here and adjust the 'baseValue' value when it evenly
// divides, but unclear if it's worthwhile...
return false;
}
/** Given a vector shift right of some value by some amount, try to
determine if all of the elements of the final result have the same
value (i.e. whether the high bits are all equal, disregarding the low
bits that are shifted out.) Returns true if so, and false otherwise.
*/
static bool
lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift,
int vectorLength) {
// Are we shifting all elements by a compile-time constant amount? If
// not, give up.
int64_t shiftAmount[ISPC_MAX_NVEC];
int nElts;
if (LLVMExtractVectorInts(shift, shiftAmount, &nElts) == false)
return false;
Assert(nElts == vectorLength);
// Is it the same amount for all elements?
for (int i = 0; i < vectorLength; ++i)
if (shiftAmount[i] != shiftAmount[0])
return false;
// Now see if the value divided by (1 << shift) can be determined to
// have the same value for all vector elements.
int pow2 = 1 << shiftAmount[0];
bool canAdd = true;
std::vector<llvm::PHINode *> seenPhis;
bool eq = lAllDivBaseEqual(val, pow2, vectorLength, seenPhis, canAdd);
#if 0
fprintf(stderr, "check all div base equal:\n");
LLVMDumpValue(shift);
LLVMDumpValue(val);
fprintf(stderr, "----> %s\n\n", eq ? "true" : "false");
#endif
return eq;
}
static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
@@ -707,11 +954,23 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
#endif
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
if (bop != NULL)
return (LLVMVectorValuesAllEqual(bop->getOperand(0), vectorLength,
seenPhis) &&
LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength,
seenPhis));
if (bop != NULL) {
// Easy case: both operands are all equal -> return true
if (lVectorValuesAllEqual(bop->getOperand(0), vectorLength,
seenPhis) &&
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
seenPhis))
return true;
// If it's a shift, take a special path that tries to check if the
// high (surviving) bits of the values are equal.
if (bop->getOpcode() == llvm::Instruction::AShr ||
bop->getOpcode() == llvm::Instruction::LShr)
return lVectorShiftRightAllEqual(bop->getOperand(0),
bop->getOperand(1), vectorLength);
return false;
}
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
if (cast != NULL)
@@ -923,6 +1182,45 @@ lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
}
/** Given (op0 AND op1), try and see if we can determine if the result is a
linear sequence with a step of "stride" between values. Returns true
if so and false otherwise. This pattern comes up when accessing SOA
data.
*/
static bool
lCheckAndForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
int stride, std::vector<llvm::PHINode *> &seenPhis) {
// Require op1 to be a compile-time constant
int64_t maskValue[ISPC_MAX_NVEC];
int nElts;
if (LLVMExtractVectorInts(op1, maskValue, &nElts) == false)
return false;
Assert(nElts == vectorLength);
// Is op1 a smear of the same value across all lanes? Give up if not.
for (int i = 1; i < vectorLength; ++i)
if (maskValue[i] != maskValue[0])
return false;
// If the op1 value isn't power of 2 minus one, then also give up.
int64_t maskPlusOne = maskValue[0] + 1;
bool isPowTwo = (maskPlusOne & (maskPlusOne - 1)) == 0;
if (isPowTwo == false)
return false;
// The case we'll covert here is op0 being a linear vector with desired
// stride, and where all of the values of op0, when divided by
// maskPlusOne, have the same value.
if (lVectorIsLinear(op0, vectorLength, stride, seenPhis) == false)
return false;
bool canAdd = true;
bool isMult = lAllDivBaseEqual(op0, maskPlusOne, vectorLength, seenPhis,
canAdd);
return isMult;
}
static bool
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
std::vector<llvm::PHINode *> &seenPhis) {
@@ -971,6 +1269,12 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
return m1;
}
else if (bop->getOpcode() == llvm::Instruction::And) {
// Special case for some AND-related patterns that come up when
// looping over SOA data
bool linear = lCheckAndForLinear(op0, op1, vectorLength, stride, seenPhis);
return linear;
}
else
return false;
}

25
tests/soa-27.ispc Normal file
View File

@@ -0,0 +1,25 @@
struct Point { float x, y, z; };
export uniform int width() { return programCount; }
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
float a = aFOO[programIndex];
soa<8> Point pts[10];
foreach (i = 1 ... 80) {
pts[i].x = b*i;
pts[i].y = 2*b*i;
pts[i].z = 3*b*i;
}
pts[0].x = pts[0].y = pts[0].z = 0;
uniform Point up = pts[4];
RET[programIndex] = up.z;
}
export void result(uniform float RET[]) {
RET[programIndex] = 60;
}

24
tests/soa-28.ispc Normal file
View File

@@ -0,0 +1,24 @@
struct Point { float x, y, z; };
export uniform int width() { return programCount; }
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
float a = aFOO[programIndex];
soa<8> Point pts[10];
foreach (i = b-5 ... 80) {
pts[i].x = b*i;
pts[i].y = 2*b*i;
pts[i].z = 3*b*i;
}
uniform Point up = pts[4];
RET[programIndex] = pts[2*programIndex].x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 10 * programIndex;
}

24
tests/soa-29.ispc Normal file
View File

@@ -0,0 +1,24 @@
struct Point { float x, y; int8 zzz0; float z; double aa[3]; int8 zzz; };
export uniform int width() { return programCount; }
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
float a = aFOO[programIndex];
soa<8> Point pts[10];
for (int i = programIndex; i < 16*b; i += programCount) {
pts[i].x = b*i;
pts[i].y = 2*b*i;
pts[i].z = 3*b*i;
}
uniform Point up = pts[4];
RET[programIndex] = pts[2*programIndex].x;
}
export void result(uniform float RET[]) {
RET[programIndex] = 10 * programIndex;
}