Merge pull request #782 from dbabokin/mic_perf
Fixing MIC performance issue (generate <2,2,2,2> as 2 in cpp).
This commit is contained in:
@@ -3286,11 +3286,17 @@ void CWriter::visitBinaryOperator(llvm::Instruction &I) {
|
|||||||
if ((I.getOpcode() == llvm::Instruction::Shl ||
|
if ((I.getOpcode() == llvm::Instruction::Shl ||
|
||||||
I.getOpcode() == llvm::Instruction::LShr ||
|
I.getOpcode() == llvm::Instruction::LShr ||
|
||||||
I.getOpcode() == llvm::Instruction::AShr)) {
|
I.getOpcode() == llvm::Instruction::AShr)) {
|
||||||
if (LLVMVectorValuesAllEqual(I.getOperand(1))) {
|
llvm::Value *splat = NULL;
|
||||||
|
if (LLVMVectorValuesAllEqual(I.getOperand(1), &splat)) {
|
||||||
|
if (splat) {
|
||||||
|
// Avoid __extract_element(splat(value), 0), if possible.
|
||||||
|
writeOperand(splat);
|
||||||
|
} else {
|
||||||
Out << "__extract_element(";
|
Out << "__extract_element(";
|
||||||
writeOperand(I.getOperand(1));
|
writeOperand(I.getOperand(1));
|
||||||
Out << ", 0) ";
|
Out << ", 0) ";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else
|
else
|
||||||
writeOperand(I.getOperand(1));
|
writeOperand(I.getOperand(1));
|
||||||
}
|
}
|
||||||
|
|||||||
36
llvmutil.cpp
36
llvmutil.cpp
@@ -818,7 +818,8 @@ LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) {
|
|||||||
|
|
||||||
static bool
|
static bool
|
||||||
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||||
std::vector<llvm::PHINode *> &seenPhis);
|
std::vector<llvm::PHINode *> &seenPhis,
|
||||||
|
llvm::Value **splatValue = NULL);
|
||||||
|
|
||||||
|
|
||||||
/** This function checks to see if the given (scalar or vector) value is an
|
/** This function checks to see if the given (scalar or vector) value is an
|
||||||
@@ -1068,20 +1069,37 @@ lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift,
|
|||||||
|
|
||||||
static bool
|
static bool
|
||||||
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
||||||
std::vector<llvm::PHINode *> &seenPhis) {
|
std::vector<llvm::PHINode *> &seenPhis,
|
||||||
|
llvm::Value **splatValue) {
|
||||||
if (vectorLength == 1)
|
if (vectorLength == 1)
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
if (llvm::isa<llvm::ConstantAggregateZero>(v))
|
if (llvm::isa<llvm::ConstantAggregateZero>(v)) {
|
||||||
|
if (splatValue) {
|
||||||
|
llvm::ConstantAggregateZero *caz =
|
||||||
|
llvm::dyn_cast<llvm::ConstantAggregateZero>(v);
|
||||||
|
*splatValue = caz->getSequentialElement();
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
|
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
|
||||||
if (cv != NULL)
|
if (cv != NULL) {
|
||||||
return (cv->getSplatValue() != NULL);
|
llvm::Value* splat = cv->getSplatValue();
|
||||||
|
if (splat != NULL && splatValue) {
|
||||||
|
*splatValue = splat;
|
||||||
|
}
|
||||||
|
return (splat != NULL);
|
||||||
|
}
|
||||||
|
|
||||||
llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
|
llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
|
||||||
if (cdv != NULL)
|
if (cdv != NULL) {
|
||||||
return (cdv->getSplatValue() != NULL);
|
llvm::Value* splat = cdv->getSplatValue();
|
||||||
|
if (splat != NULL && splatValue) {
|
||||||
|
*splatValue = splat;
|
||||||
|
}
|
||||||
|
return (splat != NULL);
|
||||||
|
}
|
||||||
|
|
||||||
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
|
||||||
if (bop != NULL) {
|
if (bop != NULL) {
|
||||||
@@ -1178,14 +1196,14 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
|
|||||||
where the values are actually all equal.
|
where the values are actually all equal.
|
||||||
*/
|
*/
|
||||||
bool
|
bool
|
||||||
LLVMVectorValuesAllEqual(llvm::Value *v) {
|
LLVMVectorValuesAllEqual(llvm::Value *v, llvm::Value **splat) {
|
||||||
llvm::VectorType *vt =
|
llvm::VectorType *vt =
|
||||||
llvm::dyn_cast<llvm::VectorType>(v->getType());
|
llvm::dyn_cast<llvm::VectorType>(v->getType());
|
||||||
Assert(vt != NULL);
|
Assert(vt != NULL);
|
||||||
int vectorLength = vt->getNumElements();
|
int vectorLength = vt->getNumElements();
|
||||||
|
|
||||||
std::vector<llvm::PHINode *> seenPhis;
|
std::vector<llvm::PHINode *> seenPhis;
|
||||||
bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis);
|
bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis, splat);
|
||||||
|
|
||||||
Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.",
|
Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.",
|
||||||
v->getName().str().c_str(), equal ? "true" : "false");
|
v->getName().str().c_str(), equal ? "true" : "false");
|
||||||
|
|||||||
@@ -228,7 +228,8 @@ extern llvm::Constant *LLVMMaskAllOff;
|
|||||||
/** Tests to see if all of the elements of the vector in the 'v' parameter
|
/** Tests to see if all of the elements of the vector in the 'v' parameter
|
||||||
are equal. Like lValuesAreEqual(), this is a conservative test and may
|
are equal. Like lValuesAreEqual(), this is a conservative test and may
|
||||||
return false for arrays where the values are actually all equal. */
|
return false for arrays where the values are actually all equal. */
|
||||||
extern bool LLVMVectorValuesAllEqual(llvm::Value *v);
|
extern bool LLVMVectorValuesAllEqual(llvm::Value *v,
|
||||||
|
llvm::Value **splat = NULL);
|
||||||
|
|
||||||
/** Given vector of integer-typed values, this function returns true if it
|
/** Given vector of integer-typed values, this function returns true if it
|
||||||
can determine that the elements of the vector have a step of 'stride'
|
can determine that the elements of the vector have a step of 'stride'
|
||||||
|
|||||||
Reference in New Issue
Block a user