Improved optimization of vector select instructions.
Various LLVM optimization passes are turning code like: %cmp = icmp lt <8 x i32> %foo, %bar %cmp32 = sext <8 x i1> %cmp to <8 x i32> . . . %cmp1 = trunc <8 x i32> %cmp32 to <8 x i1> %result = select <8 x i1> %cmp1, . . . Into: %cmp = icmp lt <8 x i32> %foo, %bar %cmp32 = zext <8 x i1> %cmp to <8 x i32> # note: zext . . . %cmp1 = icmp ne <8 x i32> %cmp32, zeroinitializer %result = select <8 x i1> %cmp1, … Which in turn isn't matched well by the LLVM code generators, which in turn leads to fairly inefficient code. (i.e. it doesn't just emit a vector compare and blend instruction.) Also, renamed VSelMovmskOptPass to InstructionSimplifyPass to better describe its functionality.
This commit is contained in:
175
opt.cpp
175
opt.cpp
@@ -108,7 +108,7 @@
|
||||
#endif
|
||||
|
||||
static llvm::Pass *CreateIntrinsicsOptPass();
|
||||
static llvm::Pass *CreateVSelMovmskOptPass();
|
||||
static llvm::Pass *CreateInstructionSimplifyPass();
|
||||
|
||||
static llvm::Pass *CreateImproveMemoryOpsPass();
|
||||
static llvm::Pass *CreateGatherCoalescePass();
|
||||
@@ -476,7 +476,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
}
|
||||
if (!g->opt.disableMaskAllOnOptimizations) {
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateVSelMovmskOptPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
}
|
||||
optPM.add(llvm::createDeadInstEliminationPass());
|
||||
|
||||
@@ -519,7 +519,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
|
||||
if (!g->opt.disableMaskAllOnOptimizations) {
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateVSelMovmskOptPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
}
|
||||
|
||||
if (g->opt.disableGatherScatterOptimizations == false &&
|
||||
@@ -539,7 +539,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
optPM.add(llvm::createFunctionInliningPass());
|
||||
optPM.add(llvm::createConstantPropagationPass());
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateVSelMovmskOptPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
|
||||
if (g->opt.disableGatherScatterOptimizations == false &&
|
||||
g->target->getVectorWidth() > 1) {
|
||||
@@ -555,18 +555,20 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
if (g->opt.disableHandlePseudoMemoryOps == false)
|
||||
optPM.add(CreateReplacePseudoMemoryOpsPass());
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateVSelMovmskOptPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
|
||||
optPM.add(llvm::createFunctionInliningPass());
|
||||
optPM.add(llvm::createArgumentPromotionPass());
|
||||
optPM.add(llvm::createScalarReplAggregatesPass(sr_threshold, false));
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
optPM.add(llvm::createCFGSimplificationPass());
|
||||
optPM.add(llvm::createReassociatePass());
|
||||
optPM.add(llvm::createLoopRotatePass());
|
||||
optPM.add(llvm::createLICMPass());
|
||||
optPM.add(llvm::createLoopUnswitchPass(false));
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
optPM.add(llvm::createIndVarSimplifyPass());
|
||||
optPM.add(llvm::createLoopIdiomPass());
|
||||
optPM.add(llvm::createLoopDeletionPass());
|
||||
@@ -576,17 +578,19 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
|
||||
optPM.add(CreateIsCompileTimeConstantPass(true));
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateVSelMovmskOptPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
|
||||
optPM.add(llvm::createMemCpyOptPass());
|
||||
optPM.add(llvm::createSCCPPass());
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
optPM.add(llvm::createJumpThreadingPass());
|
||||
optPM.add(llvm::createCorrelatedValuePropagationPass());
|
||||
optPM.add(llvm::createDeadStoreEliminationPass());
|
||||
optPM.add(llvm::createAggressiveDCEPass());
|
||||
optPM.add(llvm::createCFGSimplificationPass());
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
optPM.add(llvm::createStripDeadPrototypesPass());
|
||||
optPM.add(CreateMakeInternalFuncsStaticPass());
|
||||
optPM.add(llvm::createGlobalDCEPass());
|
||||
@@ -927,80 +931,153 @@ CreateIntrinsicsOptPass() {
|
||||
@todo The better thing to do would be to submit a patch to LLVM to get
|
||||
these; they're presumably pretty simple patterns to match.
|
||||
*/
|
||||
class VSelMovmskOpt : public llvm::BasicBlockPass {
|
||||
class InstructionSimplifyPass : public llvm::BasicBlockPass {
|
||||
public:
|
||||
VSelMovmskOpt()
|
||||
InstructionSimplifyPass()
|
||||
: BasicBlockPass(ID) { }
|
||||
|
||||
const char *getPassName() const { return "Vector Select Optimization"; }
|
||||
bool runOnBasicBlock(llvm::BasicBlock &BB);
|
||||
|
||||
static char ID;
|
||||
|
||||
private:
|
||||
static bool simplifySelect(llvm::SelectInst *selectInst,
|
||||
llvm::BasicBlock::iterator iter);
|
||||
static llvm::Value *simplifyBoolVec(llvm::Value *value);
|
||||
static bool simplifyCall(llvm::CallInst *callInst,
|
||||
llvm::BasicBlock::iterator iter);
|
||||
};
|
||||
|
||||
char VSelMovmskOpt::ID = 0;
|
||||
char InstructionSimplifyPass::ID = 0;
|
||||
|
||||
|
||||
llvm::Value *
|
||||
InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) {
|
||||
llvm::TruncInst *trunc = llvm::dyn_cast<llvm::TruncInst>(value);
|
||||
if (trunc != NULL) {
|
||||
// Convert trunc({sext,zext}(i1 vector)) -> (i1 vector)
|
||||
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value);
|
||||
if (sext &&
|
||||
sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
|
||||
return sext->getOperand(0);
|
||||
|
||||
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value);
|
||||
if (zext &&
|
||||
zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
|
||||
return zext->getOperand(0);
|
||||
}
|
||||
|
||||
llvm::ICmpInst *icmp = llvm::dyn_cast<llvm::ICmpInst>(value);
|
||||
if (icmp != NULL) {
|
||||
// icmp(ne, {sext,zext}(foo), zeroinitializer) -> foo
|
||||
if (icmp->getSignedPredicate() == llvm::CmpInst::ICMP_NE) {
|
||||
llvm::Value *op1 = icmp->getOperand(1);
|
||||
if (llvm::isa<llvm::ConstantAggregateZero>(op1)) {
|
||||
llvm::Value *op0 = icmp->getOperand(0);
|
||||
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(op0);
|
||||
if (sext)
|
||||
return sext->getOperand(0);
|
||||
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(op0);
|
||||
if (zext)
|
||||
return zext->getOperand(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
VSelMovmskOpt::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
DEBUG_START_PASS("VSelMovmaskOpt");
|
||||
InstructionSimplifyPass::simplifySelect(llvm::SelectInst *selectInst,
|
||||
llvm::BasicBlock::iterator iter) {
|
||||
if (selectInst->getType()->isVectorTy() == false)
|
||||
return false;
|
||||
|
||||
llvm::Value *factor = selectInst->getOperand(0);
|
||||
|
||||
// Simplify all-on or all-off mask values
|
||||
MaskStatus maskStatus = lGetMaskStatus(factor);
|
||||
llvm::Value *value = NULL;
|
||||
if (maskStatus == ALL_ON)
|
||||
// Mask all on -> replace with the first select value
|
||||
value = selectInst->getOperand(1);
|
||||
else if (maskStatus == ALL_OFF)
|
||||
// Mask all off -> replace with the second select value
|
||||
value = selectInst->getOperand(2);
|
||||
if (value != NULL) {
|
||||
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
|
||||
iter, value);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Sometimes earlier LLVM optimization passes generate unnecessarily
|
||||
// complex expressions for the selection vector, which in turn confuses
|
||||
// the code generators and leads to sub-optimal code (particularly for
|
||||
// 8 and 16-bit masks). We'll try to simplify them out here so that
|
||||
// the code generator patterns match..
|
||||
if ((factor = simplifyBoolVec(factor)) != NULL) {
|
||||
llvm::Instruction *newSelect =
|
||||
llvm::SelectInst::Create(factor, selectInst->getOperand(1),
|
||||
selectInst->getOperand(2),
|
||||
selectInst->getName());
|
||||
llvm::ReplaceInstWithInst(selectInst, newSelect);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
InstructionSimplifyPass::simplifyCall(llvm::CallInst *callInst,
|
||||
llvm::BasicBlock::iterator iter) {
|
||||
llvm::Function *calledFunc = callInst->getCalledFunction();
|
||||
|
||||
// Turn a __movmsk call with a compile-time constant vector into the
|
||||
// equivalent scalar value.
|
||||
if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
|
||||
return false;
|
||||
|
||||
uint64_t mask;
|
||||
if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
|
||||
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
|
||||
iter, LLVMInt64(mask));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool
|
||||
InstructionSimplifyPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
DEBUG_START_PASS("InstructionSimplify");
|
||||
|
||||
bool modifiedAny = false;
|
||||
|
||||
restart:
|
||||
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
|
||||
llvm::SelectInst *selectInst = llvm::dyn_cast<llvm::SelectInst>(&*iter);
|
||||
if (selectInst != NULL && selectInst->getType()->isVectorTy()) {
|
||||
llvm::Value *factor = selectInst->getOperand(0);
|
||||
|
||||
MaskStatus maskStatus = lGetMaskStatus(factor);
|
||||
llvm::Value *value = NULL;
|
||||
if (maskStatus == ALL_ON)
|
||||
// Mask all on -> replace with the first select value
|
||||
value = selectInst->getOperand(1);
|
||||
else if (maskStatus == ALL_OFF)
|
||||
// Mask all off -> replace with the second select value
|
||||
value = selectInst->getOperand(2);
|
||||
|
||||
if (value != NULL) {
|
||||
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
|
||||
iter, value);
|
||||
modifiedAny = true;
|
||||
goto restart;
|
||||
}
|
||||
if (selectInst && simplifySelect(selectInst, iter)) {
|
||||
modifiedAny = true;
|
||||
goto restart;
|
||||
}
|
||||
|
||||
llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
|
||||
if (callInst == NULL)
|
||||
continue;
|
||||
|
||||
llvm::Function *calledFunc = callInst->getCalledFunction();
|
||||
if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk"))
|
||||
continue;
|
||||
|
||||
uint64_t mask;
|
||||
if (lGetMask(callInst->getArgOperand(0), &mask) == true) {
|
||||
#if 0
|
||||
fprintf(stderr, "mask %d\n", mask);
|
||||
callInst->getArgOperand(0)->dump();
|
||||
fprintf(stderr, "-----------\n");
|
||||
#endif
|
||||
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
|
||||
iter, LLVMInt64(mask));
|
||||
if (callInst && simplifyCall(callInst, iter)) {
|
||||
modifiedAny = true;
|
||||
goto restart;
|
||||
}
|
||||
}
|
||||
|
||||
DEBUG_END_PASS("VSelMovMskOpt");
|
||||
DEBUG_END_PASS("InstructionSimplify");
|
||||
|
||||
return modifiedAny;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Pass *
|
||||
CreateVSelMovmskOptPass() {
|
||||
return new VSelMovmskOpt;
|
||||
CreateInstructionSimplifyPass() {
|
||||
return new InstructionSimplifyPass;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user