LLVMVectorValuesAllEqual() improvements.

Clean up the API, so the caller doesn't have to pass in a vector so
the function can track PHI nodes (do that internally instead.)

Handle casts in lValuesAreEqual().
This commit is contained in:
Matt Pharr
2012-03-19 11:54:18 -07:00
parent 0664f5a724
commit e264d95019
4 changed files with 53 additions and 26 deletions

View File

@@ -2979,9 +2979,7 @@ void CWriter::visitBinaryOperator(Instruction &I) {
if ((I.getOpcode() == Instruction::Shl ||
I.getOpcode() == Instruction::LShr ||
I.getOpcode() == Instruction::AShr)) {
std::vector<PHINode *> phis;
if (LLVMVectorValuesAllEqual(I.getOperand(1),
vectorWidth, phis)) {
if (LLVMVectorValuesAllEqual(I.getOperand(1))) {
Out << "__extract_element(";
writeOperand(I.getOperand(1));
Out << ", 0) ";

View File

@@ -512,11 +512,6 @@ LLVMUIntAsType(uint64_t val, LLVM_TYPE_CONST llvm::Type *type) {
(potentially many) cases where the two values actually are equal but
this will return false. However, if it does return true, the two
vectors definitely are equal.
@todo This seems to catch all of the cases we currently need it for in
practice, but it's be nice to make it a little more robust/general. In
general, though, a little something called the halting problem means we
won't get all of them.
*/
static bool
lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
@@ -545,6 +540,15 @@ lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
seenPhi0, seenPhi1));
}
llvm::CastInst *cast0 = llvm::dyn_cast<llvm::CastInst>(v0);
llvm::CastInst *cast1 = llvm::dyn_cast<llvm::CastInst>(v1);
if (cast0 != NULL && cast1 != NULL) {
if (cast0->getOpcode() != cast1->getOpcode())
return NULL;
return lValuesAreEqual(cast0->getOperand(0), cast1->getOperand(0),
seenPhi0, seenPhi1);
}
llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
if (phi0 != NULL && phi1 != NULL) {
@@ -559,6 +563,8 @@ lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
// then we're good.
bool anyFailure = false;
for (unsigned int i = 0; i < numIncoming; ++i) {
// FIXME: should it be ok if the incoming blocks are different,
// where we just return faliure in this case?
Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
if (!lValuesAreEqual(phi0->getIncomingValue(i),
phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
@@ -676,12 +682,14 @@ LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) {
}
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */
bool
LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis);
static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
if (vectorLength == 1)
return true;
@@ -707,7 +715,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
if (cast != NULL)
return LLVMVectorValuesAllEqual(cast->getOperand(0), vectorLength,
return lVectorValuesAllEqual(cast->getOperand(0), vectorLength,
seenPhis);
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
@@ -752,7 +760,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
// Check all of the incoming values: if all of them are all equal,
// then we're good.
for (unsigned int i = 0; i < numIncoming; ++i) {
if (!LLVMVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength,
if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength,
seenPhis)) {
seenPhis.pop_back();
return false;
@@ -776,7 +784,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
if (shuffle != NULL) {
llvm::Value *indices = shuffle->getOperand(2);
if (LLVMVectorValuesAllEqual(indices, vectorLength, seenPhis))
if (lVectorValuesAllEqual(indices, vectorLength, seenPhis))
// The easy case--just a smear of the same element across the
// whole vector.
return true;
@@ -801,6 +809,29 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
}
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. This is a conservative test and may return false for arrays
where the values are actually all equal.
*/
bool
LLVMVectorValuesAllEqual(llvm::Value *v) {
LLVM_TYPE_CONST llvm::VectorType *vt =
llvm::dyn_cast<LLVM_TYPE_CONST llvm::VectorType>(v->getType());
Assert(vt != NULL);
int vectorLength = vt->getNumElements();
std::vector<llvm::PHINode *> seenPhis;
bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis);
Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.",
v->getName().str().c_str(), equal ? "true" : "false");
if (g->debugPrint)
LLVMDumpValue(v);
return equal;
}
static void

View File

@@ -228,8 +228,7 @@ extern llvm::Constant *LLVMMaskAllOff;
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */
extern bool LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis);
extern bool LLVMVectorValuesAllEqual(llvm::Value *v);
/** Given a vector-typed value v, if the vector is a vector with constant
element values, this function extracts those element values into the

13
opt.cpp
View File

@@ -1562,8 +1562,7 @@ lExtractUniforms(llvm::Value **vec, llvm::Instruction *insertBefore) {
return unif;
}
std::vector<llvm::PHINode *> phis;
if (LLVMVectorValuesAllEqual(*vec, g->target.vectorWidth, phis)) {
if (LLVMVectorValuesAllEqual(*vec)) {
// FIXME: we may want to redo all of the expression here, in scalar
// form (if at all possible), for code quality...
llvm::Value *unif =
@@ -2619,8 +2618,10 @@ GSToLoadStorePass::runOnBasicBlock(llvm::BasicBlock &bb) {
constOffsets, "varying+const_offsets",
callInst);
std::vector<llvm::PHINode *> seenPhis;
if (LLVMVectorValuesAllEqual(fullOffsets, g->target.vectorWidth, seenPhis)) {
Debug(SourcePos(), "GSToLoadStore: %s.",
fullOffsets->getName().str().c_str());
if (LLVMVectorValuesAllEqual(fullOffsets)) {
// If all the offsets are equal, then compute the single
// pointer they all represent based on the first one of them
// (arbitrarily).
@@ -3688,9 +3689,7 @@ GatherCoalescePass::runOnBasicBlock(llvm::BasicBlock &bb) {
if (lIsMaskAllOn(mask) == false)
continue;
std::vector<llvm::PHINode *> seenPhis;
if (LLVMVectorValuesAllEqual(variableOffsets, g->target.vectorWidth,
seenPhis) == false)
if (!LLVMVectorValuesAllEqual(variableOffsets))
continue;
// coalesceGroup stores the set of gathers that we're going to try to