Improved optimization of vector select instructions.

Various LLVM optimization passes are turning code like:

%cmp = icmp lt <8 x i32> %foo, %bar
%cmp32 = sext <8 x i1> %cmp to <8 x i32>
. . .
%cmp1 = trunc <8 x i32> %cmp32 to <8 x i1>
%result = select <8 x i1> %cmp1, . . .

Into:

%cmp = icmp lt <8 x i32> %foo, %bar
%cmp32 = zext <8 x i1> %cmp to <8 x i32>   # note: zext
. . .
%cmp1 = icmp ne <8 x i32> %cmp32, zeroinitializer
%result = select <8 x i1> %cmp1, …

Which in turn isn't matched well by the LLVM code generators, which
in turn leads to fairly inefficient code.  (i.e. it doesn't just emit
a vector compare and blend instruction.)

Also, renamed VSelMovmskOptPass to InstructionSimplifyPass to better
describe its functionality.
This commit is contained in:
Matt Pharr
2013-07-24 15:08:07 -07:00
parent 780b0dfe47
commit bba84f247c

175
opt.cpp
View File

@@ -108,7 +108,7 @@
#endif
static llvm::Pass *CreateIntrinsicsOptPass();
static llvm::Pass *CreateVSelMovmskOptPass();
static llvm::Pass *CreateInstructionSimplifyPass();
static llvm::Pass *CreateImproveMemoryOpsPass();
static llvm::Pass *CreateGatherCoalescePass();
@@ -476,7 +476,7 @@ Optimize(llvm::Module *module, int optLevel) {
}
if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass());
optPM.add(CreateInstructionSimplifyPass());
}
optPM.add(llvm::createDeadInstEliminationPass());
@@ -519,7 +519,7 @@ Optimize(llvm::Module *module, int optLevel) {
if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass());
optPM.add(CreateInstructionSimplifyPass());
}
if (g->opt.disableGatherScatterOptimizations == false &&
@@ -539,7 +539,7 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(llvm::createFunctionInliningPass());
optPM.add(llvm::createConstantPropagationPass());
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass());
optPM.add(CreateInstructionSimplifyPass());
if (g->opt.disableGatherScatterOptimizations == false &&
g->target->getVectorWidth() > 1) {
@@ -555,18 +555,20 @@ Optimize(llvm::Module *module, int optLevel) {
if (g->opt.disableHandlePseudoMemoryOps == false)
optPM.add(CreateReplacePseudoMemoryOpsPass());
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createFunctionInliningPass());
optPM.add(llvm::createArgumentPromotionPass());
optPM.add(llvm::createScalarReplAggregatesPass(sr_threshold, false));
optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createCFGSimplificationPass());
optPM.add(llvm::createReassociatePass());
optPM.add(llvm::createLoopRotatePass());
optPM.add(llvm::createLICMPass());
optPM.add(llvm::createLoopUnswitchPass(false));
optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createIndVarSimplifyPass());
optPM.add(llvm::createLoopIdiomPass());
optPM.add(llvm::createLoopDeletionPass());
@@ -576,17 +578,19 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(CreateIsCompileTimeConstantPass(true));
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createMemCpyOptPass());
optPM.add(llvm::createSCCPPass());
optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createJumpThreadingPass());
optPM.add(llvm::createCorrelatedValuePropagationPass());
optPM.add(llvm::createDeadStoreEliminationPass());
optPM.add(llvm::createAggressiveDCEPass());
optPM.add(llvm::createCFGSimplificationPass());
optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createStripDeadPrototypesPass());
optPM.add(CreateMakeInternalFuncsStaticPass());
optPM.add(llvm::createGlobalDCEPass());
@@ -927,80 +931,153 @@ CreateIntrinsicsOptPass() {
@todo The better thing to do would be to submit a patch to LLVM to get
these; they're presumably pretty simple patterns to match.
*/
class VSelMovmskOpt : public llvm::BasicBlockPass {
class InstructionSimplifyPass : public llvm::BasicBlockPass {
public:
VSelMovmskOpt()
InstructionSimplifyPass()
: BasicBlockPass(ID) { }
const char *getPassName() const { return "Vector Select Optimization"; }
bool runOnBasicBlock(llvm::BasicBlock &BB);
static char ID;
private:
static bool simplifySelect(llvm::SelectInst *selectInst,
llvm::BasicBlock::iterator iter);
static llvm::Value *simplifyBoolVec(llvm::Value *value);
static bool simplifyCall(llvm::CallInst *callInst,
llvm::BasicBlock::iterator iter);
};
char VSelMovmskOpt::ID = 0;
char InstructionSimplifyPass::ID = 0;
llvm::Value *
InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) {
llvm::TruncInst *trunc = llvm::dyn_cast<llvm::TruncInst>(value);
if (trunc != NULL) {
// Convert trunc({sext,zext}(i1 vector)) -> (i1 vector)
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value);
if (sext &&
sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
return sext->getOperand(0);
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value);
if (zext &&
zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
return zext->getOperand(0);
}
llvm::ICmpInst *icmp = llvm::dyn_cast<llvm::ICmpInst>(value);
if (icmp != NULL) {
// icmp(ne, {sext,zext}(foo), zeroinitializer) -> foo
if (icmp->getSignedPredicate() == llvm::CmpInst::ICMP_NE) {
llvm::Value *op1 = icmp->getOperand(1);
if (llvm::isa<llvm::ConstantAggregateZero>(op1)) {
llvm::Value *op0 = icmp->getOperand(0);
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(op0);
if (sext)
return sext->getOperand(0);
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(op0);
if (zext)
return zext->getOperand(0);
}
}
}
return NULL;
}
bool
VSelMovmskOpt::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("VSelMovmaskOpt");
InstructionSimplifyPass::simplifySelect(llvm::SelectInst *selectInst,
llvm::BasicBlock::iterator iter) {
if (selectInst->getType()->isVectorTy() == false)
return false;
llvm::Value *factor = selectInst->getOperand(0);
// Simplify all-on or all-off mask values
MaskStatus maskStatus = lGetMaskStatus(factor);
llvm::Value *value = NULL;
if (maskStatus == ALL_ON)
// Mask all on -> replace with the first select value
value = selectInst->getOperand(1);
else if (maskStatus == ALL_OFF)
// Mask all off -> replace with the second select value
value = selectInst->getOperand(2);
if (value != NULL) {
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, value);
return true;
}
// Sometimes earlier LLVM optimization passes generate unnecessarily
// complex expressions for the selection vector, which in turn confuses
// the code generators and leads to sub-optimal code (particularly for
// 8 and 16-bit masks). We'll try to simplify them out here so that
// the code generator patterns match..
if ((factor = simplifyBoolVec(factor)) != NULL) {
llvm::Instruction *newSelect =
llvm::SelectInst::Create(factor, selectInst->getOperand(1),
selectInst->getOperand(2),
selectInst->getName());
llvm::ReplaceInstWithInst(selectInst, newSelect);
return true;
}
return false;
}
bool
InstructionSimplifyPass::simplifyCall(llvm::CallInst *callInst,
llvm::BasicBlock::iterator iter) {
llvm::Function *calledFunc = callInst->getCalledFunction();
// Turn a __movmsk call with a compile-time constant vector into the
// equivalent scalar value.
if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
return false;
uint64_t mask;
if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, LLVMInt64(mask));
return true;
}
return false;
}
bool
InstructionSimplifyPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("InstructionSimplify");
bool modifiedAny = false;
restart:
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
llvm::SelectInst *selectInst = llvm::dyn_cast<llvm::SelectInst>(&*iter);
if (selectInst != NULL && selectInst->getType()->isVectorTy()) {
llvm::Value *factor = selectInst->getOperand(0);
MaskStatus maskStatus = lGetMaskStatus(factor);
llvm::Value *value = NULL;
if (maskStatus == ALL_ON)
// Mask all on -> replace with the first select value
value = selectInst->getOperand(1);
else if (maskStatus == ALL_OFF)
// Mask all off -> replace with the second select value
value = selectInst->getOperand(2);
if (value != NULL) {
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, value);
modifiedAny = true;
goto restart;
}
if (selectInst && simplifySelect(selectInst, iter)) {
modifiedAny = true;
goto restart;
}
llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
if (callInst == NULL)
continue;
llvm::Function *calledFunc = callInst->getCalledFunction();
if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
continue;
uint64_t mask;
if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
#if 0
fprintf(stderr, "mask %d\n", mask);
callInst->getArgOperand(0)->dump();
fprintf(stderr, "-----------\n");
#endif
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, LLVMInt64(mask));
if (callInst && simplifyCall(callInst, iter)) {
modifiedAny = true;
goto restart;
}
}
DEBUG_END_PASS("VSelMovMskOpt");
DEBUG_END_PASS("InstructionSimplify");
return modifiedAny;
}
static llvm::Pass *
CreateVSelMovmskOptPass() {
return new VSelMovmskOpt;
CreateInstructionSimplifyPass() {
return new InstructionSimplifyPass;
}