Improved optimization of vector select instructions.

Various LLVM optimization passes are turning code like:

%cmp = icmp lt <8 x i32> %foo, %bar
%cmp32 = sext <8 x i1> %cmp to <8 x i32>
. . .
%cmp1 = trunc <8 x i32> %cmp32 to <8 x i1>
%result = select <8 x i1> %cmp1, . . .

Into:

%cmp = icmp lt <8 x i32> %foo, %bar
%cmp32 = zext <8 x i1> %cmp to <8 x i32>   # note: zext
. . .
%cmp1 = icmp ne <8 x i32> %cmp32, zeroinitializer
%result = select <8 x i1> %cmp1, …

Which in turn isn't matched well by the LLVM code generators, which
in turn leads to fairly inefficient code.  (i.e. it doesn't just emit
a vector compare and blend instruction.)

Also, renamed VSelMovmskOptPass to InstructionSimplifyPass to better
describe its functionality.
This commit is contained in:
Matt Pharr
2013-07-24 15:08:07 -07:00
parent 780b0dfe47
commit bba84f247c

175
opt.cpp
View File

@@ -108,7 +108,7 @@
#endif #endif
static llvm::Pass *CreateIntrinsicsOptPass(); static llvm::Pass *CreateIntrinsicsOptPass();
static llvm::Pass *CreateVSelMovmskOptPass(); static llvm::Pass *CreateInstructionSimplifyPass();
static llvm::Pass *CreateImproveMemoryOpsPass(); static llvm::Pass *CreateImproveMemoryOpsPass();
static llvm::Pass *CreateGatherCoalescePass(); static llvm::Pass *CreateGatherCoalescePass();
@@ -476,7 +476,7 @@ Optimize(llvm::Module *module, int optLevel) {
} }
if (!g->opt.disableMaskAllOnOptimizations) { if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass()); optPM.add(CreateInstructionSimplifyPass());
} }
optPM.add(llvm::createDeadInstEliminationPass()); optPM.add(llvm::createDeadInstEliminationPass());
@@ -519,7 +519,7 @@ Optimize(llvm::Module *module, int optLevel) {
if (!g->opt.disableMaskAllOnOptimizations) { if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass()); optPM.add(CreateInstructionSimplifyPass());
} }
if (g->opt.disableGatherScatterOptimizations == false && if (g->opt.disableGatherScatterOptimizations == false &&
@@ -539,7 +539,7 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(llvm::createFunctionInliningPass()); optPM.add(llvm::createFunctionInliningPass());
optPM.add(llvm::createConstantPropagationPass()); optPM.add(llvm::createConstantPropagationPass());
optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass()); optPM.add(CreateInstructionSimplifyPass());
if (g->opt.disableGatherScatterOptimizations == false && if (g->opt.disableGatherScatterOptimizations == false &&
g->target->getVectorWidth() > 1) { g->target->getVectorWidth() > 1) {
@@ -555,18 +555,20 @@ Optimize(llvm::Module *module, int optLevel) {
if (g->opt.disableHandlePseudoMemoryOps == false) if (g->opt.disableHandlePseudoMemoryOps == false)
optPM.add(CreateReplacePseudoMemoryOpsPass()); optPM.add(CreateReplacePseudoMemoryOpsPass());
optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass()); optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createFunctionInliningPass()); optPM.add(llvm::createFunctionInliningPass());
optPM.add(llvm::createArgumentPromotionPass()); optPM.add(llvm::createArgumentPromotionPass());
optPM.add(llvm::createScalarReplAggregatesPass(sr_threshold, false)); optPM.add(llvm::createScalarReplAggregatesPass(sr_threshold, false));
optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createCFGSimplificationPass()); optPM.add(llvm::createCFGSimplificationPass());
optPM.add(llvm::createReassociatePass()); optPM.add(llvm::createReassociatePass());
optPM.add(llvm::createLoopRotatePass()); optPM.add(llvm::createLoopRotatePass());
optPM.add(llvm::createLICMPass()); optPM.add(llvm::createLICMPass());
optPM.add(llvm::createLoopUnswitchPass(false)); optPM.add(llvm::createLoopUnswitchPass(false));
optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createIndVarSimplifyPass()); optPM.add(llvm::createIndVarSimplifyPass());
optPM.add(llvm::createLoopIdiomPass()); optPM.add(llvm::createLoopIdiomPass());
optPM.add(llvm::createLoopDeletionPass()); optPM.add(llvm::createLoopDeletionPass());
@@ -576,17 +578,19 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(CreateIsCompileTimeConstantPass(true)); optPM.add(CreateIsCompileTimeConstantPass(true));
optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass()); optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createMemCpyOptPass()); optPM.add(llvm::createMemCpyOptPass());
optPM.add(llvm::createSCCPPass()); optPM.add(llvm::createSCCPPass());
optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createJumpThreadingPass()); optPM.add(llvm::createJumpThreadingPass());
optPM.add(llvm::createCorrelatedValuePropagationPass()); optPM.add(llvm::createCorrelatedValuePropagationPass());
optPM.add(llvm::createDeadStoreEliminationPass()); optPM.add(llvm::createDeadStoreEliminationPass());
optPM.add(llvm::createAggressiveDCEPass()); optPM.add(llvm::createAggressiveDCEPass());
optPM.add(llvm::createCFGSimplificationPass()); optPM.add(llvm::createCFGSimplificationPass());
optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createInstructionCombiningPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(llvm::createStripDeadPrototypesPass()); optPM.add(llvm::createStripDeadPrototypesPass());
optPM.add(CreateMakeInternalFuncsStaticPass()); optPM.add(CreateMakeInternalFuncsStaticPass());
optPM.add(llvm::createGlobalDCEPass()); optPM.add(llvm::createGlobalDCEPass());
@@ -927,80 +931,153 @@ CreateIntrinsicsOptPass() {
@todo The better thing to do would be to submit a patch to LLVM to get @todo The better thing to do would be to submit a patch to LLVM to get
these; they're presumably pretty simple patterns to match. these; they're presumably pretty simple patterns to match.
*/ */
class VSelMovmskOpt : public llvm::BasicBlockPass { class InstructionSimplifyPass : public llvm::BasicBlockPass {
public: public:
VSelMovmskOpt() InstructionSimplifyPass()
: BasicBlockPass(ID) { } : BasicBlockPass(ID) { }
const char *getPassName() const { return "Vector Select Optimization"; } const char *getPassName() const { return "Vector Select Optimization"; }
bool runOnBasicBlock(llvm::BasicBlock &BB); bool runOnBasicBlock(llvm::BasicBlock &BB);
static char ID; static char ID;
private:
static bool simplifySelect(llvm::SelectInst *selectInst,
llvm::BasicBlock::iterator iter);
static llvm::Value *simplifyBoolVec(llvm::Value *value);
static bool simplifyCall(llvm::CallInst *callInst,
llvm::BasicBlock::iterator iter);
}; };
char VSelMovmskOpt::ID = 0; char InstructionSimplifyPass::ID = 0;
llvm::Value *
InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) {
llvm::TruncInst *trunc = llvm::dyn_cast<llvm::TruncInst>(value);
if (trunc != NULL) {
// Convert trunc({sext,zext}(i1 vector)) -> (i1 vector)
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value);
if (sext &&
sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
return sext->getOperand(0);
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value);
if (zext &&
zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
return zext->getOperand(0);
}
llvm::ICmpInst *icmp = llvm::dyn_cast<llvm::ICmpInst>(value);
if (icmp != NULL) {
// icmp(ne, {sext,zext}(foo), zeroinitializer) -> foo
if (icmp->getSignedPredicate() == llvm::CmpInst::ICMP_NE) {
llvm::Value *op1 = icmp->getOperand(1);
if (llvm::isa<llvm::ConstantAggregateZero>(op1)) {
llvm::Value *op0 = icmp->getOperand(0);
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(op0);
if (sext)
return sext->getOperand(0);
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(op0);
if (zext)
return zext->getOperand(0);
}
}
}
return NULL;
}
bool bool
VSelMovmskOpt::runOnBasicBlock(llvm::BasicBlock &bb) { InstructionSimplifyPass::simplifySelect(llvm::SelectInst *selectInst,
DEBUG_START_PASS("VSelMovmaskOpt"); llvm::BasicBlock::iterator iter) {
if (selectInst->getType()->isVectorTy() == false)
return false;
llvm::Value *factor = selectInst->getOperand(0);
// Simplify all-on or all-off mask values
MaskStatus maskStatus = lGetMaskStatus(factor);
llvm::Value *value = NULL;
if (maskStatus == ALL_ON)
// Mask all on -> replace with the first select value
value = selectInst->getOperand(1);
else if (maskStatus == ALL_OFF)
// Mask all off -> replace with the second select value
value = selectInst->getOperand(2);
if (value != NULL) {
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, value);
return true;
}
// Sometimes earlier LLVM optimization passes generate unnecessarily
// complex expressions for the selection vector, which in turn confuses
// the code generators and leads to sub-optimal code (particularly for
// 8 and 16-bit masks). We'll try to simplify them out here so that
// the code generator patterns match..
if ((factor = simplifyBoolVec(factor)) != NULL) {
llvm::Instruction *newSelect =
llvm::SelectInst::Create(factor, selectInst->getOperand(1),
selectInst->getOperand(2),
selectInst->getName());
llvm::ReplaceInstWithInst(selectInst, newSelect);
return true;
}
return false;
}
bool
InstructionSimplifyPass::simplifyCall(llvm::CallInst *callInst,
llvm::BasicBlock::iterator iter) {
llvm::Function *calledFunc = callInst->getCalledFunction();
// Turn a __movmsk call with a compile-time constant vector into the
// equivalent scalar value.
if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
return false;
uint64_t mask;
if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, LLVMInt64(mask));
return true;
}
return false;
}
bool
InstructionSimplifyPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("InstructionSimplify");
bool modifiedAny = false; bool modifiedAny = false;
restart: restart:
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) { for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
llvm::SelectInst *selectInst = llvm::dyn_cast<llvm::SelectInst>(&*iter); llvm::SelectInst *selectInst = llvm::dyn_cast<llvm::SelectInst>(&*iter);
if (selectInst != NULL && selectInst->getType()->isVectorTy()) { if (selectInst && simplifySelect(selectInst, iter)) {
llvm::Value *factor = selectInst->getOperand(0); modifiedAny = true;
goto restart;
MaskStatus maskStatus = lGetMaskStatus(factor);
llvm::Value *value = NULL;
if (maskStatus == ALL_ON)
// Mask all on -> replace with the first select value
value = selectInst->getOperand(1);
else if (maskStatus == ALL_OFF)
// Mask all off -> replace with the second select value
value = selectInst->getOperand(2);
if (value != NULL) {
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, value);
modifiedAny = true;
goto restart;
}
} }
llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter); llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
if (callInst == NULL) if (callInst && simplifyCall(callInst, iter)) {
continue;
llvm::Function *calledFunc = callInst->getCalledFunction();
if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
continue;
uint64_t mask;
if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
#if 0
fprintf(stderr, "mask %d\n", mask);
callInst->getArgOperand(0)->dump();
fprintf(stderr, "-----------\n");
#endif
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, LLVMInt64(mask));
modifiedAny = true; modifiedAny = true;
goto restart; goto restart;
} }
} }
DEBUG_END_PASS("VSelMovMskOpt"); DEBUG_END_PASS("InstructionSimplify");
return modifiedAny; return modifiedAny;
} }
static llvm::Pass * static llvm::Pass *
CreateVSelMovmskOptPass() { CreateInstructionSimplifyPass() {
return new VSelMovmskOpt; return new InstructionSimplifyPass;
} }