diff --git a/cbackend.cpp b/cbackend.cpp index 33906778..1c9626b5 100644 --- a/cbackend.cpp +++ b/cbackend.cpp @@ -3286,10 +3286,16 @@ void CWriter::visitBinaryOperator(llvm::Instruction &I) { if ((I.getOpcode() == llvm::Instruction::Shl || I.getOpcode() == llvm::Instruction::LShr || I.getOpcode() == llvm::Instruction::AShr)) { - if (LLVMVectorValuesAllEqual(I.getOperand(1))) { - Out << "__extract_element("; - writeOperand(I.getOperand(1)); - Out << ", 0) "; + llvm::Value *splat = NULL; + if (LLVMVectorValuesAllEqual(I.getOperand(1), &splat)) { + if (splat) { + // Avoid __extract_element(splat(value), 0), if possible. + writeOperand(splat); + } else { + Out << "__extract_element("; + writeOperand(I.getOperand(1)); + Out << ", 0) "; + } } else writeOperand(I.getOperand(1)); diff --git a/llvmutil.cpp b/llvmutil.cpp index 275cf794..5707bbc9 100644 --- a/llvmutil.cpp +++ b/llvmutil.cpp @@ -818,7 +818,8 @@ LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) { static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, - std::vector &seenPhis); + std::vector &seenPhis, + llvm::Value **splatValue = NULL); /** This function checks to see if the given (scalar or vector) value is an @@ -1068,20 +1069,37 @@ lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift, static bool lVectorValuesAllEqual(llvm::Value *v, int vectorLength, - std::vector &seenPhis) { + std::vector &seenPhis, + llvm::Value **splatValue) { if (vectorLength == 1) return true; - if (llvm::isa(v)) + if (llvm::isa(v)) { + if (splatValue) { + llvm::ConstantAggregateZero *caz = + llvm::dyn_cast(v); + *splatValue = caz->getSequentialElement(); + } return true; + } llvm::ConstantVector *cv = llvm::dyn_cast(v); - if (cv != NULL) - return (cv->getSplatValue() != NULL); + if (cv != NULL) { + llvm::Value* splat = cv->getSplatValue(); + if (splat != NULL && splatValue) { + *splatValue = splat; + } + return (splat != NULL); + } llvm::ConstantDataVector *cdv = llvm::dyn_cast(v); - if (cdv != NULL) - return (cdv->getSplatValue() != NULL); + if (cdv != NULL) { + llvm::Value* splat = cdv->getSplatValue(); + if (splat != NULL && splatValue) { + *splatValue = splat; + } + return (splat != NULL); + } llvm::BinaryOperator *bop = llvm::dyn_cast(v); if (bop != NULL) { @@ -1178,14 +1196,14 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength, where the values are actually all equal. */ bool -LLVMVectorValuesAllEqual(llvm::Value *v) { +LLVMVectorValuesAllEqual(llvm::Value *v, llvm::Value **splat) { llvm::VectorType *vt = llvm::dyn_cast(v->getType()); Assert(vt != NULL); int vectorLength = vt->getNumElements(); std::vector seenPhis; - bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis); + bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis, splat); Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.", v->getName().str().c_str(), equal ? "true" : "false"); diff --git a/llvmutil.h b/llvmutil.h index d6c5ede0..96310b94 100644 --- a/llvmutil.h +++ b/llvmutil.h @@ -228,7 +228,8 @@ extern llvm::Constant *LLVMMaskAllOff; /** Tests to see if all of the elements of the vector in the 'v' parameter are equal. Like lValuesAreEqual(), this is a conservative test and may return false for arrays where the values are actually all equal. */ -extern bool LLVMVectorValuesAllEqual(llvm::Value *v); +extern bool LLVMVectorValuesAllEqual(llvm::Value *v, + llvm::Value **splat = NULL); /** Given vector of integer-typed values, this function returns true if it can determine that the elements of the vector have a step of 'stride'