LLVMVectorValuesAllEqual() improvements.

Clean up the API, so the caller doesn't have to pass in a vector so
the function can track PHI nodes (do that internally instead.)

Handle casts in lValuesAreEqual().
This commit is contained in:
Matt Pharr
2012-03-19 11:54:18 -07:00
parent 0664f5a724
commit e264d95019
4 changed files with 53 additions and 26 deletions

View File

@@ -2979,9 +2979,7 @@ void CWriter::visitBinaryOperator(Instruction &I) {
if ((I.getOpcode() == Instruction::Shl || if ((I.getOpcode() == Instruction::Shl ||
I.getOpcode() == Instruction::LShr || I.getOpcode() == Instruction::LShr ||
I.getOpcode() == Instruction::AShr)) { I.getOpcode() == Instruction::AShr)) {
std::vector<PHINode *> phis; if (LLVMVectorValuesAllEqual(I.getOperand(1))) {
if (LLVMVectorValuesAllEqual(I.getOperand(1),
vectorWidth, phis)) {
Out << "__extract_element("; Out << "__extract_element(";
writeOperand(I.getOperand(1)); writeOperand(I.getOperand(1));
Out << ", 0) "; Out << ", 0) ";

View File

@@ -512,11 +512,6 @@ LLVMUIntAsType(uint64_t val, LLVM_TYPE_CONST llvm::Type *type) {
(potentially many) cases where the two values actually are equal but (potentially many) cases where the two values actually are equal but
this will return false. However, if it does return true, the two this will return false. However, if it does return true, the two
vectors definitely are equal. vectors definitely are equal.
@todo This seems to catch all of the cases we currently need it for in
practice, but it's be nice to make it a little more robust/general. In
general, though, a little something called the halting problem means we
won't get all of them.
*/ */
static bool static bool
lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
@@ -545,6 +540,15 @@ lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
seenPhi0, seenPhi1)); seenPhi0, seenPhi1));
} }
llvm::CastInst *cast0 = llvm::dyn_cast<llvm::CastInst>(v0);
llvm::CastInst *cast1 = llvm::dyn_cast<llvm::CastInst>(v1);
if (cast0 != NULL && cast1 != NULL) {
if (cast0->getOpcode() != cast1->getOpcode())
return NULL;
return lValuesAreEqual(cast0->getOperand(0), cast1->getOperand(0),
seenPhi0, seenPhi1);
}
llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0); llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1); llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
if (phi0 != NULL && phi1 != NULL) { if (phi0 != NULL && phi1 != NULL) {
@@ -559,6 +563,8 @@ lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
// then we're good. // then we're good.
bool anyFailure = false; bool anyFailure = false;
for (unsigned int i = 0; i < numIncoming; ++i) { for (unsigned int i = 0; i < numIncoming; ++i) {
// FIXME: should it be ok if the incoming blocks are different,
// where we just return faliure in this case?
Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i)); Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
if (!lValuesAreEqual(phi0->getIncomingValue(i), if (!lValuesAreEqual(phi0->getIncomingValue(i),
phi1->getIncomingValue(i), seenPhi0, seenPhi1)) { phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
@@ -676,12 +682,14 @@ LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) {
} }
/** Tests to see if all of the elements of the vector in the 'v' parameter static bool
are equal. Like lValuesAreEqual(), this is a conservative test and may lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
return false for arrays where the values are actually all equal. */ std::vector<llvm::PHINode *> &seenPhis);
bool
LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) { static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
if (vectorLength == 1) if (vectorLength == 1)
return true; return true;
@@ -707,7 +715,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v); llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
if (cast != NULL) if (cast != NULL)
return LLVMVectorValuesAllEqual(cast->getOperand(0), vectorLength, return lVectorValuesAllEqual(cast->getOperand(0), vectorLength,
seenPhis); seenPhis);
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v); llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
@@ -752,7 +760,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
// Check all of the incoming values: if all of them are all equal, // Check all of the incoming values: if all of them are all equal,
// then we're good. // then we're good.
for (unsigned int i = 0; i < numIncoming; ++i) { for (unsigned int i = 0; i < numIncoming; ++i) {
if (!LLVMVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength,
seenPhis)) { seenPhis)) {
seenPhis.pop_back(); seenPhis.pop_back();
return false; return false;
@@ -776,7 +784,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v); llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
if (shuffle != NULL) { if (shuffle != NULL) {
llvm::Value *indices = shuffle->getOperand(2); llvm::Value *indices = shuffle->getOperand(2);
if (LLVMVectorValuesAllEqual(indices, vectorLength, seenPhis)) if (lVectorValuesAllEqual(indices, vectorLength, seenPhis))
// The easy case--just a smear of the same element across the // The easy case--just a smear of the same element across the
// whole vector. // whole vector.
return true; return true;
@@ -801,6 +809,29 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
} }
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. This is a conservative test and may return false for arrays
where the values are actually all equal.
*/
bool
LLVMVectorValuesAllEqual(llvm::Value *v) {
LLVM_TYPE_CONST llvm::VectorType *vt =
llvm::dyn_cast<LLVM_TYPE_CONST llvm::VectorType>(v->getType());
Assert(vt != NULL);
int vectorLength = vt->getNumElements();
std::vector<llvm::PHINode *> seenPhis;
bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis);
Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.",
v->getName().str().c_str(), equal ? "true" : "false");
if (g->debugPrint)
LLVMDumpValue(v);
return equal;
}
static void static void

View File

@@ -228,8 +228,7 @@ extern llvm::Constant *LLVMMaskAllOff;
/** Tests to see if all of the elements of the vector in the 'v' parameter /** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */ return false for arrays where the values are actually all equal. */
extern bool LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, extern bool LLVMVectorValuesAllEqual(llvm::Value *v);
std::vector<llvm::PHINode *> &seenPhis);
/** Given a vector-typed value v, if the vector is a vector with constant /** Given a vector-typed value v, if the vector is a vector with constant
element values, this function extracts those element values into the element values, this function extracts those element values into the

13
opt.cpp
View File

@@ -1562,8 +1562,7 @@ lExtractUniforms(llvm::Value **vec, llvm::Instruction *insertBefore) {
return unif; return unif;
} }
std::vector<llvm::PHINode *> phis; if (LLVMVectorValuesAllEqual(*vec)) {
if (LLVMVectorValuesAllEqual(*vec, g->target.vectorWidth, phis)) {
// FIXME: we may want to redo all of the expression here, in scalar // FIXME: we may want to redo all of the expression here, in scalar
// form (if at all possible), for code quality... // form (if at all possible), for code quality...
llvm::Value *unif = llvm::Value *unif =
@@ -2619,8 +2618,10 @@ GSToLoadStorePass::runOnBasicBlock(llvm::BasicBlock &bb) {
constOffsets, "varying+const_offsets", constOffsets, "varying+const_offsets",
callInst); callInst);
std::vector<llvm::PHINode *> seenPhis; Debug(SourcePos(), "GSToLoadStore: %s.",
if (LLVMVectorValuesAllEqual(fullOffsets, g->target.vectorWidth, seenPhis)) { fullOffsets->getName().str().c_str());
if (LLVMVectorValuesAllEqual(fullOffsets)) {
// If all the offsets are equal, then compute the single // If all the offsets are equal, then compute the single
// pointer they all represent based on the first one of them // pointer they all represent based on the first one of them
// (arbitrarily). // (arbitrarily).
@@ -3688,9 +3689,7 @@ GatherCoalescePass::runOnBasicBlock(llvm::BasicBlock &bb) {
if (lIsMaskAllOn(mask) == false) if (lIsMaskAllOn(mask) == false)
continue; continue;
std::vector<llvm::PHINode *> seenPhis; if (!LLVMVectorValuesAllEqual(variableOffsets))
if (LLVMVectorValuesAllEqual(variableOffsets, g->target.vectorWidth,
seenPhis) == false)
continue; continue;
// coalesceGroup stores the set of gathers that we're going to try to // coalesceGroup stores the set of gathers that we're going to try to