Select optimization for LLVM 3.3
This commit is contained in:
202
opt.cpp
202
opt.cpp
@@ -127,6 +127,8 @@ static llvm::Pass *CreateDebugPass(char * output);
|
||||
|
||||
static llvm::Pass *CreateReplaceStdlibShiftPass();
|
||||
|
||||
static llvm::Pass *CreateFixBooleanSelectPass();
|
||||
|
||||
#define DEBUG_START_PASS(NAME) \
|
||||
if (g->debugPrint && \
|
||||
(getenv("FUNC") == NULL || \
|
||||
@@ -659,6 +661,9 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
optPM.add(CreateMakeInternalFuncsStaticPass());
|
||||
optPM.add(llvm::createGlobalDCEPass());
|
||||
optPM.add(llvm::createConstantMergePass());
|
||||
|
||||
// Should be the last
|
||||
optPM.add(CreateFixBooleanSelectPass(), 400);
|
||||
}
|
||||
|
||||
// Finish up by making sure we didn't mess anything up in the IR along
|
||||
@@ -670,6 +675,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
printf("\n*****\nFINAL OUTPUT\n*****\n");
|
||||
module->dump();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -1022,12 +1028,12 @@ InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) {
|
||||
if (trunc != NULL) {
|
||||
// Convert trunc({sext,zext}(i1 vector)) -> (i1 vector)
|
||||
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value);
|
||||
if (sext &&
|
||||
if (sext &&
|
||||
sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
|
||||
return sext->getOperand(0);
|
||||
|
||||
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value);
|
||||
if (zext &&
|
||||
if (zext &&
|
||||
zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
|
||||
return zext->getOperand(0);
|
||||
}
|
||||
@@ -1853,7 +1859,7 @@ lIs32BitSafeHelper(llvm::Value *v) {
|
||||
// handle Adds, SExts, Constant Vectors
|
||||
if (llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v)) {
|
||||
if (bop->getOpcode() == llvm::Instruction::Add) {
|
||||
return lIs32BitSafeHelper(bop->getOperand(0))
|
||||
return lIs32BitSafeHelper(bop->getOperand(0))
|
||||
&& lIs32BitSafeHelper(bop->getOperand(1));
|
||||
}
|
||||
return false;
|
||||
@@ -4961,7 +4967,7 @@ bool
|
||||
ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
DEBUG_START_PASS("ReplaceStdlibShiftPass");
|
||||
bool modifiedAny = false;
|
||||
|
||||
|
||||
llvm::Function *shifts[6];
|
||||
shifts[0] = m->module->getFunction("__shift_i8");
|
||||
shifts[1] = m->module->getFunction("__shift_i16");
|
||||
@@ -4992,19 +4998,19 @@ ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
}
|
||||
llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals);
|
||||
llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType());
|
||||
llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec,
|
||||
llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec,
|
||||
shuffleIdxs, "vecShift", ci);
|
||||
ci->replaceAllUsesWith(shuffle);
|
||||
modifiedAny = true;
|
||||
delete [] shuffleVals;
|
||||
} else {
|
||||
PerformanceWarning(SourcePos(), "Stdlib shift() called without constant shift amount.");
|
||||
PerformanceWarning(SourcePos(), "Stdlib shift() called without constant shift amount.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
DEBUG_END_PASS("ReplaceStdlibShiftPass");
|
||||
|
||||
return modifiedAny;
|
||||
@@ -5015,3 +5021,185 @@ static llvm::Pass *
|
||||
CreateReplaceStdlibShiftPass() {
|
||||
return new ReplaceStdlibShiftPass();
|
||||
}
|
||||
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FixBooleanSelect
|
||||
//
|
||||
// The problem is that in LLVM 3.3, optimizer doesn't like
|
||||
// the following instruction sequence:
|
||||
// %cmp = fcmp olt <8 x float> %a, %b
|
||||
// %sext_cmp = sext <8 x i1> %cmp to <8 x i32>
|
||||
// %new_mask = and <8 x i32> %sext_cmp, %mask
|
||||
// and optimizes it to the following:
|
||||
// %cmp = fcmp olt <8 x float> %a, %b
|
||||
// %cond = select <8 x i1> %cmp, <8 x i32> %mask, <8 x i32> zeroinitializer
|
||||
//
|
||||
// It wouldn't be a problem if codegen produced good code for it. But it
|
||||
// doesn't, especially for vectors larger than native vectors.
|
||||
//
|
||||
// This optimization reverts this pattern and should be the last one before
|
||||
// code gen.
|
||||
//
|
||||
// Note that this problem was introduced in LLVM 3.3. But in LLVM 3.4 it was
|
||||
// fixed. See commit r194542.
|
||||
//
|
||||
// After LLVM 3.3 this optimization should probably stay for experimental
|
||||
// purposes and code should be compared with and without this optimization from
|
||||
// time to time to make sure that LLVM does right thing.
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
class FixBooleanSelectPass : public llvm::FunctionPass {
|
||||
public:
|
||||
static char ID;
|
||||
FixBooleanSelectPass() :FunctionPass(ID) {}
|
||||
|
||||
const char *getPassName() const { return "Resolve \"replace extract insert chains\""; }
|
||||
bool runOnFunction(llvm::Function &F);
|
||||
|
||||
private:
|
||||
llvm::Instruction* fixSelect(llvm::SelectInst* sel, llvm::SExtInst* sext);
|
||||
};
|
||||
|
||||
char FixBooleanSelectPass::ID = 0;
|
||||
|
||||
llvm::Instruction* FixBooleanSelectPass::fixSelect(llvm::SelectInst* sel, llvm::SExtInst* sext) {
|
||||
// Select instruction result type and its integer equivalent
|
||||
llvm::VectorType *orig_type = llvm::dyn_cast<llvm::VectorType>(sel->getType());
|
||||
llvm::VectorType *int_type = llvm::VectorType::getInteger(orig_type);
|
||||
|
||||
// Result value and optional pointer to instruction to delete
|
||||
llvm::Instruction *result = 0, *optional_to_delete = 0;
|
||||
|
||||
// It can be vector of integers or vector of floating point values.
|
||||
if (orig_type->getElementType()->isIntegerTy()) {
|
||||
// Generate sext+and, remove select.
|
||||
result = llvm::BinaryOperator::CreateAnd(sext, sel->getTrueValue(), "and_mask", sel);
|
||||
} else {
|
||||
llvm::BitCastInst* bc = llvm::dyn_cast<llvm::BitCastInst>(sel->getTrueValue());
|
||||
|
||||
if (bc && bc->hasOneUse() && bc->getSrcTy()->isIntOrIntVectorTy() && bc->getSrcTy()->isVectorTy() &&
|
||||
llvm::isa<llvm::Instruction>(bc->getOperand(0)) &&
|
||||
llvm::dyn_cast<llvm::Instruction>(bc->getOperand(0))->getParent() == sel->getParent()) {
|
||||
// Bitcast is casting form integer type, it's operand is instruction, which is located in the same basic block (otherwise it's unsafe to use it).
|
||||
// bitcast+select => sext+and+bicast
|
||||
// Create and
|
||||
llvm::BinaryOperator* and_inst = llvm::BinaryOperator::CreateAnd(sext, bc->getOperand(0), "and_mask", sel);
|
||||
// Bitcast back to original type
|
||||
result = new llvm::BitCastInst(and_inst, sel->getType(), "bitcast_mask_out", sel);
|
||||
// Original bitcast will be removed
|
||||
optional_to_delete = bc;
|
||||
} else {
|
||||
// General case: select => bitcast+sext+and+bitcast
|
||||
// Bitcast
|
||||
llvm::BitCastInst* bc_in = new llvm::BitCastInst(sel->getTrueValue(), int_type, "bitcast_mask_in", sel);
|
||||
// And
|
||||
llvm::BinaryOperator* and_inst = llvm::BinaryOperator::CreateAnd(sext, bc_in, "and_mask", sel);
|
||||
// Bitcast back to original type
|
||||
result = new llvm::BitCastInst(and_inst, sel->getType(), "bitcast_mask_out", sel);
|
||||
}
|
||||
}
|
||||
|
||||
// Done, finalize.
|
||||
sel->replaceAllUsesWith(result);
|
||||
sel->eraseFromParent();
|
||||
if (optional_to_delete) {
|
||||
optional_to_delete->eraseFromParent();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool
|
||||
FixBooleanSelectPass::runOnFunction(llvm::Function &F) {
|
||||
bool modifiedAny = false;
|
||||
|
||||
// LLVM 3.3 only
|
||||
#if defined(LLVM_3_3)
|
||||
|
||||
for (llvm::Function::iterator I = F.begin(), E = F.end();
|
||||
I != E; ++I) {
|
||||
llvm::BasicBlock* bb = &*I;
|
||||
for (llvm::BasicBlock::iterator iter = bb->begin(), e = bb->end(); iter != e; ++iter) {
|
||||
llvm::Instruction *inst = &*iter;
|
||||
|
||||
llvm::CmpInst *cmp = llvm::dyn_cast<llvm::CmpInst>(inst);
|
||||
|
||||
if (cmp &&
|
||||
cmp->getType()->isVectorTy() &&
|
||||
cmp->getType()->getVectorElementType()->isIntegerTy(1)) {
|
||||
|
||||
// Search for select instruction uses.
|
||||
int selects = 0;
|
||||
llvm::VectorType* sext_type = 0;
|
||||
for (llvm::Instruction::use_iterator it=cmp->use_begin(); it!=cmp->use_end(); ++it ) {
|
||||
llvm::SelectInst* sel = llvm::dyn_cast<llvm::SelectInst>(*it);
|
||||
if (sel &&
|
||||
sel->getType()->isVectorTy() &&
|
||||
sel->getType()->getScalarSizeInBits() > 1) {
|
||||
selects++;
|
||||
// We pick the first one, but typical case when all select types are the same.
|
||||
sext_type = llvm::dyn_cast<llvm::VectorType>(sel->getType());
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (selects == 0) {
|
||||
continue;
|
||||
}
|
||||
// Get an integer equivalent, if it's not yet an integer.
|
||||
sext_type = llvm::VectorType::getInteger(sext_type);
|
||||
|
||||
// Do transformation
|
||||
llvm::BasicBlock::iterator iter_copy=iter;
|
||||
llvm::Instruction* next_inst = &*(++iter_copy);
|
||||
// Create or reuse sext
|
||||
llvm::SExtInst* sext = llvm::dyn_cast<llvm::SExtInst>(next_inst);
|
||||
if (sext &&
|
||||
sext->getOperand(0) == cmp &&
|
||||
sext->getDestTy() == sext_type) {
|
||||
// This sext can be reused
|
||||
} else {
|
||||
if (next_inst) {
|
||||
sext = new llvm::SExtInst(cmp, sext_type, "sext_cmp", next_inst);
|
||||
} else {
|
||||
sext = new llvm::SExtInst(cmp, sext_type, "sext_cmp", bb);
|
||||
}
|
||||
}
|
||||
|
||||
// Walk and fix selects
|
||||
std::vector<llvm::SelectInst*> sel_uses;
|
||||
for (llvm::Instruction::use_iterator it=cmp->use_begin(); it!=cmp->use_end(); ++it) {
|
||||
llvm::SelectInst* sel = llvm::dyn_cast<llvm::SelectInst>(*it);
|
||||
if (sel &&
|
||||
sel->getType()->getScalarSizeInBits() == sext_type->getScalarSizeInBits()) {
|
||||
|
||||
// Check that second operand is zero.
|
||||
llvm::Constant* false_cond = llvm::dyn_cast<llvm::Constant>(sel->getFalseValue());
|
||||
if (false_cond &&
|
||||
false_cond->isZeroValue()) {
|
||||
sel_uses.push_back(sel);
|
||||
modifiedAny = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i=0; i<sel_uses.size(); i++) {
|
||||
fixSelect(sel_uses[i], sext);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // LLVM 3.3
|
||||
|
||||
return modifiedAny;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Pass *
|
||||
CreateFixBooleanSelectPass() {
|
||||
return new FixBooleanSelectPass();
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user