Move check for linear vector to LLVMVectorIsLinear() function.

This commit is contained in:
Matt Pharr
2012-03-19 11:57:04 -07:00
parent e264d95019
commit 60aae16752
3 changed files with 233 additions and 205 deletions

View File

@@ -832,6 +832,228 @@ LLVMVectorValuesAllEqual(llvm::Value *v) {
}
static bool
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
std::vector<llvm::PHINode *> &seenPhis);
/** Given a vector of compile-time constant integer values, test to see if
they are a linear sequence of constant integers starting from an
arbirary value but then having a step of value "stride" between
elements.
*/
static bool
lVectorIsLinearConstantInts(
#ifdef LLVM_3_1svn
llvm::ConstantDataVector *cv,
#else
llvm::ConstantVector *cv,
#endif
int vectorLength,
int stride) {
// Flatten the vector out into the elements array
llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
#ifdef LLVM_3_1svn
for (int i = 0; i < (int)cv->getNumElements(); ++i)
elements.push_back(cv->getElementAsConstant(i));
#else
cv->getVectorElements(elements);
#endif
Assert((int)elements.size() == vectorLength);
llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(elements[0]);
if (ci == NULL)
// Not a vector of integers
return false;
int64_t prevVal = ci->getSExtValue();
// For each element in the array, see if it is both a ConstantInt and
// if the difference between it and the value of the previous element
// is stride. If not, fail.
for (int i = 1; i < vectorLength; ++i) {
ci = llvm::dyn_cast<llvm::ConstantInt>(elements[i]);
if (ci == NULL)
return false;
int64_t nextVal = ci->getSExtValue();
if (prevVal + stride != nextVal)
return false;
prevVal = nextVal;
}
return true;
}
/** Checks to see if (op0 * op1) is a linear vector where the result is a
vector with values that increase by stride.
*/
static bool
lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
int stride, std::vector<llvm::PHINode *> &seenPhis) {
// Is the first operand a constant integer value splatted across all of
// the lanes?
#ifdef LLVM_3_1svn
llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(op0);
#else
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(op0);
#endif
if (cv == NULL)
return false;
llvm::Constant *csplat = cv->getSplatValue();
if (csplat == NULL)
return false;
llvm::ConstantInt *splat = llvm::dyn_cast<llvm::ConstantInt>(csplat);
if (splat == NULL)
return false;
// If the splat value doesn't evenly divide the stride we're looking
// for, there's no way that we can get the linear sequence we're
// looking or.
int64_t splatVal = splat->getSExtValue();
if (splatVal == 0 || splatVal > stride || (stride % splatVal) != 0)
return false;
// Check to see if the other operand is a linear vector with stride
// given by stride/splatVal.
return lVectorIsLinear(op1, vectorLength, (int)(stride / splatVal),
seenPhis);
}
static bool
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
std::vector<llvm::PHINode *> &seenPhis) {
// First try the easy case: if the values are all just constant
// integers and have the expected stride between them, then we're done.
#ifdef LLVM_3_1svn
llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
#else
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
#endif
if (cv != NULL)
return lVectorIsLinearConstantInts(cv, vectorLength, stride);
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
if (bop != NULL) {
// FIXME: is it right to pass the seenPhis to the all equal check as well??
llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
if (bop->getOpcode() == llvm::Instruction::Add) {
// There are two cases to check if we have an add:
//
// programIndex + unif -> ascending linear seqeuence
// unif + programIndex -> ascending linear sequence
bool l0 = lVectorIsLinear(op0, vectorLength, stride, seenPhis);
bool e1 = lVectorValuesAllEqual(op1, vectorLength, seenPhis);
if (l0 && e1)
return true;
bool e0 = lVectorValuesAllEqual(op0, vectorLength, seenPhis);
bool l1 = lVectorIsLinear(op1, vectorLength, stride, seenPhis);
return (e0 && l1);
}
else if (bop->getOpcode() == llvm::Instruction::Sub)
// For subtraction, we only match:
// programIndex - unif -> ascending linear seqeuence
return (lVectorIsLinear(bop->getOperand(0), vectorLength,
stride, seenPhis) &&
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
seenPhis));
else if (bop->getOpcode() == llvm::Instruction::Mul) {
// Multiplies are a bit trickier, so are handled in a separate
// function.
bool m0 = lCheckMulForLinear(op0, op1, vectorLength, stride, seenPhis);
if (m0)
return true;
bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
return m1;
}
else
return false;
}
llvm::CastInst *ci = llvm::dyn_cast<llvm::CastInst>(v);
if (ci != NULL)
return lVectorIsLinear(ci->getOperand(0), vectorLength,
stride, seenPhis);
if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v))
return false;
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
if (phi != NULL) {
for (unsigned int i = 0; i < seenPhis.size(); ++i)
if (seenPhis[i] == phi)
return true;
seenPhis.push_back(phi);
unsigned int numIncoming = phi->getNumIncomingValues();
// Check all of the incoming values: if all of them are all equal,
// then we're good.
for (unsigned int i = 0; i < numIncoming; ++i) {
if (!lVectorIsLinear(phi->getIncomingValue(i), vectorLength, stride,
seenPhis)) {
seenPhis.pop_back();
return false;
}
}
seenPhis.pop_back();
return true;
}
// TODO: is any reason to worry about these?
if (llvm::isa<llvm::InsertElementInst>(v))
return false;
// TODO: we could also handle shuffles, but we haven't yet seen any
// cases where doing so would detect cases where actually have a linear
// vector.
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
if (shuffle != NULL)
return false;
#if 0
fprintf(stderr, "linear check: ");
v->dump();
fprintf(stderr, "\n");
llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
if (inst) {
inst->getParent()->dump();
fprintf(stderr, "\n");
fprintf(stderr, "\n");
}
#endif
return false;
}
/** Given vector of integer-typed values, see if the elements of the array
have a step of 'stride' between their values. This function tries to
handle as many possibilities as possible, including things like all
elements equal to some non-constant value plus an integer offset, etc.
*/
bool
LLVMVectorIsLinear(llvm::Value *v, int stride) {
LLVM_TYPE_CONST llvm::VectorType *vt =
llvm::dyn_cast<LLVM_TYPE_CONST llvm::VectorType>(v->getType());
Assert(vt != NULL);
int vectorLength = vt->getNumElements();
std::vector<llvm::PHINode *> seenPhis;
bool linear = lVectorIsLinear(v, vectorLength, stride, seenPhis);
Debug(SourcePos(), "LLVMVectorIsLinear(%s) -> %s.",
v->getName().str().c_str(), linear ? "true" : "false");
if (g->debugPrint)
LLVMDumpValue(v);
return linear;
}
static void