From e264d950193b642fe5131435570f18be931a5e42 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Mon, 19 Mar 2012 11:54:18 -0700 Subject: [PATCH] LLVMVectorValuesAllEqual() improvements. Clean up the API, so the caller doesn't have to pass in a vector so the function can track PHI nodes (do that internally instead.) Handle casts in lValuesAreEqual(). --- cbackend.cpp | 4 +--- llvmutil.cpp | 59 +++++++++++++++++++++++++++++++++++++++------------- llvmutil.h | 3 +-- opt.cpp | 13 ++++++------ 4 files changed, 53 insertions(+), 26 deletions(-) diff --git a/cbackend.cpp b/cbackend.cpp index 6aade4ed..b1a0a907 100644 --- a/cbackend.cpp +++ b/cbackend.cpp @@ -2979,9 +2979,7 @@ void CWriter::visitBinaryOperator(Instruction &I) { if ((I.getOpcode() == Instruction::Shl || I.getOpcode() == Instruction::LShr || I.getOpcode() == Instruction::AShr)) { - std::vector phis; - if (LLVMVectorValuesAllEqual(I.getOperand(1), - vectorWidth, phis)) { + if (LLVMVectorValuesAllEqual(I.getOperand(1))) { Out << "__extract_element("; writeOperand(I.getOperand(1)); Out << ", 0) "; diff --git a/llvmutil.cpp b/llvmutil.cpp index 55bc45a0..a2eaad6f 100644 --- a/llvmutil.cpp +++ b/llvmutil.cpp @@ -512,11 +512,6 @@ LLVMUIntAsType(uint64_t val, LLVM_TYPE_CONST llvm::Type *type) { (potentially many) cases where the two values actually are equal but this will return false. However, if it does return true, the two vectors definitely are equal. - - @todo This seems to catch all of the cases we currently need it for in - practice, but it's be nice to make it a little more robust/general. In - general, though, a little something called the halting problem means we - won't get all of them. */ static bool lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, @@ -545,6 +540,15 @@ lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, seenPhi0, seenPhi1)); } + llvm::CastInst *cast0 = llvm::dyn_cast(v0); + llvm::CastInst *cast1 = llvm::dyn_cast(v1); + if (cast0 != NULL && cast1 != NULL) { + if (cast0->getOpcode() != cast1->getOpcode()) + return NULL; + return lValuesAreEqual(cast0->getOperand(0), cast1->getOperand(0), + seenPhi0, seenPhi1); + } + llvm::PHINode *phi0 = llvm::dyn_cast(v0); llvm::PHINode *phi1 = llvm::dyn_cast(v1); if (phi0 != NULL && phi1 != NULL) { @@ -559,6 +563,8 @@ lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, // then we're good. bool anyFailure = false; for (unsigned int i = 0; i < numIncoming; ++i) { + // FIXME: should it be ok if the incoming blocks are different, + // where we just return faliure in this case? Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i)); if (!lValuesAreEqual(phi0->getIncomingValue(i), phi1->getIncomingValue(i), seenPhi0, seenPhi1)) { @@ -676,12 +682,14 @@ LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) { } -/** Tests to see if all of the elements of the vector in the 'v' parameter - are equal. Like lValuesAreEqual(), this is a conservative test and may - return false for arrays where the values are actually all equal. */ -bool -LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, - std::vector &seenPhis) { +static bool +lVectorValuesAllEqual(llvm::Value *v, int vectorLength, + std::vector &seenPhis); + + +static bool +lVectorValuesAllEqual(llvm::Value *v, int vectorLength, + std::vector &seenPhis) { if (vectorLength == 1) return true; @@ -707,7 +715,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, llvm::CastInst *cast = llvm::dyn_cast(v); if (cast != NULL) - return LLVMVectorValuesAllEqual(cast->getOperand(0), vectorLength, + return lVectorValuesAllEqual(cast->getOperand(0), vectorLength, seenPhis); llvm::InsertElementInst *ie = llvm::dyn_cast(v); @@ -752,7 +760,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, // Check all of the incoming values: if all of them are all equal, // then we're good. for (unsigned int i = 0; i < numIncoming; ++i) { - if (!LLVMVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, + if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, seenPhis)) { seenPhis.pop_back(); return false; @@ -776,7 +784,7 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast(v); if (shuffle != NULL) { llvm::Value *indices = shuffle->getOperand(2); - if (LLVMVectorValuesAllEqual(indices, vectorLength, seenPhis)) + if (lVectorValuesAllEqual(indices, vectorLength, seenPhis)) // The easy case--just a smear of the same element across the // whole vector. return true; @@ -801,6 +809,29 @@ LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, } +/** Tests to see if all of the elements of the vector in the 'v' parameter + are equal. This is a conservative test and may return false for arrays + where the values are actually all equal. +*/ +bool +LLVMVectorValuesAllEqual(llvm::Value *v) { + LLVM_TYPE_CONST llvm::VectorType *vt = + llvm::dyn_cast(v->getType()); + Assert(vt != NULL); + int vectorLength = vt->getNumElements(); + + std::vector seenPhis; + bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis); + + Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.", + v->getName().str().c_str(), equal ? "true" : "false"); + if (g->debugPrint) + LLVMDumpValue(v); + + return equal; +} + + static void diff --git a/llvmutil.h b/llvmutil.h index 8be95f99..ab9fba82 100644 --- a/llvmutil.h +++ b/llvmutil.h @@ -228,8 +228,7 @@ extern llvm::Constant *LLVMMaskAllOff; /** Tests to see if all of the elements of the vector in the 'v' parameter are equal. Like lValuesAreEqual(), this is a conservative test and may return false for arrays where the values are actually all equal. */ -extern bool LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, - std::vector &seenPhis); +extern bool LLVMVectorValuesAllEqual(llvm::Value *v); /** Given a vector-typed value v, if the vector is a vector with constant element values, this function extracts those element values into the diff --git a/opt.cpp b/opt.cpp index cb708636..5d6c3dfd 100644 --- a/opt.cpp +++ b/opt.cpp @@ -1562,8 +1562,7 @@ lExtractUniforms(llvm::Value **vec, llvm::Instruction *insertBefore) { return unif; } - std::vector phis; - if (LLVMVectorValuesAllEqual(*vec, g->target.vectorWidth, phis)) { + if (LLVMVectorValuesAllEqual(*vec)) { // FIXME: we may want to redo all of the expression here, in scalar // form (if at all possible), for code quality... llvm::Value *unif = @@ -2619,8 +2618,10 @@ GSToLoadStorePass::runOnBasicBlock(llvm::BasicBlock &bb) { constOffsets, "varying+const_offsets", callInst); - std::vector seenPhis; - if (LLVMVectorValuesAllEqual(fullOffsets, g->target.vectorWidth, seenPhis)) { + Debug(SourcePos(), "GSToLoadStore: %s.", + fullOffsets->getName().str().c_str()); + + if (LLVMVectorValuesAllEqual(fullOffsets)) { // If all the offsets are equal, then compute the single // pointer they all represent based on the first one of them // (arbitrarily). @@ -3688,9 +3689,7 @@ GatherCoalescePass::runOnBasicBlock(llvm::BasicBlock &bb) { if (lIsMaskAllOn(mask) == false) continue; - std::vector seenPhis; - if (LLVMVectorValuesAllEqual(variableOffsets, g->target.vectorWidth, - seenPhis) == false) + if (!LLVMVectorValuesAllEqual(variableOffsets)) continue; // coalesceGroup stores the set of gathers that we're going to try to