Add peephole optimization to match int8/int16 averages.
Match the following patterns in IR, turning them into target-specific intrinsics (e.g. PAVGB on x86) when possible. (unsigned int8)(((unsigned int16)a + (unsigned int16)b + 1)/2) (unsigned int8)(((unsigned int16)a + (unsigned int16)b)/2) (unsigned int16)(((unsigned int32)a + (unsigned int32)b + 1)/2) (unsigned int16)(((unsigned int32)a + (unsigned int32)b)/2) (int8)(((int16)a + (int16)b + 1)/2) (int8)(((int16)a + (int16)b)/2) (int16)(((int32)a + (int32)b + 1)/2) (int16)(((int32)a + (int32)b)/2)
This commit is contained in:
393
opt.cpp
393
opt.cpp
@@ -84,6 +84,7 @@
|
||||
#include <llvm/Analysis/Passes.h>
|
||||
#include <llvm/Support/raw_ostream.h>
|
||||
#include <llvm/DebugInfo.h>
|
||||
#include <llvm/Support/PatternMatch.h>
|
||||
#include <llvm/Support/Dwarf.h>
|
||||
#ifdef ISPC_IS_LINUX
|
||||
#include <alloca.h>
|
||||
@@ -103,6 +104,7 @@
|
||||
|
||||
static llvm::Pass *CreateIntrinsicsOptPass();
|
||||
static llvm::Pass *CreateInstructionSimplifyPass();
|
||||
static llvm::Pass *CreatePeepholePass();
|
||||
|
||||
static llvm::Pass *CreateImproveMemoryOpsPass();
|
||||
static llvm::Pass *CreateGatherCoalescePass();
|
||||
@@ -459,6 +461,9 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
optPM.add(llvm::createDeadInstEliminationPass());
|
||||
optPM.add(llvm::createCFGSimplificationPass());
|
||||
|
||||
optPM.add(llvm::createPromoteMemoryToRegisterPass());
|
||||
optPM.add(llvm::createAggressiveDCEPass());
|
||||
|
||||
if (g->opt.disableGatherScatterOptimizations == false &&
|
||||
g->target->getVectorWidth() > 1) {
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
@@ -500,6 +505,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
// InstructionCombiningPass. See r184459 for details.
|
||||
optPM.add(llvm::createSimplifyLibCallsPass());
|
||||
#endif
|
||||
optPM.add(llvm::createAggressiveDCEPass());
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(llvm::createJumpThreadingPass());
|
||||
optPM.add(llvm::createCFGSimplificationPass());
|
||||
@@ -539,6 +545,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
|
||||
optPM.add(llvm::createIPSCCPPass());
|
||||
optPM.add(llvm::createDeadArgEliminationPass());
|
||||
optPM.add(llvm::createAggressiveDCEPass());
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(llvm::createCFGSimplificationPass());
|
||||
|
||||
@@ -581,6 +588,9 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
optPM.add(llvm::createCFGSimplificationPass());
|
||||
optPM.add(llvm::createInstructionCombiningPass());
|
||||
optPM.add(CreateInstructionSimplifyPass());
|
||||
optPM.add(CreatePeepholePass());
|
||||
optPM.add(llvm::createFunctionInliningPass());
|
||||
optPM.add(llvm::createAggressiveDCEPass());
|
||||
optPM.add(llvm::createStripDeadPrototypesPass());
|
||||
optPM.add(CreateMakeInternalFuncsStaticPass());
|
||||
optPM.add(llvm::createGlobalDCEPass());
|
||||
@@ -4430,3 +4440,386 @@ static llvm::Pass *
|
||||
CreateMakeInternalFuncsStaticPass() {
|
||||
return new MakeInternalFuncsStaticPass;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// PeepholePass
|
||||
|
||||
class PeepholePass : public llvm::BasicBlockPass {
|
||||
public:
|
||||
PeepholePass();
|
||||
|
||||
const char *getPassName() const { return "Peephole Optimizations"; }
|
||||
bool runOnBasicBlock(llvm::BasicBlock &BB);
|
||||
|
||||
static char ID;
|
||||
};
|
||||
|
||||
char PeepholePass::ID = 0;
|
||||
|
||||
PeepholePass::PeepholePass()
|
||||
: BasicBlockPass(ID) {
|
||||
}
|
||||
|
||||
using namespace llvm::PatternMatch;
|
||||
|
||||
template<typename Op_t, unsigned Opcode>
|
||||
struct CastClassTypes_match {
|
||||
Op_t Op;
|
||||
const llvm::Type *fromType, *toType;
|
||||
|
||||
CastClassTypes_match(const Op_t &OpMatch, const llvm::Type *f,
|
||||
const llvm::Type *t)
|
||||
: Op(OpMatch), fromType(f), toType(t) {}
|
||||
|
||||
template<typename OpTy>
|
||||
bool match(OpTy *V) {
|
||||
if (llvm::Operator *O = llvm::dyn_cast<llvm::Operator>(V))
|
||||
return (O->getOpcode() == Opcode && Op.match(O->getOperand(0)) &&
|
||||
O->getType() == toType &&
|
||||
O->getOperand(0)->getType() == fromType);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename OpTy>
|
||||
inline CastClassTypes_match<OpTy, llvm::Instruction::SExt>
|
||||
m_SExt8To16(const OpTy &Op) {
|
||||
return CastClassTypes_match<OpTy, llvm::Instruction::SExt>(
|
||||
Op,
|
||||
LLVMTypes::Int8VectorType,
|
||||
LLVMTypes::Int16VectorType);
|
||||
}
|
||||
|
||||
template<typename OpTy>
|
||||
inline CastClassTypes_match<OpTy, llvm::Instruction::ZExt>
|
||||
m_ZExt8To16(const OpTy &Op) {
|
||||
return CastClassTypes_match<OpTy, llvm::Instruction::ZExt>(
|
||||
Op,
|
||||
LLVMTypes::Int8VectorType,
|
||||
LLVMTypes::Int16VectorType);
|
||||
}
|
||||
|
||||
|
||||
template<typename OpTy>
|
||||
inline CastClassTypes_match<OpTy, llvm::Instruction::Trunc>
|
||||
m_Trunc16To8(const OpTy &Op) {
|
||||
return CastClassTypes_match<OpTy, llvm::Instruction::Trunc>(
|
||||
Op,
|
||||
LLVMTypes::Int16VectorType,
|
||||
LLVMTypes::Int8VectorType);
|
||||
}
|
||||
|
||||
template<typename OpTy>
|
||||
inline CastClassTypes_match<OpTy, llvm::Instruction::SExt>
|
||||
m_SExt16To32(const OpTy &Op) {
|
||||
return CastClassTypes_match<OpTy, llvm::Instruction::SExt>(
|
||||
Op,
|
||||
LLVMTypes::Int16VectorType,
|
||||
LLVMTypes::Int32VectorType);
|
||||
}
|
||||
|
||||
template<typename OpTy>
|
||||
inline CastClassTypes_match<OpTy, llvm::Instruction::ZExt>
|
||||
m_ZExt16To32(const OpTy &Op) {
|
||||
return CastClassTypes_match<OpTy, llvm::Instruction::ZExt>(
|
||||
Op,
|
||||
LLVMTypes::Int16VectorType,
|
||||
LLVMTypes::Int32VectorType);
|
||||
}
|
||||
|
||||
|
||||
template<typename OpTy>
|
||||
inline CastClassTypes_match<OpTy, llvm::Instruction::Trunc>
|
||||
m_Trunc32To16(const OpTy &Op) {
|
||||
return CastClassTypes_match<OpTy, llvm::Instruction::Trunc>(
|
||||
Op,
|
||||
LLVMTypes::Int32VectorType,
|
||||
LLVMTypes::Int16VectorType);
|
||||
}
|
||||
|
||||
template<typename Op_t>
|
||||
struct UDiv2_match {
|
||||
Op_t Op;
|
||||
|
||||
UDiv2_match(const Op_t &OpMatch)
|
||||
: Op(OpMatch) {}
|
||||
|
||||
template<typename OpTy>
|
||||
bool match(OpTy *V) {
|
||||
llvm::BinaryOperator *bop;
|
||||
llvm::ConstantDataVector *cdv;
|
||||
if ((bop = llvm::dyn_cast<llvm::BinaryOperator>(V)) &&
|
||||
(cdv = llvm::dyn_cast<llvm::ConstantDataVector>(bop->getOperand(1))) &&
|
||||
cdv->getSplatValue() != NULL) {
|
||||
const llvm::APInt &apInt = cdv->getUniqueInteger();
|
||||
|
||||
switch (bop->getOpcode()) {
|
||||
case llvm::Instruction::UDiv:
|
||||
// divide by 2
|
||||
return (apInt.isIntN(2) && Op.match(bop->getOperand(0)));
|
||||
case llvm::Instruction::LShr:
|
||||
// shift left by 1
|
||||
return (apInt.isIntN(1) && Op.match(bop->getOperand(0)));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename V>
|
||||
inline UDiv2_match<V>
|
||||
m_UDiv2(const V &v) {
|
||||
return UDiv2_match<V>(v);
|
||||
}
|
||||
|
||||
template<typename Op_t>
|
||||
struct SDiv2_match {
|
||||
Op_t Op;
|
||||
|
||||
SDiv2_match(const Op_t &OpMatch)
|
||||
: Op(OpMatch) {}
|
||||
|
||||
template<typename OpTy>
|
||||
bool match(OpTy *V) {
|
||||
llvm::BinaryOperator *bop;
|
||||
llvm::ConstantDataVector *cdv;
|
||||
if ((bop = llvm::dyn_cast<llvm::BinaryOperator>(V)) &&
|
||||
(cdv = llvm::dyn_cast<llvm::ConstantDataVector>(bop->getOperand(1))) &&
|
||||
cdv->getSplatValue() != NULL) {
|
||||
const llvm::APInt &apInt = cdv->getUniqueInteger();
|
||||
|
||||
switch (bop->getOpcode()) {
|
||||
case llvm::Instruction::SDiv:
|
||||
// divide by 2
|
||||
return (apInt.isIntN(2) && Op.match(bop->getOperand(0)));
|
||||
case llvm::Instruction::AShr:
|
||||
// shift left by 1
|
||||
return (apInt.isIntN(1) && Op.match(bop->getOperand(0)));
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename V>
|
||||
inline SDiv2_match<V>
|
||||
m_SDiv2(const V &v) {
|
||||
return SDiv2_match<V>(v);
|
||||
}
|
||||
// Returns true if the given function has a call to an intrinsic function
|
||||
// in its definition.
|
||||
static bool
|
||||
lHasIntrinsicInDefinition(llvm::Function *func) {
|
||||
llvm::Function::iterator bbiter = func->begin();
|
||||
for (; bbiter != func->end(); ++bbiter) {
|
||||
for (llvm::BasicBlock::iterator institer = bbiter->begin();
|
||||
institer != bbiter->end(); ++institer) {
|
||||
if (llvm::isa<llvm::IntrinsicInst>(institer))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static llvm::Instruction *
|
||||
lGetBinaryIntrinsic(const char *name, llvm::Value *opa, llvm::Value *opb) {
|
||||
llvm::Function *func = m->module->getFunction(name);
|
||||
Assert(func != NULL);
|
||||
|
||||
// Make sure that the definition of the llvm::Function has a call to an
|
||||
// intrinsic function in its instructions; otherwise we will generate
|
||||
// infinite loops where we "helpfully" turn the default implementations
|
||||
// of target builtins like __avg_up_uint8 that are implemented with plain
|
||||
// arithmetic ops into recursive calls to themselves.
|
||||
if (lHasIntrinsicInDefinition(func))
|
||||
return lCallInst(func, opa, opb, name);
|
||||
else
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgUpUInt8(llvm::Value *inst) {
|
||||
// (unsigned int8)(((unsigned int16)a + (unsigned int16)b + 1)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
const llvm::APInt *delta;
|
||||
if (match(inst, m_Trunc16To8(m_UDiv2(m_CombineOr(
|
||||
m_CombineOr(
|
||||
m_Add(m_ZExt8To16(m_Value(opa)),
|
||||
m_Add(m_ZExt8To16(m_Value(opb)), m_APInt(delta))),
|
||||
m_Add(m_Add(m_ZExt8To16(m_Value(opa)), m_APInt(delta)),
|
||||
m_ZExt8To16(m_Value(opb)))),
|
||||
m_Add(m_Add(m_ZExt8To16(m_Value(opa)), m_ZExt8To16(m_Value(opb))),
|
||||
m_APInt(delta))))))) {
|
||||
if (delta->isIntN(1) == false)
|
||||
return false;
|
||||
|
||||
return lGetBinaryIntrinsic("__avg_up_uint8", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgDownUInt8(llvm::Value *inst) {
|
||||
// (unsigned int8)(((unsigned int16)a + (unsigned int16)b)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
if (match(inst, m_Trunc16To8(m_UDiv2(
|
||||
m_Add(m_ZExt8To16(m_Value(opa)),
|
||||
m_ZExt8To16(m_Value(opb))))))) {
|
||||
return lGetBinaryIntrinsic("__avg_down_uint8", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgUpUInt16(llvm::Value *inst) {
|
||||
// (unsigned int16)(((unsigned int32)a + (unsigned int32)b + 1)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
const llvm::APInt *delta;
|
||||
if (match(inst, m_Trunc32To16(m_UDiv2(m_CombineOr(
|
||||
m_CombineOr(
|
||||
m_Add(m_ZExt16To32(m_Value(opa)),
|
||||
m_Add(m_ZExt16To32(m_Value(opb)), m_APInt(delta))),
|
||||
m_Add(m_Add(m_ZExt16To32(m_Value(opa)), m_APInt(delta)),
|
||||
m_ZExt16To32(m_Value(opb)))),
|
||||
m_Add(m_Add(m_ZExt16To32(m_Value(opa)), m_ZExt16To32(m_Value(opb))),
|
||||
m_APInt(delta))))))) {
|
||||
if (delta->isIntN(1) == false)
|
||||
return false;
|
||||
|
||||
return lGetBinaryIntrinsic("__avg_up_uint16", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgDownUInt16(llvm::Value *inst) {
|
||||
// (unsigned int16)(((unsigned int32)a + (unsigned int32)b)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
if (match(inst, m_Trunc32To16(m_UDiv2(
|
||||
m_Add(m_ZExt16To32(m_Value(opa)),
|
||||
m_ZExt16To32(m_Value(opb))))))) {
|
||||
return lGetBinaryIntrinsic("__avg_down_uint16", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgUpInt8(llvm::Value *inst) {
|
||||
// (int8)(((int16)a + (int16)b + 1)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
const llvm::APInt *delta;
|
||||
if (match(inst, m_Trunc16To8(m_SDiv2(m_CombineOr(
|
||||
m_CombineOr(
|
||||
m_Add(m_SExt8To16(m_Value(opa)),
|
||||
m_Add(m_SExt8To16(m_Value(opb)), m_APInt(delta))),
|
||||
m_Add(m_Add(m_SExt8To16(m_Value(opa)), m_APInt(delta)),
|
||||
m_SExt8To16(m_Value(opb)))),
|
||||
m_Add(m_Add(m_SExt8To16(m_Value(opa)), m_SExt8To16(m_Value(opb))),
|
||||
m_APInt(delta))))))) {
|
||||
if (delta->isIntN(1) == false)
|
||||
return false;
|
||||
|
||||
return lGetBinaryIntrinsic("__avg_up_int8", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgDownInt8(llvm::Value *inst) {
|
||||
// (int8)(((int16)a + (int16)b)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
if (match(inst, m_Trunc16To8(m_SDiv2(
|
||||
m_Add(m_SExt8To16(m_Value(opa)),
|
||||
m_SExt8To16(m_Value(opb))))))) {
|
||||
return lGetBinaryIntrinsic("__avg_down_int8", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgUpInt16(llvm::Value *inst) {
|
||||
// (int16)(((int32)a + (int32)b + 1)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
const llvm::APInt *delta;
|
||||
if (match(inst, m_Trunc32To16(m_SDiv2(m_CombineOr(
|
||||
m_CombineOr(
|
||||
m_Add(m_SExt16To32(m_Value(opa)),
|
||||
m_Add(m_SExt16To32(m_Value(opb)), m_APInt(delta))),
|
||||
m_Add(m_Add(m_SExt16To32(m_Value(opa)), m_APInt(delta)),
|
||||
m_SExt16To32(m_Value(opb)))),
|
||||
m_Add(m_Add(m_SExt16To32(m_Value(opa)), m_SExt16To32(m_Value(opb))),
|
||||
m_APInt(delta))))))) {
|
||||
if (delta->isIntN(1) == false)
|
||||
return false;
|
||||
|
||||
return lGetBinaryIntrinsic("__avg_up_int16", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static llvm::Instruction *
|
||||
lMatchAvgDownInt16(llvm::Value *inst) {
|
||||
// (int16)(((int32)a + (int32)b)/2)
|
||||
llvm::Value *opa, *opb;
|
||||
if (match(inst, m_Trunc32To16(m_SDiv2(
|
||||
m_Add(m_SExt16To32(m_Value(opa)),
|
||||
m_SExt16To32(m_Value(opb))))))) {
|
||||
return lGetBinaryIntrinsic("__avg_down_int16", opa, opb);
|
||||
}
|
||||
return NULL;
|
||||
}
|
||||
|
||||
bool
|
||||
PeepholePass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
DEBUG_START_PASS("PeepholePass");
|
||||
|
||||
bool modifiedAny = false;
|
||||
restart:
|
||||
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
|
||||
llvm::Instruction *inst = &*iter;
|
||||
|
||||
llvm::Instruction *builtinCall = NULL;
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgUpUInt8(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgUpUInt16(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgDownUInt8(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgDownUInt16(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgUpInt8(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgUpInt16(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgDownInt8(inst);
|
||||
if (!builtinCall)
|
||||
builtinCall = lMatchAvgDownInt16(inst);
|
||||
|
||||
if (builtinCall != NULL) {
|
||||
llvm::ReplaceInstWithInst(inst, builtinCall);
|
||||
modifiedAny = true;
|
||||
goto restart;
|
||||
}
|
||||
}
|
||||
|
||||
DEBUG_END_PASS("PeepholePass");
|
||||
|
||||
return modifiedAny;
|
||||
}
|
||||
|
||||
static llvm::Pass *
|
||||
CreatePeepholePass() {
|
||||
return new PeepholePass;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user