Add patterns to better-match code generated when accessing SOA data.
In particular, LLVMVectorIsLinear() and LLVMVectorValuesAllEqual() are able to reason a bit about the effects of the shifts and the ANDs that are generated from SOA indexing calculations, so that they can detect more cases where a linear sequence of locations are in fact being accessed in the presence of SOA data.
This commit is contained in:
314
llvmutil.cpp
314
llvmutil.cpp
@@ -687,6 +687,253 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||
std::vector<llvm::PHINode *> &seenPhis);
|
||||
|
||||
|
||||
/** This function checks to see if the given (scalar or vector) value is an
|
||||
exact multiple of baseValue. It returns true if so, and false if not
|
||||
(or if it's not able to determine if it is). Any vector value passed
|
||||
in is required to have the same value in all elements (so that we can
|
||||
just check the first element to be a multiple of the given value.)
|
||||
*/
|
||||
static bool
|
||||
lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength,
|
||||
std::vector<llvm::PHINode *> &seenPhis) {
|
||||
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()) == false) {
|
||||
// If we've worked down to a constant int, then the moment of truth
|
||||
// has arrived...
|
||||
llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(val);
|
||||
if (ci != NULL)
|
||||
return (ci->getZExtValue() % baseValue) == 0;
|
||||
}
|
||||
else
|
||||
Assert(LLVMVectorValuesAllEqual(val));
|
||||
|
||||
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(val);
|
||||
if (ie != NULL) {
|
||||
llvm::Value *elts[ISPC_MAX_NVEC];
|
||||
LLVMFlattenInsertChain(ie, g->target.vectorWidth, elts);
|
||||
// We just need to check the scalar first value, since we know that
|
||||
// all elements are equal
|
||||
return lIsExactMultiple(elts[0], baseValue, vectorLength,
|
||||
seenPhis);
|
||||
}
|
||||
|
||||
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
|
||||
if (phi != NULL) {
|
||||
for (unsigned int i = 0; i < seenPhis.size(); ++i)
|
||||
if (phi == seenPhis[i])
|
||||
return true;
|
||||
|
||||
seenPhis.push_back(phi);
|
||||
unsigned int numIncoming = phi->getNumIncomingValues();
|
||||
|
||||
// Check all of the incoming values: if all of them pass, then
|
||||
// we're good.
|
||||
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||
llvm::Value *incoming = phi->getIncomingValue(i);
|
||||
bool mult = lIsExactMultiple(incoming, baseValue, vectorLength,
|
||||
seenPhis);
|
||||
if (mult == false) {
|
||||
seenPhis.pop_back();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
seenPhis.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
|
||||
if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add) {
|
||||
llvm::Value *op0 = bop->getOperand(0);
|
||||
llvm::Value *op1 = bop->getOperand(1);
|
||||
|
||||
bool be0 = lIsExactMultiple(op0, baseValue, vectorLength, seenPhis);
|
||||
bool be1 = lIsExactMultiple(op1, baseValue, vectorLength, seenPhis);
|
||||
return (be0 && be1);
|
||||
}
|
||||
// FIXME: mul? casts? ... ?
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/** Returns the next power of two greater than or equal to the given
|
||||
value. */
|
||||
static int
|
||||
lRoundUpPow2(int v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
return v+1;
|
||||
}
|
||||
|
||||
|
||||
/** Try to determine if all of the elements of the given vector value have
|
||||
the same value when divided by the given baseValue. The function
|
||||
returns true if this can be determined to be the case, and false
|
||||
otherwise. (This function may fail to identify some cases where it
|
||||
does in fact have this property, but should never report a given value
|
||||
as being a multiple if it isn't!)
|
||||
*/
|
||||
static bool
|
||||
lAllDivBaseEqual(llvm::Value *val, int baseValue, int vectorLength,
|
||||
std::vector<llvm::PHINode *> &seenPhis,
|
||||
bool &canAdd) {
|
||||
Assert(llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()));
|
||||
// Make sure the base value is a positive power of 2
|
||||
Assert(baseValue > 0 && (baseValue & (baseValue-1)) == 0);
|
||||
|
||||
// The easy case
|
||||
if (lVectorValuesAllEqual(val, vectorLength, seenPhis))
|
||||
return true;
|
||||
|
||||
int64_t vecVals[ISPC_MAX_NVEC];
|
||||
int nElts;
|
||||
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(val->getType()) &&
|
||||
LLVMExtractVectorInts(val, vecVals, &nElts)) {
|
||||
// If we have a vector of compile-time constant integer values,
|
||||
// then go ahead and check them directly..
|
||||
int64_t firstDiv = vecVals[0] / baseValue;
|
||||
for (int i = 1; i < nElts; ++i)
|
||||
if ((vecVals[i] / baseValue) != firstDiv)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(val);
|
||||
if (phi != NULL) {
|
||||
for (unsigned int i = 0; i < seenPhis.size(); ++i)
|
||||
if (phi == seenPhis[i])
|
||||
return true;
|
||||
|
||||
seenPhis.push_back(phi);
|
||||
unsigned int numIncoming = phi->getNumIncomingValues();
|
||||
|
||||
// Check all of the incoming values: if all of them pass, then
|
||||
// we're good.
|
||||
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||
llvm::Value *incoming = phi->getIncomingValue(i);
|
||||
bool ca = canAdd;
|
||||
bool mult = lAllDivBaseEqual(incoming, baseValue, vectorLength,
|
||||
seenPhis, ca);
|
||||
if (mult == false) {
|
||||
seenPhis.pop_back();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
seenPhis.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(val);
|
||||
if (bop != NULL && bop->getOpcode() == llvm::Instruction::Add &&
|
||||
canAdd == true) {
|
||||
llvm::Value *op0 = bop->getOperand(0);
|
||||
llvm::Value *op1 = bop->getOperand(1);
|
||||
|
||||
// Otherwise we're only going to worry about the following case,
|
||||
// which comes up often when looping over SOA data:
|
||||
// ashr %val, <constant shift>
|
||||
// where %val = add %smear, <0,1,2,3...>
|
||||
// and where the maximum of the <0,...> vector in the add is less than
|
||||
// 1<<(constant shift),
|
||||
// and where %smear is a smear of a value that is a multiple of
|
||||
// baseValue.
|
||||
|
||||
int64_t addConstants[ISPC_MAX_NVEC];
|
||||
if (LLVMExtractVectorInts(op1, addConstants, &nElts) == false)
|
||||
return false;
|
||||
Assert(nElts == vectorLength);
|
||||
|
||||
// Do all of them give the same value when divided by baseValue?
|
||||
int64_t firstConstDiv = addConstants[0] / baseValue;
|
||||
for (int i = 1; i < vectorLength; ++i)
|
||||
if ((addConstants[i] / baseValue) != firstConstDiv)
|
||||
return false;
|
||||
|
||||
if (lVectorValuesAllEqual(op0, vectorLength, seenPhis) == false)
|
||||
return false;
|
||||
|
||||
// Note that canAdd is a reference parameter; setting this ensures
|
||||
// that we don't allow multiple adds in other parts of the chain of
|
||||
// dependent values from here.
|
||||
canAdd = false;
|
||||
|
||||
// Now we need to figure out the required alignment (in numbers of
|
||||
// elements of the underlying type being indexed) of the value to
|
||||
// which these integer addConstant[] values are being added to. We
|
||||
// know that we have addConstant[] values that all give the same
|
||||
// value when divided by baseValue, but we may need a less-strict
|
||||
// alignment than baseValue depending on the actual values.
|
||||
//
|
||||
// As an example, consider a case where the baseValue alignment is
|
||||
// 16, but the addConstants here are <0,1,2,3>. In that case, the
|
||||
// value to which addConstants is added to only needs to be a
|
||||
// multiple of 4. Conversely, if addConstants are <4,5,6,7>, then
|
||||
// we need a multiple of 8 to ensure that the final added result
|
||||
// will still have the same value for all vector elements when
|
||||
// divided by baseValue.
|
||||
//
|
||||
// All that said, here we need to find the maximum value of any of
|
||||
// the addConstants[], mod baseValue. If we round that up to the
|
||||
// next power of 2, we'll have a value that will be no greater than
|
||||
// baseValue and sometimes less.
|
||||
int maxMod = addConstants[0] % baseValue;
|
||||
for (int i = 1; i < vectorLength; ++i)
|
||||
maxMod = std::max(maxMod, int(addConstants[i] % baseValue));
|
||||
int requiredAlignment = lRoundUpPow2(maxMod);
|
||||
|
||||
std::vector<llvm::PHINode *> seenPhisEEM;
|
||||
return lIsExactMultiple(op0, requiredAlignment, vectorLength,
|
||||
seenPhisEEM);
|
||||
}
|
||||
// TODO: could handle mul by a vector of equal constant integer values
|
||||
// and the like here and adjust the 'baseValue' value when it evenly
|
||||
// divides, but unclear if it's worthwhile...
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/** Given a vector shift right of some value by some amount, try to
|
||||
determine if all of the elements of the final result have the same
|
||||
value (i.e. whether the high bits are all equal, disregarding the low
|
||||
bits that are shifted out.) Returns true if so, and false otherwise.
|
||||
*/
|
||||
static bool
|
||||
lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift,
|
||||
int vectorLength) {
|
||||
// Are we shifting all elements by a compile-time constant amount? If
|
||||
// not, give up.
|
||||
int64_t shiftAmount[ISPC_MAX_NVEC];
|
||||
int nElts;
|
||||
if (LLVMExtractVectorInts(shift, shiftAmount, &nElts) == false)
|
||||
return false;
|
||||
Assert(nElts == vectorLength);
|
||||
|
||||
// Is it the same amount for all elements?
|
||||
for (int i = 0; i < vectorLength; ++i)
|
||||
if (shiftAmount[i] != shiftAmount[0])
|
||||
return false;
|
||||
|
||||
// Now see if the value divided by (1 << shift) can be determined to
|
||||
// have the same value for all vector elements.
|
||||
int pow2 = 1 << shiftAmount[0];
|
||||
bool canAdd = true;
|
||||
std::vector<llvm::PHINode *> seenPhis;
|
||||
bool eq = lAllDivBaseEqual(val, pow2, vectorLength, seenPhis, canAdd);
|
||||
#if 0
|
||||
fprintf(stderr, "check all div base equal:\n");
|
||||
LLVMDumpValue(shift);
|
||||
LLVMDumpValue(val);
|
||||
fprintf(stderr, "----> %s\n\n", eq ? "true" : "false");
|
||||
#endif
|
||||
return eq;
|
||||
}
|
||||
|
||||
|
||||
static bool
|
||||
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||
std::vector<llvm::PHINode *> &seenPhis) {
|
||||
@@ -707,11 +954,23 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||
#endif
|
||||
|
||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
||||
if (bop != NULL)
|
||||
return (LLVMVectorValuesAllEqual(bop->getOperand(0), vectorLength,
|
||||
seenPhis) &&
|
||||
LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||
seenPhis));
|
||||
if (bop != NULL) {
|
||||
// Easy case: both operands are all equal -> return true
|
||||
if (lVectorValuesAllEqual(bop->getOperand(0), vectorLength,
|
||||
seenPhis) &&
|
||||
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||
seenPhis))
|
||||
return true;
|
||||
|
||||
// If it's a shift, take a special path that tries to check if the
|
||||
// high (surviving) bits of the values are equal.
|
||||
if (bop->getOpcode() == llvm::Instruction::AShr ||
|
||||
bop->getOpcode() == llvm::Instruction::LShr)
|
||||
return lVectorShiftRightAllEqual(bop->getOperand(0),
|
||||
bop->getOperand(1), vectorLength);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
|
||||
if (cast != NULL)
|
||||
@@ -923,6 +1182,45 @@ lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
|
||||
}
|
||||
|
||||
|
||||
/** Given (op0 AND op1), try and see if we can determine if the result is a
|
||||
linear sequence with a step of "stride" between values. Returns true
|
||||
if so and false otherwise. This pattern comes up when accessing SOA
|
||||
data.
|
||||
*/
|
||||
static bool
|
||||
lCheckAndForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
|
||||
int stride, std::vector<llvm::PHINode *> &seenPhis) {
|
||||
// Require op1 to be a compile-time constant
|
||||
int64_t maskValue[ISPC_MAX_NVEC];
|
||||
int nElts;
|
||||
if (LLVMExtractVectorInts(op1, maskValue, &nElts) == false)
|
||||
return false;
|
||||
Assert(nElts == vectorLength);
|
||||
|
||||
// Is op1 a smear of the same value across all lanes? Give up if not.
|
||||
for (int i = 1; i < vectorLength; ++i)
|
||||
if (maskValue[i] != maskValue[0])
|
||||
return false;
|
||||
|
||||
// If the op1 value isn't power of 2 minus one, then also give up.
|
||||
int64_t maskPlusOne = maskValue[0] + 1;
|
||||
bool isPowTwo = (maskPlusOne & (maskPlusOne - 1)) == 0;
|
||||
if (isPowTwo == false)
|
||||
return false;
|
||||
|
||||
// The case we'll covert here is op0 being a linear vector with desired
|
||||
// stride, and where all of the values of op0, when divided by
|
||||
// maskPlusOne, have the same value.
|
||||
if (lVectorIsLinear(op0, vectorLength, stride, seenPhis) == false)
|
||||
return false;
|
||||
|
||||
bool canAdd = true;
|
||||
bool isMult = lAllDivBaseEqual(op0, maskPlusOne, vectorLength, seenPhis,
|
||||
canAdd);
|
||||
return isMult;
|
||||
}
|
||||
|
||||
|
||||
static bool
|
||||
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||
std::vector<llvm::PHINode *> &seenPhis) {
|
||||
@@ -971,6 +1269,12 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||
bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
|
||||
return m1;
|
||||
}
|
||||
else if (bop->getOpcode() == llvm::Instruction::And) {
|
||||
// Special case for some AND-related patterns that come up when
|
||||
// looping over SOA data
|
||||
bool linear = lCheckAndForLinear(op0, op1, vectorLength, stride, seenPhis);
|
||||
return linear;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
25
tests/soa-27.ispc
Normal file
25
tests/soa-27.ispc
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
struct Point { float x, y, z; };
|
||||
|
||||
export uniform int width() { return programCount; }
|
||||
|
||||
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
|
||||
float a = aFOO[programIndex];
|
||||
|
||||
soa<8> Point pts[10];
|
||||
|
||||
foreach (i = 1 ... 80) {
|
||||
pts[i].x = b*i;
|
||||
pts[i].y = 2*b*i;
|
||||
pts[i].z = 3*b*i;
|
||||
}
|
||||
pts[0].x = pts[0].y = pts[0].z = 0;
|
||||
|
||||
uniform Point up = pts[4];
|
||||
|
||||
RET[programIndex] = up.z;
|
||||
}
|
||||
|
||||
export void result(uniform float RET[]) {
|
||||
RET[programIndex] = 60;
|
||||
}
|
||||
24
tests/soa-28.ispc
Normal file
24
tests/soa-28.ispc
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
struct Point { float x, y, z; };
|
||||
|
||||
export uniform int width() { return programCount; }
|
||||
|
||||
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
|
||||
float a = aFOO[programIndex];
|
||||
|
||||
soa<8> Point pts[10];
|
||||
|
||||
foreach (i = b-5 ... 80) {
|
||||
pts[i].x = b*i;
|
||||
pts[i].y = 2*b*i;
|
||||
pts[i].z = 3*b*i;
|
||||
}
|
||||
|
||||
uniform Point up = pts[4];
|
||||
|
||||
RET[programIndex] = pts[2*programIndex].x;
|
||||
}
|
||||
|
||||
export void result(uniform float RET[]) {
|
||||
RET[programIndex] = 10 * programIndex;
|
||||
}
|
||||
24
tests/soa-29.ispc
Normal file
24
tests/soa-29.ispc
Normal file
@@ -0,0 +1,24 @@
|
||||
|
||||
struct Point { float x, y; int8 zzz0; float z; double aa[3]; int8 zzz; };
|
||||
|
||||
export uniform int width() { return programCount; }
|
||||
|
||||
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float b) {
|
||||
float a = aFOO[programIndex];
|
||||
|
||||
soa<8> Point pts[10];
|
||||
|
||||
for (int i = programIndex; i < 16*b; i += programCount) {
|
||||
pts[i].x = b*i;
|
||||
pts[i].y = 2*b*i;
|
||||
pts[i].z = 3*b*i;
|
||||
}
|
||||
|
||||
uniform Point up = pts[4];
|
||||
|
||||
RET[programIndex] = pts[2*programIndex].x;
|
||||
}
|
||||
|
||||
export void result(uniform float RET[]) {
|
||||
RET[programIndex] = 10 * programIndex;
|
||||
}
|
||||
Reference in New Issue
Block a user