From ccdbddd388bf494bf3cb4aaf6a90cbb684cd18f0 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Tue, 6 Aug 2013 08:59:46 -0700 Subject: [PATCH] Add peephole optimization to match int8/int16 averages. Match the following patterns in IR, turning them into target-specific intrinsics (e.g. PAVGB on x86) when possible. (unsigned int8)(((unsigned int16)a + (unsigned int16)b + 1)/2) (unsigned int8)(((unsigned int16)a + (unsigned int16)b)/2) (unsigned int16)(((unsigned int32)a + (unsigned int32)b + 1)/2) (unsigned int16)(((unsigned int32)a + (unsigned int32)b)/2) (int8)(((int16)a + (int16)b + 1)/2) (int8)(((int16)a + (int16)b)/2) (int16)(((int32)a + (int32)b + 1)/2) (int16)(((int32)a + (int32)b)/2) --- opt.cpp | 393 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 393 insertions(+) diff --git a/opt.cpp b/opt.cpp index b363f0e1..8899c64d 100644 --- a/opt.cpp +++ b/opt.cpp @@ -84,6 +84,7 @@ #include #include #include +#include #include #ifdef ISPC_IS_LINUX #include @@ -103,6 +104,7 @@ static llvm::Pass *CreateIntrinsicsOptPass(); static llvm::Pass *CreateInstructionSimplifyPass(); +static llvm::Pass *CreatePeepholePass(); static llvm::Pass *CreateImproveMemoryOpsPass(); static llvm::Pass *CreateGatherCoalescePass(); @@ -459,6 +461,9 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(llvm::createDeadInstEliminationPass()); optPM.add(llvm::createCFGSimplificationPass()); + optPM.add(llvm::createPromoteMemoryToRegisterPass()); + optPM.add(llvm::createAggressiveDCEPass()); + if (g->opt.disableGatherScatterOptimizations == false && g->target->getVectorWidth() > 1) { optPM.add(llvm::createInstructionCombiningPass()); @@ -500,6 +505,7 @@ Optimize(llvm::Module *module, int optLevel) { // InstructionCombiningPass. See r184459 for details. optPM.add(llvm::createSimplifyLibCallsPass()); #endif + optPM.add(llvm::createAggressiveDCEPass()); optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createJumpThreadingPass()); optPM.add(llvm::createCFGSimplificationPass()); @@ -539,6 +545,7 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(llvm::createIPSCCPPass()); optPM.add(llvm::createDeadArgEliminationPass()); + optPM.add(llvm::createAggressiveDCEPass()); optPM.add(llvm::createInstructionCombiningPass()); optPM.add(llvm::createCFGSimplificationPass()); @@ -581,6 +588,9 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(llvm::createCFGSimplificationPass()); optPM.add(llvm::createInstructionCombiningPass()); optPM.add(CreateInstructionSimplifyPass()); + optPM.add(CreatePeepholePass()); + optPM.add(llvm::createFunctionInliningPass()); + optPM.add(llvm::createAggressiveDCEPass()); optPM.add(llvm::createStripDeadPrototypesPass()); optPM.add(CreateMakeInternalFuncsStaticPass()); optPM.add(llvm::createGlobalDCEPass()); @@ -4430,3 +4440,386 @@ static llvm::Pass * CreateMakeInternalFuncsStaticPass() { return new MakeInternalFuncsStaticPass; } + + +/////////////////////////////////////////////////////////////////////////// +// PeepholePass + +class PeepholePass : public llvm::BasicBlockPass { +public: + PeepholePass(); + + const char *getPassName() const { return "Peephole Optimizations"; } + bool runOnBasicBlock(llvm::BasicBlock &BB); + + static char ID; +}; + +char PeepholePass::ID = 0; + +PeepholePass::PeepholePass() + : BasicBlockPass(ID) { +} + +using namespace llvm::PatternMatch; + +template +struct CastClassTypes_match { + Op_t Op; + const llvm::Type *fromType, *toType; + + CastClassTypes_match(const Op_t &OpMatch, const llvm::Type *f, + const llvm::Type *t) + : Op(OpMatch), fromType(f), toType(t) {} + + template + bool match(OpTy *V) { + if (llvm::Operator *O = llvm::dyn_cast(V)) + return (O->getOpcode() == Opcode && Op.match(O->getOperand(0)) && + O->getType() == toType && + O->getOperand(0)->getType() == fromType); + return false; + } +}; + +template +inline CastClassTypes_match +m_SExt8To16(const OpTy &Op) { + return CastClassTypes_match( + Op, + LLVMTypes::Int8VectorType, + LLVMTypes::Int16VectorType); +} + +template +inline CastClassTypes_match +m_ZExt8To16(const OpTy &Op) { + return CastClassTypes_match( + Op, + LLVMTypes::Int8VectorType, + LLVMTypes::Int16VectorType); +} + + +template +inline CastClassTypes_match +m_Trunc16To8(const OpTy &Op) { + return CastClassTypes_match( + Op, + LLVMTypes::Int16VectorType, + LLVMTypes::Int8VectorType); +} + +template +inline CastClassTypes_match +m_SExt16To32(const OpTy &Op) { + return CastClassTypes_match( + Op, + LLVMTypes::Int16VectorType, + LLVMTypes::Int32VectorType); +} + +template +inline CastClassTypes_match +m_ZExt16To32(const OpTy &Op) { + return CastClassTypes_match( + Op, + LLVMTypes::Int16VectorType, + LLVMTypes::Int32VectorType); +} + + +template +inline CastClassTypes_match +m_Trunc32To16(const OpTy &Op) { + return CastClassTypes_match( + Op, + LLVMTypes::Int32VectorType, + LLVMTypes::Int16VectorType); +} + +template +struct UDiv2_match { + Op_t Op; + + UDiv2_match(const Op_t &OpMatch) + : Op(OpMatch) {} + + template + bool match(OpTy *V) { + llvm::BinaryOperator *bop; + llvm::ConstantDataVector *cdv; + if ((bop = llvm::dyn_cast(V)) && + (cdv = llvm::dyn_cast(bop->getOperand(1))) && + cdv->getSplatValue() != NULL) { + const llvm::APInt &apInt = cdv->getUniqueInteger(); + + switch (bop->getOpcode()) { + case llvm::Instruction::UDiv: + // divide by 2 + return (apInt.isIntN(2) && Op.match(bop->getOperand(0))); + case llvm::Instruction::LShr: + // shift left by 1 + return (apInt.isIntN(1) && Op.match(bop->getOperand(0))); + default: + return false; + } + } + return false; + } +}; + +template +inline UDiv2_match +m_UDiv2(const V &v) { + return UDiv2_match(v); +} + +template +struct SDiv2_match { + Op_t Op; + + SDiv2_match(const Op_t &OpMatch) + : Op(OpMatch) {} + + template + bool match(OpTy *V) { + llvm::BinaryOperator *bop; + llvm::ConstantDataVector *cdv; + if ((bop = llvm::dyn_cast(V)) && + (cdv = llvm::dyn_cast(bop->getOperand(1))) && + cdv->getSplatValue() != NULL) { + const llvm::APInt &apInt = cdv->getUniqueInteger(); + + switch (bop->getOpcode()) { + case llvm::Instruction::SDiv: + // divide by 2 + return (apInt.isIntN(2) && Op.match(bop->getOperand(0))); + case llvm::Instruction::AShr: + // shift left by 1 + return (apInt.isIntN(1) && Op.match(bop->getOperand(0))); + default: + return false; + } + } + return false; + } +}; + +template +inline SDiv2_match +m_SDiv2(const V &v) { + return SDiv2_match(v); +} +// Returns true if the given function has a call to an intrinsic function +// in its definition. +static bool +lHasIntrinsicInDefinition(llvm::Function *func) { + llvm::Function::iterator bbiter = func->begin(); + for (; bbiter != func->end(); ++bbiter) { + for (llvm::BasicBlock::iterator institer = bbiter->begin(); + institer != bbiter->end(); ++institer) { + if (llvm::isa(institer)) + return true; + } + } + return false; +} + +static llvm::Instruction * +lGetBinaryIntrinsic(const char *name, llvm::Value *opa, llvm::Value *opb) { + llvm::Function *func = m->module->getFunction(name); + Assert(func != NULL); + + // Make sure that the definition of the llvm::Function has a call to an + // intrinsic function in its instructions; otherwise we will generate + // infinite loops where we "helpfully" turn the default implementations + // of target builtins like __avg_up_uint8 that are implemented with plain + // arithmetic ops into recursive calls to themselves. + if (lHasIntrinsicInDefinition(func)) + return lCallInst(func, opa, opb, name); + else + return NULL; +} + +////////////////////////////////////////////////// + +static llvm::Instruction * +lMatchAvgUpUInt8(llvm::Value *inst) { + // (unsigned int8)(((unsigned int16)a + (unsigned int16)b + 1)/2) + llvm::Value *opa, *opb; + const llvm::APInt *delta; + if (match(inst, m_Trunc16To8(m_UDiv2(m_CombineOr( + m_CombineOr( + m_Add(m_ZExt8To16(m_Value(opa)), + m_Add(m_ZExt8To16(m_Value(opb)), m_APInt(delta))), + m_Add(m_Add(m_ZExt8To16(m_Value(opa)), m_APInt(delta)), + m_ZExt8To16(m_Value(opb)))), + m_Add(m_Add(m_ZExt8To16(m_Value(opa)), m_ZExt8To16(m_Value(opb))), + m_APInt(delta))))))) { + if (delta->isIntN(1) == false) + return false; + + return lGetBinaryIntrinsic("__avg_up_uint8", opa, opb); + } + return NULL; +} + + +static llvm::Instruction * +lMatchAvgDownUInt8(llvm::Value *inst) { + // (unsigned int8)(((unsigned int16)a + (unsigned int16)b)/2) + llvm::Value *opa, *opb; + if (match(inst, m_Trunc16To8(m_UDiv2( + m_Add(m_ZExt8To16(m_Value(opa)), + m_ZExt8To16(m_Value(opb))))))) { + return lGetBinaryIntrinsic("__avg_down_uint8", opa, opb); + } + return NULL; +} + +static llvm::Instruction * +lMatchAvgUpUInt16(llvm::Value *inst) { + // (unsigned int16)(((unsigned int32)a + (unsigned int32)b + 1)/2) + llvm::Value *opa, *opb; + const llvm::APInt *delta; + if (match(inst, m_Trunc32To16(m_UDiv2(m_CombineOr( + m_CombineOr( + m_Add(m_ZExt16To32(m_Value(opa)), + m_Add(m_ZExt16To32(m_Value(opb)), m_APInt(delta))), + m_Add(m_Add(m_ZExt16To32(m_Value(opa)), m_APInt(delta)), + m_ZExt16To32(m_Value(opb)))), + m_Add(m_Add(m_ZExt16To32(m_Value(opa)), m_ZExt16To32(m_Value(opb))), + m_APInt(delta))))))) { + if (delta->isIntN(1) == false) + return false; + + return lGetBinaryIntrinsic("__avg_up_uint16", opa, opb); + } + return NULL; +} + + +static llvm::Instruction * +lMatchAvgDownUInt16(llvm::Value *inst) { + // (unsigned int16)(((unsigned int32)a + (unsigned int32)b)/2) + llvm::Value *opa, *opb; + if (match(inst, m_Trunc32To16(m_UDiv2( + m_Add(m_ZExt16To32(m_Value(opa)), + m_ZExt16To32(m_Value(opb))))))) { + return lGetBinaryIntrinsic("__avg_down_uint16", opa, opb); + } + return NULL; +} + + +static llvm::Instruction * +lMatchAvgUpInt8(llvm::Value *inst) { + // (int8)(((int16)a + (int16)b + 1)/2) + llvm::Value *opa, *opb; + const llvm::APInt *delta; + if (match(inst, m_Trunc16To8(m_SDiv2(m_CombineOr( + m_CombineOr( + m_Add(m_SExt8To16(m_Value(opa)), + m_Add(m_SExt8To16(m_Value(opb)), m_APInt(delta))), + m_Add(m_Add(m_SExt8To16(m_Value(opa)), m_APInt(delta)), + m_SExt8To16(m_Value(opb)))), + m_Add(m_Add(m_SExt8To16(m_Value(opa)), m_SExt8To16(m_Value(opb))), + m_APInt(delta))))))) { + if (delta->isIntN(1) == false) + return false; + + return lGetBinaryIntrinsic("__avg_up_int8", opa, opb); + } + return NULL; +} + + +static llvm::Instruction * +lMatchAvgDownInt8(llvm::Value *inst) { + // (int8)(((int16)a + (int16)b)/2) + llvm::Value *opa, *opb; + if (match(inst, m_Trunc16To8(m_SDiv2( + m_Add(m_SExt8To16(m_Value(opa)), + m_SExt8To16(m_Value(opb))))))) { + return lGetBinaryIntrinsic("__avg_down_int8", opa, opb); + } + return NULL; +} + +static llvm::Instruction * +lMatchAvgUpInt16(llvm::Value *inst) { + // (int16)(((int32)a + (int32)b + 1)/2) + llvm::Value *opa, *opb; + const llvm::APInt *delta; + if (match(inst, m_Trunc32To16(m_SDiv2(m_CombineOr( + m_CombineOr( + m_Add(m_SExt16To32(m_Value(opa)), + m_Add(m_SExt16To32(m_Value(opb)), m_APInt(delta))), + m_Add(m_Add(m_SExt16To32(m_Value(opa)), m_APInt(delta)), + m_SExt16To32(m_Value(opb)))), + m_Add(m_Add(m_SExt16To32(m_Value(opa)), m_SExt16To32(m_Value(opb))), + m_APInt(delta))))))) { + if (delta->isIntN(1) == false) + return false; + + return lGetBinaryIntrinsic("__avg_up_int16", opa, opb); + } + return NULL; +} + +static llvm::Instruction * +lMatchAvgDownInt16(llvm::Value *inst) { + // (int16)(((int32)a + (int32)b)/2) + llvm::Value *opa, *opb; + if (match(inst, m_Trunc32To16(m_SDiv2( + m_Add(m_SExt16To32(m_Value(opa)), + m_SExt16To32(m_Value(opb))))))) { + return lGetBinaryIntrinsic("__avg_down_int16", opa, opb); + } + return NULL; +} + +bool +PeepholePass::runOnBasicBlock(llvm::BasicBlock &bb) { + DEBUG_START_PASS("PeepholePass"); + + bool modifiedAny = false; + restart: + for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) { + llvm::Instruction *inst = &*iter; + + llvm::Instruction *builtinCall = NULL; + if (!builtinCall) + builtinCall = lMatchAvgUpUInt8(inst); + if (!builtinCall) + builtinCall = lMatchAvgUpUInt16(inst); + if (!builtinCall) + builtinCall = lMatchAvgDownUInt8(inst); + if (!builtinCall) + builtinCall = lMatchAvgDownUInt16(inst); + if (!builtinCall) + builtinCall = lMatchAvgUpInt8(inst); + if (!builtinCall) + builtinCall = lMatchAvgUpInt16(inst); + if (!builtinCall) + builtinCall = lMatchAvgDownInt8(inst); + if (!builtinCall) + builtinCall = lMatchAvgDownInt16(inst); + + if (builtinCall != NULL) { + llvm::ReplaceInstWithInst(inst, builtinCall); + modifiedAny = true; + goto restart; + } + } + + DEBUG_END_PASS("PeepholePass"); + + return modifiedAny; +} + +static llvm::Pass * +CreatePeepholePass() { + return new PeepholePass; +}