Move check for linear vector to LLVMVectorIsLinear() function.
This commit is contained in:
222
llvmutil.cpp
222
llvmutil.cpp
@@ -832,6 +832,228 @@ LLVMVectorValuesAllEqual(llvm::Value *v) {
|
||||
}
|
||||
|
||||
|
||||
static bool
|
||||
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||
std::vector<llvm::PHINode *> &seenPhis);
|
||||
|
||||
/** Given a vector of compile-time constant integer values, test to see if
|
||||
they are a linear sequence of constant integers starting from an
|
||||
arbirary value but then having a step of value "stride" between
|
||||
elements.
|
||||
*/
|
||||
static bool
|
||||
lVectorIsLinearConstantInts(
|
||||
#ifdef LLVM_3_1svn
|
||||
llvm::ConstantDataVector *cv,
|
||||
#else
|
||||
llvm::ConstantVector *cv,
|
||||
#endif
|
||||
int vectorLength,
|
||||
int stride) {
|
||||
// Flatten the vector out into the elements array
|
||||
llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
|
||||
#ifdef LLVM_3_1svn
|
||||
for (int i = 0; i < (int)cv->getNumElements(); ++i)
|
||||
elements.push_back(cv->getElementAsConstant(i));
|
||||
#else
|
||||
cv->getVectorElements(elements);
|
||||
#endif
|
||||
Assert((int)elements.size() == vectorLength);
|
||||
|
||||
llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(elements[0]);
|
||||
if (ci == NULL)
|
||||
// Not a vector of integers
|
||||
return false;
|
||||
|
||||
int64_t prevVal = ci->getSExtValue();
|
||||
|
||||
// For each element in the array, see if it is both a ConstantInt and
|
||||
// if the difference between it and the value of the previous element
|
||||
// is stride. If not, fail.
|
||||
for (int i = 1; i < vectorLength; ++i) {
|
||||
ci = llvm::dyn_cast<llvm::ConstantInt>(elements[i]);
|
||||
if (ci == NULL)
|
||||
return false;
|
||||
|
||||
int64_t nextVal = ci->getSExtValue();
|
||||
if (prevVal + stride != nextVal)
|
||||
return false;
|
||||
|
||||
prevVal = nextVal;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
/** Checks to see if (op0 * op1) is a linear vector where the result is a
|
||||
vector with values that increase by stride.
|
||||
*/
|
||||
static bool
|
||||
lCheckMulForLinear(llvm::Value *op0, llvm::Value *op1, int vectorLength,
|
||||
int stride, std::vector<llvm::PHINode *> &seenPhis) {
|
||||
// Is the first operand a constant integer value splatted across all of
|
||||
// the lanes?
|
||||
#ifdef LLVM_3_1svn
|
||||
llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(op0);
|
||||
#else
|
||||
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(op0);
|
||||
#endif
|
||||
if (cv == NULL)
|
||||
return false;
|
||||
|
||||
llvm::Constant *csplat = cv->getSplatValue();
|
||||
if (csplat == NULL)
|
||||
return false;
|
||||
|
||||
llvm::ConstantInt *splat = llvm::dyn_cast<llvm::ConstantInt>(csplat);
|
||||
if (splat == NULL)
|
||||
return false;
|
||||
|
||||
// If the splat value doesn't evenly divide the stride we're looking
|
||||
// for, there's no way that we can get the linear sequence we're
|
||||
// looking or.
|
||||
int64_t splatVal = splat->getSExtValue();
|
||||
if (splatVal == 0 || splatVal > stride || (stride % splatVal) != 0)
|
||||
return false;
|
||||
|
||||
// Check to see if the other operand is a linear vector with stride
|
||||
// given by stride/splatVal.
|
||||
return lVectorIsLinear(op1, vectorLength, (int)(stride / splatVal),
|
||||
seenPhis);
|
||||
}
|
||||
|
||||
|
||||
static bool
|
||||
lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||
std::vector<llvm::PHINode *> &seenPhis) {
|
||||
// First try the easy case: if the values are all just constant
|
||||
// integers and have the expected stride between them, then we're done.
|
||||
#ifdef LLVM_3_1svn
|
||||
llvm::ConstantDataVector *cv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
|
||||
#else
|
||||
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
|
||||
#endif
|
||||
if (cv != NULL)
|
||||
return lVectorIsLinearConstantInts(cv, vectorLength, stride);
|
||||
|
||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
||||
if (bop != NULL) {
|
||||
// FIXME: is it right to pass the seenPhis to the all equal check as well??
|
||||
llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
|
||||
|
||||
if (bop->getOpcode() == llvm::Instruction::Add) {
|
||||
// There are two cases to check if we have an add:
|
||||
//
|
||||
// programIndex + unif -> ascending linear seqeuence
|
||||
// unif + programIndex -> ascending linear sequence
|
||||
bool l0 = lVectorIsLinear(op0, vectorLength, stride, seenPhis);
|
||||
bool e1 = lVectorValuesAllEqual(op1, vectorLength, seenPhis);
|
||||
if (l0 && e1)
|
||||
return true;
|
||||
|
||||
bool e0 = lVectorValuesAllEqual(op0, vectorLength, seenPhis);
|
||||
bool l1 = lVectorIsLinear(op1, vectorLength, stride, seenPhis);
|
||||
return (e0 && l1);
|
||||
}
|
||||
else if (bop->getOpcode() == llvm::Instruction::Sub)
|
||||
// For subtraction, we only match:
|
||||
// programIndex - unif -> ascending linear seqeuence
|
||||
return (lVectorIsLinear(bop->getOperand(0), vectorLength,
|
||||
stride, seenPhis) &&
|
||||
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||
seenPhis));
|
||||
else if (bop->getOpcode() == llvm::Instruction::Mul) {
|
||||
// Multiplies are a bit trickier, so are handled in a separate
|
||||
// function.
|
||||
bool m0 = lCheckMulForLinear(op0, op1, vectorLength, stride, seenPhis);
|
||||
if (m0)
|
||||
return true;
|
||||
bool m1 = lCheckMulForLinear(op1, op0, vectorLength, stride, seenPhis);
|
||||
return m1;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
llvm::CastInst *ci = llvm::dyn_cast<llvm::CastInst>(v);
|
||||
if (ci != NULL)
|
||||
return lVectorIsLinear(ci->getOperand(0), vectorLength,
|
||||
stride, seenPhis);
|
||||
|
||||
if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v))
|
||||
return false;
|
||||
|
||||
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
|
||||
if (phi != NULL) {
|
||||
for (unsigned int i = 0; i < seenPhis.size(); ++i)
|
||||
if (seenPhis[i] == phi)
|
||||
return true;
|
||||
|
||||
seenPhis.push_back(phi);
|
||||
|
||||
unsigned int numIncoming = phi->getNumIncomingValues();
|
||||
// Check all of the incoming values: if all of them are all equal,
|
||||
// then we're good.
|
||||
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||
if (!lVectorIsLinear(phi->getIncomingValue(i), vectorLength, stride,
|
||||
seenPhis)) {
|
||||
seenPhis.pop_back();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
seenPhis.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO: is any reason to worry about these?
|
||||
if (llvm::isa<llvm::InsertElementInst>(v))
|
||||
return false;
|
||||
|
||||
// TODO: we could also handle shuffles, but we haven't yet seen any
|
||||
// cases where doing so would detect cases where actually have a linear
|
||||
// vector.
|
||||
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
|
||||
if (shuffle != NULL)
|
||||
return false;
|
||||
|
||||
#if 0
|
||||
fprintf(stderr, "linear check: ");
|
||||
v->dump();
|
||||
fprintf(stderr, "\n");
|
||||
llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
|
||||
if (inst) {
|
||||
inst->getParent()->dump();
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/** Given vector of integer-typed values, see if the elements of the array
|
||||
have a step of 'stride' between their values. This function tries to
|
||||
handle as many possibilities as possible, including things like all
|
||||
elements equal to some non-constant value plus an integer offset, etc.
|
||||
*/
|
||||
bool
|
||||
LLVMVectorIsLinear(llvm::Value *v, int stride) {
|
||||
LLVM_TYPE_CONST llvm::VectorType *vt =
|
||||
llvm::dyn_cast<LLVM_TYPE_CONST llvm::VectorType>(v->getType());
|
||||
Assert(vt != NULL);
|
||||
int vectorLength = vt->getNumElements();
|
||||
|
||||
std::vector<llvm::PHINode *> seenPhis;
|
||||
bool linear = lVectorIsLinear(v, vectorLength, stride, seenPhis);
|
||||
Debug(SourcePos(), "LLVMVectorIsLinear(%s) -> %s.",
|
||||
v->getName().str().c_str(), linear ? "true" : "false");
|
||||
if (g->debugPrint)
|
||||
LLVMDumpValue(v);
|
||||
|
||||
return linear;
|
||||
}
|
||||
|
||||
|
||||
static void
|
||||
|
||||
Reference in New Issue
Block a user