For << and >> with C++, detect when all instances are shifting by the same amount.
In this case, we now emit calls to potentially-specialized functions for the left/right shifts that take a single integer value for the shift amount. These in turn can be matched to the corresponding intrinsics for the SSE target. Issue #145.
This commit is contained in:
243
opt.cpp
243
opt.cpp
@@ -921,58 +921,6 @@ char GatherScatterFlattenOpt::ID = 0;
|
||||
llvm::RegisterPass<GatherScatterFlattenOpt> gsf("gs-flatten", "Gather/Scatter Flatten Pass");
|
||||
|
||||
|
||||
/** Given an llvm::Value known to be an integer, return its value as
|
||||
an int64_t.
|
||||
*/
|
||||
static int64_t
|
||||
lGetIntValue(llvm::Value *offset) {
|
||||
llvm::ConstantInt *intOffset = llvm::dyn_cast<llvm::ConstantInt>(offset);
|
||||
Assert(intOffset && (intOffset->getBitWidth() == 32 ||
|
||||
intOffset->getBitWidth() == 64));
|
||||
return intOffset->getSExtValue();
|
||||
}
|
||||
|
||||
/** This function takes chains of InsertElement instructions along the
|
||||
lines of:
|
||||
|
||||
%v0 = insertelement undef, value_0, i32 index_0
|
||||
%v1 = insertelement %v1, value_1, i32 index_1
|
||||
...
|
||||
%vn = insertelement %vn-1, value_n-1, i32 index_n-1
|
||||
|
||||
and initializes the provided elements array such that the i'th
|
||||
llvm::Value * in the array is the element that was inserted into the
|
||||
i'th element of the vector.
|
||||
*/
|
||||
static void
|
||||
lFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth,
|
||||
llvm::Value **elements) {
|
||||
for (int i = 0; i < vectorWidth; ++i)
|
||||
elements[i] = NULL;
|
||||
|
||||
while (ie != NULL) {
|
||||
int64_t iOffset = lGetIntValue(ie->getOperand(2));
|
||||
Assert(iOffset >= 0 && iOffset < vectorWidth);
|
||||
Assert(elements[iOffset] == NULL);
|
||||
|
||||
elements[iOffset] = ie->getOperand(1);
|
||||
|
||||
llvm::Value *insertBase = ie->getOperand(0);
|
||||
ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase);
|
||||
if (ie == NULL) {
|
||||
if (llvm::isa<llvm::UndefValue>(insertBase))
|
||||
return;
|
||||
|
||||
llvm::ConstantVector *cv =
|
||||
llvm::dyn_cast<llvm::ConstantVector>(insertBase);
|
||||
Assert(cv != NULL);
|
||||
Assert(iOffset < (int)cv->getNumOperands());
|
||||
elements[iOffset] = cv->getOperand(iOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/** Check to make sure that this value is actually a pointer in the end.
|
||||
We need to make sure that given an expression like vec(offset) +
|
||||
ptr2int(ptr), lGetBasePointer() doesn't return vec(offset) for the base
|
||||
@@ -1011,7 +959,7 @@ lGetBasePointer(llvm::Value *v) {
|
||||
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
|
||||
if (ie != NULL) {
|
||||
llvm::Value *elements[ISPC_MAX_NVEC];
|
||||
lFlattenInsertChain(ie, g->target.vectorWidth, elements);
|
||||
LLVMFlattenInsertChain(ie, g->target.vectorWidth, elements);
|
||||
|
||||
// Make sure none of the elements is undefined.
|
||||
// TODO: it's probably ok to allow undefined elements and return
|
||||
@@ -1825,187 +1773,6 @@ llvm::RegisterPass<GSImprovementsPass> gsi("gs-improvements",
|
||||
"Gather/Scatter Improvements Pass");
|
||||
|
||||
|
||||
/** Conservative test to see if two llvm::Values are equal. There are
|
||||
(potentially many) cases where the two values actually are equal but
|
||||
this will return false. However, if it does return true, the two
|
||||
vectors definitely are equal.
|
||||
|
||||
@todo This seems to catch all of the cases we currently need it for in
|
||||
practice, but it's be nice to make it a little more robust/general. In
|
||||
general, though, a little something called the halting problem means we
|
||||
won't get all of them.
|
||||
*/
|
||||
static bool
|
||||
lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
|
||||
std::vector<llvm::PHINode *> &seenPhi0,
|
||||
std::vector<llvm::PHINode *> &seenPhi1) {
|
||||
// Thanks to the fact that LLVM hashes and returns the same pointer for
|
||||
// constants (of all sorts, even constant expressions), this first test
|
||||
// actually catches a lot of cases. LLVM's SSA form also helps a lot
|
||||
// with this..
|
||||
if (v0 == v1)
|
||||
return true;
|
||||
|
||||
Assert(seenPhi0.size() == seenPhi1.size());
|
||||
for (unsigned int i = 0; i < seenPhi0.size(); ++i)
|
||||
if (v0 == seenPhi0[i] && v1 == seenPhi1[i])
|
||||
return true;
|
||||
|
||||
llvm::BinaryOperator *bo0 = llvm::dyn_cast<llvm::BinaryOperator>(v0);
|
||||
llvm::BinaryOperator *bo1 = llvm::dyn_cast<llvm::BinaryOperator>(v1);
|
||||
if (bo0 != NULL && bo1 != NULL) {
|
||||
if (bo0->getOpcode() != bo1->getOpcode())
|
||||
return false;
|
||||
return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0),
|
||||
seenPhi0, seenPhi1) &&
|
||||
lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1),
|
||||
seenPhi0, seenPhi1));
|
||||
}
|
||||
|
||||
llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
|
||||
llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
|
||||
if (phi0 != NULL && phi1 != NULL) {
|
||||
if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues())
|
||||
return false;
|
||||
|
||||
seenPhi0.push_back(phi0);
|
||||
seenPhi1.push_back(phi1);
|
||||
|
||||
unsigned int numIncoming = phi0->getNumIncomingValues();
|
||||
// Check all of the incoming values: if all of them are all equal,
|
||||
// then we're good.
|
||||
bool anyFailure = false;
|
||||
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||
Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
|
||||
if (!lValuesAreEqual(phi0->getIncomingValue(i),
|
||||
phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
|
||||
anyFailure = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
seenPhi0.pop_back();
|
||||
seenPhi1.pop_back();
|
||||
|
||||
return !anyFailure;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/** Tests to see if all of the elements of the vector in the 'v' parameter
|
||||
are equal. Like lValuesAreEqual(), this is a conservative test and may
|
||||
return false for arrays where the values are actually all equal. */
|
||||
static bool
|
||||
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||
std::vector<llvm::PHINode *> &seenPhis) {
|
||||
if (llvm::isa<llvm::ConstantAggregateZero>(v))
|
||||
return true;
|
||||
|
||||
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
|
||||
if (cv != NULL)
|
||||
return (cv->getSplatValue() != NULL);
|
||||
|
||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
||||
if (bop != NULL)
|
||||
return (lVectorValuesAllEqual(bop->getOperand(0), vectorLength,
|
||||
seenPhis) &&
|
||||
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||
seenPhis));
|
||||
|
||||
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
|
||||
if (cast != NULL)
|
||||
return lVectorValuesAllEqual(cast->getOperand(0), vectorLength,
|
||||
seenPhis);
|
||||
|
||||
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
|
||||
if (ie != NULL) {
|
||||
llvm::Value *elements[ISPC_MAX_NVEC];
|
||||
lFlattenInsertChain(ie, vectorLength, elements);
|
||||
|
||||
// We will ignore any values of elements[] that are NULL; as they
|
||||
// correspond to undefined values--we just want to see if all of
|
||||
// the defined values have the same value.
|
||||
int lastNonNull = 0;
|
||||
while (lastNonNull < vectorLength && elements[lastNonNull] == NULL)
|
||||
++lastNonNull;
|
||||
|
||||
if (lastNonNull == vectorLength)
|
||||
// all of them are undef!
|
||||
return true;
|
||||
|
||||
for (int i = lastNonNull; i < vectorLength; ++i) {
|
||||
if (elements[i] == NULL)
|
||||
continue;
|
||||
|
||||
std::vector<llvm::PHINode *> seenPhi0;
|
||||
std::vector<llvm::PHINode *> seenPhi1;
|
||||
if (lValuesAreEqual(elements[lastNonNull], elements[i], seenPhi0,
|
||||
seenPhi1) == false)
|
||||
return false;
|
||||
lastNonNull = i;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
|
||||
if (phi) {
|
||||
for (unsigned int i = 0; i < seenPhis.size(); ++i)
|
||||
if (seenPhis[i] == phi)
|
||||
return true;
|
||||
|
||||
seenPhis.push_back(phi);
|
||||
|
||||
unsigned int numIncoming = phi->getNumIncomingValues();
|
||||
// Check all of the incoming values: if all of them are all equal,
|
||||
// then we're good.
|
||||
for (unsigned int i = 0; i < numIncoming; ++i) {
|
||||
if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength,
|
||||
seenPhis)) {
|
||||
seenPhis.pop_back();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
seenPhis.pop_back();
|
||||
return true;
|
||||
}
|
||||
|
||||
Assert(!llvm::isa<llvm::Constant>(v));
|
||||
|
||||
if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v) ||
|
||||
!llvm::isa<llvm::Instruction>(v))
|
||||
return false;
|
||||
|
||||
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
|
||||
if (shuffle != NULL) {
|
||||
llvm::Value *indices = shuffle->getOperand(2);
|
||||
if (lVectorValuesAllEqual(indices, vectorLength, seenPhis))
|
||||
// The easy case--just a smear of the same element across the
|
||||
// whole vector.
|
||||
return true;
|
||||
|
||||
// TODO: handle more general cases?
|
||||
return false;
|
||||
}
|
||||
|
||||
#if 0
|
||||
fprintf(stderr, "all equal: ");
|
||||
v->dump();
|
||||
fprintf(stderr, "\n");
|
||||
llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
|
||||
if (inst) {
|
||||
inst->getParent()->dump();
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
/** Given a vector of compile-time constant integer values, test to see if
|
||||
they are a linear sequence of constant integers starting from an
|
||||
arbirary value but then having a step of value "stride" between
|
||||
@@ -2102,9 +1869,9 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||
// programIndex + unif -> ascending linear seqeuence
|
||||
// unif + programIndex -> ascending linear sequence
|
||||
return ((lVectorIsLinear(op0, vectorLength, stride, seenPhis) &&
|
||||
lVectorValuesAllEqual(op1, vectorLength, seenPhis)) ||
|
||||
LLVMVectorValuesAllEqual(op1, vectorLength, seenPhis)) ||
|
||||
(lVectorIsLinear(op1, vectorLength, stride, seenPhis) &&
|
||||
lVectorValuesAllEqual(op0, vectorLength, seenPhis)));
|
||||
LLVMVectorValuesAllEqual(op0, vectorLength, seenPhis)));
|
||||
else if (bop->getOpcode() == llvm::Instruction::Sub)
|
||||
// For subtraction, we only match:
|
||||
//
|
||||
@@ -2115,7 +1882,7 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
|
||||
// And generate code for that as a vector load + shuffle.
|
||||
return (lVectorIsLinear(bop->getOperand(0), vectorLength,
|
||||
stride, seenPhis) &&
|
||||
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||
LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength,
|
||||
seenPhis));
|
||||
else if (bop->getOpcode() == llvm::Instruction::Mul)
|
||||
// Multiplies are a bit trickier, so are handled in a separate
|
||||
@@ -2313,7 +2080,7 @@ GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
|
||||
{
|
||||
std::vector<llvm::PHINode *> seenPhis;
|
||||
if (lVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) {
|
||||
if (LLVMVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) {
|
||||
// If all the offsets are equal, then compute the single
|
||||
// pointer they all represent based on the first one of them
|
||||
// (arbitrarily).
|
||||
|
||||
Reference in New Issue
Block a user