Merge pull request #782 from dbabokin/mic_perf

Fixing MIC performance issue (generate <2,2,2,2> as 2 in cpp).
This commit is contained in:
Dmitry Babokin
2014-04-17 21:24:40 +04:00
3 changed files with 39 additions and 14 deletions

View File

@@ -3286,10 +3286,16 @@ void CWriter::visitBinaryOperator(llvm::Instruction &I) {
if ((I.getOpcode() == llvm::Instruction::Shl ||
I.getOpcode() == llvm::Instruction::LShr ||
I.getOpcode() == llvm::Instruction::AShr)) {
if (LLVMVectorValuesAllEqual(I.getOperand(1))) {
Out << "__extract_element(";
writeOperand(I.getOperand(1));
Out << ", 0) ";
llvm::Value *splat = NULL;
if (LLVMVectorValuesAllEqual(I.getOperand(1), &splat)) {
if (splat) {
// Avoid __extract_element(splat(value), 0), if possible.
writeOperand(splat);
} else {
Out << "__extract_element(";
writeOperand(I.getOperand(1));
Out << ", 0) ";
}
}
else
writeOperand(I.getOperand(1));

View File

@@ -818,7 +818,8 @@ LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) {
static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis);
std::vector<llvm::PHINode *> &seenPhis,
llvm::Value **splatValue = NULL);
/** This function checks to see if the given (scalar or vector) value is an
@@ -1068,20 +1069,37 @@ lVectorShiftRightAllEqual(llvm::Value *val, llvm::Value *shift,
static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
std::vector<llvm::PHINode *> &seenPhis,
llvm::Value **splatValue) {
if (vectorLength == 1)
return true;
if (llvm::isa<llvm::ConstantAggregateZero>(v))
if (llvm::isa<llvm::ConstantAggregateZero>(v)) {
if (splatValue) {
llvm::ConstantAggregateZero *caz =
llvm::dyn_cast<llvm::ConstantAggregateZero>(v);
*splatValue = caz->getSequentialElement();
}
return true;
}
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
if (cv != NULL)
return (cv->getSplatValue() != NULL);
if (cv != NULL) {
llvm::Value* splat = cv->getSplatValue();
if (splat != NULL && splatValue) {
*splatValue = splat;
}
return (splat != NULL);
}
llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v);
if (cdv != NULL)
return (cdv->getSplatValue() != NULL);
if (cdv != NULL) {
llvm::Value* splat = cdv->getSplatValue();
if (splat != NULL && splatValue) {
*splatValue = splat;
}
return (splat != NULL);
}
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
if (bop != NULL) {
@@ -1178,14 +1196,14 @@ lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
where the values are actually all equal.
*/
bool
LLVMVectorValuesAllEqual(llvm::Value *v) {
LLVMVectorValuesAllEqual(llvm::Value *v, llvm::Value **splat) {
llvm::VectorType *vt =
llvm::dyn_cast<llvm::VectorType>(v->getType());
Assert(vt != NULL);
int vectorLength = vt->getNumElements();
std::vector<llvm::PHINode *> seenPhis;
bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis);
bool equal = lVectorValuesAllEqual(v, vectorLength, seenPhis, splat);
Debug(SourcePos(), "LLVMVectorValuesAllEqual(%s) -> %s.",
v->getName().str().c_str(), equal ? "true" : "false");

View File

@@ -228,7 +228,8 @@ extern llvm::Constant *LLVMMaskAllOff;
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */
extern bool LLVMVectorValuesAllEqual(llvm::Value *v);
extern bool LLVMVectorValuesAllEqual(llvm::Value *v,
llvm::Value **splat = NULL);
/** Given vector of integer-typed values, this function returns true if it
can determine that the elements of the vector have a step of 'stride'