From bba84f247c34f67ed28a357d19a4a7414c590c2b Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Wed, 24 Jul 2013 15:08:07 -0700 Subject: [PATCH] Improved optimization of vector select instructions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Various LLVM optimization passes are turning code like: %cmp = icmp lt <8 x i32> %foo, %bar %cmp32 = sext <8 x i1> %cmp to <8 x i32> . . . %cmp1 = trunc <8 x i32> %cmp32 to <8 x i1> %result = select <8 x i1> %cmp1, . . . Into: %cmp = icmp lt <8 x i32> %foo, %bar %cmp32 = zext <8 x i1> %cmp to <8 x i32> # note: zext . . . %cmp1 = icmp ne <8 x i32> %cmp32, zeroinitializer %result = select <8 x i1> %cmp1, … Which in turn isn't matched well by the LLVM code generators, which in turn leads to fairly inefficient code. (i.e. it doesn't just emit a vector compare and blend instruction.) Also, renamed VSelMovmskOptPass to InstructionSimplifyPass to better describe its functionality. --- opt.cpp | 175 ++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 126 insertions(+), 49 deletions(-) diff --git a/opt.cpp b/opt.cpp index 4701e7df..8efdbc67 100644 --- a/opt.cpp +++ b/opt.cpp @@ -108,7 +108,7 @@ #endif static llvm::Pass *CreateIntrinsicsOptPass(); -static llvm::Pass *CreateVSelMovmskOptPass(); +static llvm::Pass *CreateInstructionSimplifyPass(); static llvm::Pass *CreateImproveMemoryOpsPass(); static llvm::Pass *CreateGatherCoalescePass(); @@ -476,7 +476,7 @@ Optimize(llvm::Module *module, int optLevel) { } if (!g->opt.disableMaskAllOnOptimizations) { optPM.add(CreateIntrinsicsOptPass()); - optPM.add(CreateVSelMovmskOptPass()); + optPM.add(CreateInstructionSimplifyPass()); } optPM.add(llvm::createDeadInstEliminationPass()); @@ -519,7 +519,7 @@ Optimize(llvm::Module *module, int optLevel) { if (!g->opt.disableMaskAllOnOptimizations) { optPM.add(CreateIntrinsicsOptPass()); - optPM.add(CreateVSelMovmskOptPass()); + optPM.add(CreateInstructionSimplifyPass()); } if (g->opt.disableGatherScatterOptimizations == false && @@ -539,7 +539,7 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(llvm::createFunctionInliningPass()); optPM.add(llvm::createConstantPropagationPass()); optPM.add(CreateIntrinsicsOptPass()); - optPM.add(CreateVSelMovmskOptPass()); + optPM.add(CreateInstructionSimplifyPass()); if (g->opt.disableGatherScatterOptimizations == false && g->target->getVectorWidth() > 1) { @@ -555,18 +555,20 @@ Optimize(llvm::Module *module, int optLevel) { if (g->opt.disableHandlePseudoMemoryOps == false) optPM.add(CreateReplacePseudoMemoryOpsPass()); optPM.add(CreateIntrinsicsOptPass()); - optPM.add(CreateVSelMovmskOptPass()); + optPM.add(CreateInstructionSimplifyPass()); optPM.add(llvm::createFunctionInliningPass()); optPM.add(llvm::createArgumentPromotionPass()); optPM.add(llvm::createScalarReplAggregatesPass(sr_threshold, false)); optPM.add(llvm::createInstructionCombiningPass()); + optPM.add(CreateInstructionSimplifyPass()); optPM.add(llvm::createCFGSimplificationPass()); optPM.add(llvm::createReassociatePass()); optPM.add(llvm::createLoopRotatePass()); optPM.add(llvm::createLICMPass()); optPM.add(llvm::createLoopUnswitchPass(false)); optPM.add(llvm::createInstructionCombiningPass()); + optPM.add(CreateInstructionSimplifyPass()); optPM.add(llvm::createIndVarSimplifyPass()); optPM.add(llvm::createLoopIdiomPass()); optPM.add(llvm::createLoopDeletionPass()); @@ -576,17 +578,19 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(CreateIsCompileTimeConstantPass(true)); optPM.add(CreateIntrinsicsOptPass()); - optPM.add(CreateVSelMovmskOptPass()); + optPM.add(CreateInstructionSimplifyPass()); optPM.add(llvm::createMemCpyOptPass()); optPM.add(llvm::createSCCPPass()); optPM.add(llvm::createInstructionCombiningPass()); + optPM.add(CreateInstructionSimplifyPass()); optPM.add(llvm::createJumpThreadingPass()); optPM.add(llvm::createCorrelatedValuePropagationPass()); optPM.add(llvm::createDeadStoreEliminationPass()); optPM.add(llvm::createAggressiveDCEPass()); optPM.add(llvm::createCFGSimplificationPass()); optPM.add(llvm::createInstructionCombiningPass()); + optPM.add(CreateInstructionSimplifyPass()); optPM.add(llvm::createStripDeadPrototypesPass()); optPM.add(CreateMakeInternalFuncsStaticPass()); optPM.add(llvm::createGlobalDCEPass()); @@ -927,80 +931,153 @@ CreateIntrinsicsOptPass() { @todo The better thing to do would be to submit a patch to LLVM to get these; they're presumably pretty simple patterns to match. */ -class VSelMovmskOpt : public llvm::BasicBlockPass { +class InstructionSimplifyPass : public llvm::BasicBlockPass { public: - VSelMovmskOpt() + InstructionSimplifyPass() : BasicBlockPass(ID) { } const char *getPassName() const { return "Vector Select Optimization"; } bool runOnBasicBlock(llvm::BasicBlock &BB); static char ID; + +private: + static bool simplifySelect(llvm::SelectInst *selectInst, + llvm::BasicBlock::iterator iter); + static llvm::Value *simplifyBoolVec(llvm::Value *value); + static bool simplifyCall(llvm::CallInst *callInst, + llvm::BasicBlock::iterator iter); }; -char VSelMovmskOpt::ID = 0; +char InstructionSimplifyPass::ID = 0; + + +llvm::Value * +InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) { + llvm::TruncInst *trunc = llvm::dyn_cast(value); + if (trunc != NULL) { + // Convert trunc({sext,zext}(i1 vector)) -> (i1 vector) + llvm::SExtInst *sext = llvm::dyn_cast(value); + if (sext && + sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType) + return sext->getOperand(0); + + llvm::ZExtInst *zext = llvm::dyn_cast(value); + if (zext && + zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType) + return zext->getOperand(0); + } + + llvm::ICmpInst *icmp = llvm::dyn_cast(value); + if (icmp != NULL) { + // icmp(ne, {sext,zext}(foo), zeroinitializer) -> foo + if (icmp->getSignedPredicate() == llvm::CmpInst::ICMP_NE) { + llvm::Value *op1 = icmp->getOperand(1); + if (llvm::isa(op1)) { + llvm::Value *op0 = icmp->getOperand(0); + llvm::SExtInst *sext = llvm::dyn_cast(op0); + if (sext) + return sext->getOperand(0); + llvm::ZExtInst *zext = llvm::dyn_cast(op0); + if (zext) + return zext->getOperand(0); + } + } + } + return NULL; +} bool -VSelMovmskOpt::runOnBasicBlock(llvm::BasicBlock &bb) { - DEBUG_START_PASS("VSelMovmaskOpt"); +InstructionSimplifyPass::simplifySelect(llvm::SelectInst *selectInst, + llvm::BasicBlock::iterator iter) { + if (selectInst->getType()->isVectorTy() == false) + return false; + + llvm::Value *factor = selectInst->getOperand(0); + + // Simplify all-on or all-off mask values + MaskStatus maskStatus = lGetMaskStatus(factor); + llvm::Value *value = NULL; + if (maskStatus == ALL_ON) + // Mask all on -> replace with the first select value + value = selectInst->getOperand(1); + else if (maskStatus == ALL_OFF) + // Mask all off -> replace with the second select value + value = selectInst->getOperand(2); + if (value != NULL) { + llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), + iter, value); + return true; + } + + // Sometimes earlier LLVM optimization passes generate unnecessarily + // complex expressions for the selection vector, which in turn confuses + // the code generators and leads to sub-optimal code (particularly for + // 8 and 16-bit masks). We'll try to simplify them out here so that + // the code generator patterns match.. + if ((factor = simplifyBoolVec(factor)) != NULL) { + llvm::Instruction *newSelect = + llvm::SelectInst::Create(factor, selectInst->getOperand(1), + selectInst->getOperand(2), + selectInst->getName()); + llvm::ReplaceInstWithInst(selectInst, newSelect); + return true; + } + + return false; +} + + +bool +InstructionSimplifyPass::simplifyCall(llvm::CallInst *callInst, + llvm::BasicBlock::iterator iter) { + llvm::Function *calledFunc = callInst->getCalledFunction(); + + // Turn a __movmsk call with a compile-time constant vector into the + // equivalent scalar value. + if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk")) + return false; + + uint64_t mask; + if (lGetMask(callInst->getArgOperand(0), &mask) == true) { + llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), + iter, LLVMInt64(mask)); + return true; + } + return false; +} + + +bool +InstructionSimplifyPass::runOnBasicBlock(llvm::BasicBlock &bb) { + DEBUG_START_PASS("InstructionSimplify"); bool modifiedAny = false; restart: for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) { llvm::SelectInst *selectInst = llvm::dyn_cast(&*iter); - if (selectInst != NULL && selectInst->getType()->isVectorTy()) { - llvm::Value *factor = selectInst->getOperand(0); - - MaskStatus maskStatus = lGetMaskStatus(factor); - llvm::Value *value = NULL; - if (maskStatus == ALL_ON) - // Mask all on -> replace with the first select value - value = selectInst->getOperand(1); - else if (maskStatus == ALL_OFF) - // Mask all off -> replace with the second select value - value = selectInst->getOperand(2); - - if (value != NULL) { - llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), - iter, value); - modifiedAny = true; - goto restart; - } + if (selectInst && simplifySelect(selectInst, iter)) { + modifiedAny = true; + goto restart; } - llvm::CallInst *callInst = llvm::dyn_cast(&*iter); - if (callInst == NULL) - continue; - - llvm::Function *calledFunc = callInst->getCalledFunction(); - if (calledFunc == NULL || calledFunc != m->module->getFunction("__movmsk")) - continue; - - uint64_t mask; - if (lGetMask(callInst->getArgOperand(0), &mask) == true) { -#if 0 - fprintf(stderr, "mask %d\n", mask); - callInst->getArgOperand(0)->dump(); - fprintf(stderr, "-----------\n"); -#endif - llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), - iter, LLVMInt64(mask)); + if (callInst && simplifyCall(callInst, iter)) { modifiedAny = true; goto restart; } } - DEBUG_END_PASS("VSelMovMskOpt"); + DEBUG_END_PASS("InstructionSimplify"); return modifiedAny; } static llvm::Pass * -CreateVSelMovmskOptPass() { - return new VSelMovmskOpt; +CreateInstructionSimplifyPass() { + return new InstructionSimplifyPass; }