Merge pull request #671 from dbabokin/select

Select optimization
This commit is contained in:
jbrodman
2013-12-02 08:43:30 -08:00
2 changed files with 198 additions and 8 deletions

View File

@@ -635,6 +635,7 @@ import platform
import smtplib import smtplib
import datetime import datetime
import copy import copy
import multiprocessing
from email.MIMEMultipart import MIMEMultipart from email.MIMEMultipart import MIMEMultipart
from email.MIMEBase import MIMEBase from email.MIMEBase import MIMEBase
from email.mime.text import MIMEText from email.mime.text import MIMEText
@@ -663,13 +664,14 @@ if __name__ == '__main__':
"Try to build compiler with all LLVM\n\talloy.py -r --only=build\n" + "Try to build compiler with all LLVM\n\talloy.py -r --only=build\n" +
"Performance validation run with 10 runs of each test and comparing to branch 'old'\n\talloy.py -r --only=performance --compare-with=old --number=10\n" + "Performance validation run with 10 runs of each test and comparing to branch 'old'\n\talloy.py -r --only=performance --compare-with=old --number=10\n" +
"Validation run. Update fail_db.txt with new fails, send results to my@my.com\n\talloy.py -r --update-errors=F --notify='my@my.com'\n") "Validation run. Update fail_db.txt with new fails, send results to my@my.com\n\talloy.py -r --update-errors=F --notify='my@my.com'\n")
num_threads="%s" % multiprocessing.cpu_count()
parser = MyParser(usage="Usage: alloy.py -r/-b [options]", epilog=examples) parser = MyParser(usage="Usage: alloy.py -r/-b [options]", epilog=examples)
parser.add_option('-b', '--build-llvm', dest='build_llvm', parser.add_option('-b', '--build-llvm', dest='build_llvm',
help='ask to build LLVM', default=False, action="store_true") help='ask to build LLVM', default=False, action="store_true")
parser.add_option('-r', '--run', dest='validation_run', parser.add_option('-r', '--run', dest='validation_run',
help='ask for validation run', default=False, action="store_true") help='ask for validation run', default=False, action="store_true")
parser.add_option('-j', dest='speed', parser.add_option('-j', dest='speed',
help='set -j for make', default="8") help='set -j for make', default=num_threads)
# options for activity "build LLVM" # options for activity "build LLVM"
llvm_group = OptionGroup(parser, "Options for building LLVM", llvm_group = OptionGroup(parser, "Options for building LLVM",
"These options must be used with -b option.") "These options must be used with -b option.")

202
opt.cpp
View File

@@ -127,6 +127,8 @@ static llvm::Pass *CreateDebugPass(char * output);
static llvm::Pass *CreateReplaceStdlibShiftPass(); static llvm::Pass *CreateReplaceStdlibShiftPass();
static llvm::Pass *CreateFixBooleanSelectPass();
#define DEBUG_START_PASS(NAME) \ #define DEBUG_START_PASS(NAME) \
if (g->debugPrint && \ if (g->debugPrint && \
(getenv("FUNC") == NULL || \ (getenv("FUNC") == NULL || \
@@ -659,6 +661,9 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(CreateMakeInternalFuncsStaticPass()); optPM.add(CreateMakeInternalFuncsStaticPass());
optPM.add(llvm::createGlobalDCEPass()); optPM.add(llvm::createGlobalDCEPass());
optPM.add(llvm::createConstantMergePass()); optPM.add(llvm::createConstantMergePass());
// Should be the last
optPM.add(CreateFixBooleanSelectPass(), 400);
} }
// Finish up by making sure we didn't mess anything up in the IR along // Finish up by making sure we didn't mess anything up in the IR along
@@ -670,6 +675,7 @@ Optimize(llvm::Module *module, int optLevel) {
printf("\n*****\nFINAL OUTPUT\n*****\n"); printf("\n*****\nFINAL OUTPUT\n*****\n");
module->dump(); module->dump();
} }
} }
@@ -1022,12 +1028,12 @@ InstructionSimplifyPass::simplifyBoolVec(llvm::Value *value) {
if (trunc != NULL) { if (trunc != NULL) {
// Convert trunc({sext,zext}(i1 vector)) -> (i1 vector) // Convert trunc({sext,zext}(i1 vector)) -> (i1 vector)
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value); llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(value);
if (sext && if (sext &&
sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType) sext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
return sext->getOperand(0); return sext->getOperand(0);
llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value); llvm::ZExtInst *zext = llvm::dyn_cast<llvm::ZExtInst>(value);
if (zext && if (zext &&
zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType) zext->getOperand(0)->getType() == LLVMTypes::Int1VectorType)
return zext->getOperand(0); return zext->getOperand(0);
} }
@@ -1853,7 +1859,7 @@ lIs32BitSafeHelper(llvm::Value *v) {
// handle Adds, SExts, Constant Vectors // handle Adds, SExts, Constant Vectors
if (llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v)) { if (llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v)) {
if (bop->getOpcode() == llvm::Instruction::Add) { if (bop->getOpcode() == llvm::Instruction::Add) {
return lIs32BitSafeHelper(bop->getOperand(0)) return lIs32BitSafeHelper(bop->getOperand(0))
&& lIs32BitSafeHelper(bop->getOperand(1)); && lIs32BitSafeHelper(bop->getOperand(1));
} }
return false; return false;
@@ -4961,7 +4967,7 @@ bool
ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) { ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("ReplaceStdlibShiftPass"); DEBUG_START_PASS("ReplaceStdlibShiftPass");
bool modifiedAny = false; bool modifiedAny = false;
llvm::Function *shifts[6]; llvm::Function *shifts[6];
shifts[0] = m->module->getFunction("__shift_i8"); shifts[0] = m->module->getFunction("__shift_i8");
shifts[1] = m->module->getFunction("__shift_i16"); shifts[1] = m->module->getFunction("__shift_i16");
@@ -4992,19 +4998,19 @@ ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
} }
llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals); llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals);
llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType()); llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType());
llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec, llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec,
shuffleIdxs, "vecShift", ci); shuffleIdxs, "vecShift", ci);
ci->replaceAllUsesWith(shuffle); ci->replaceAllUsesWith(shuffle);
modifiedAny = true; modifiedAny = true;
delete [] shuffleVals; delete [] shuffleVals;
} else { } else {
PerformanceWarning(SourcePos(), "Stdlib shift() called without constant shift amount."); PerformanceWarning(SourcePos(), "Stdlib shift() called without constant shift amount.");
} }
} }
} }
} }
} }
DEBUG_END_PASS("ReplaceStdlibShiftPass"); DEBUG_END_PASS("ReplaceStdlibShiftPass");
return modifiedAny; return modifiedAny;
@@ -5015,3 +5021,185 @@ static llvm::Pass *
CreateReplaceStdlibShiftPass() { CreateReplaceStdlibShiftPass() {
return new ReplaceStdlibShiftPass(); return new ReplaceStdlibShiftPass();
} }
///////////////////////////////////////////////////////////////////////////////
// FixBooleanSelect
//
// The problem is that in LLVM 3.3, optimizer doesn't like
// the following instruction sequence:
// %cmp = fcmp olt <8 x float> %a, %b
// %sext_cmp = sext <8 x i1> %cmp to <8 x i32>
// %new_mask = and <8 x i32> %sext_cmp, %mask
// and optimizes it to the following:
// %cmp = fcmp olt <8 x float> %a, %b
// %cond = select <8 x i1> %cmp, <8 x i32> %mask, <8 x i32> zeroinitializer
//
// It wouldn't be a problem if codegen produced good code for it. But it
// doesn't, especially for vectors larger than native vectors.
//
// This optimization reverts this pattern and should be the last one before
// code gen.
//
// Note that this problem was introduced in LLVM 3.3. But in LLVM 3.4 it was
// fixed. See commit r194542.
//
// After LLVM 3.3 this optimization should probably stay for experimental
// purposes and code should be compared with and without this optimization from
// time to time to make sure that LLVM does right thing.
///////////////////////////////////////////////////////////////////////////////
class FixBooleanSelectPass : public llvm::FunctionPass {
public:
static char ID;
FixBooleanSelectPass() :FunctionPass(ID) {}
const char *getPassName() const { return "Resolve \"replace extract insert chains\""; }
bool runOnFunction(llvm::Function &F);
private:
llvm::Instruction* fixSelect(llvm::SelectInst* sel, llvm::SExtInst* sext);
};
char FixBooleanSelectPass::ID = 0;
llvm::Instruction* FixBooleanSelectPass::fixSelect(llvm::SelectInst* sel, llvm::SExtInst* sext) {
// Select instruction result type and its integer equivalent
llvm::VectorType *orig_type = llvm::dyn_cast<llvm::VectorType>(sel->getType());
llvm::VectorType *int_type = llvm::VectorType::getInteger(orig_type);
// Result value and optional pointer to instruction to delete
llvm::Instruction *result = 0, *optional_to_delete = 0;
// It can be vector of integers or vector of floating point values.
if (orig_type->getElementType()->isIntegerTy()) {
// Generate sext+and, remove select.
result = llvm::BinaryOperator::CreateAnd(sext, sel->getTrueValue(), "and_mask", sel);
} else {
llvm::BitCastInst* bc = llvm::dyn_cast<llvm::BitCastInst>(sel->getTrueValue());
if (bc && bc->hasOneUse() && bc->getSrcTy()->isIntOrIntVectorTy() && bc->getSrcTy()->isVectorTy() &&
llvm::isa<llvm::Instruction>(bc->getOperand(0)) &&
llvm::dyn_cast<llvm::Instruction>(bc->getOperand(0))->getParent() == sel->getParent()) {
// Bitcast is casting form integer type, it's operand is instruction, which is located in the same basic block (otherwise it's unsafe to use it).
// bitcast+select => sext+and+bicast
// Create and
llvm::BinaryOperator* and_inst = llvm::BinaryOperator::CreateAnd(sext, bc->getOperand(0), "and_mask", sel);
// Bitcast back to original type
result = new llvm::BitCastInst(and_inst, sel->getType(), "bitcast_mask_out", sel);
// Original bitcast will be removed
optional_to_delete = bc;
} else {
// General case: select => bitcast+sext+and+bitcast
// Bitcast
llvm::BitCastInst* bc_in = new llvm::BitCastInst(sel->getTrueValue(), int_type, "bitcast_mask_in", sel);
// And
llvm::BinaryOperator* and_inst = llvm::BinaryOperator::CreateAnd(sext, bc_in, "and_mask", sel);
// Bitcast back to original type
result = new llvm::BitCastInst(and_inst, sel->getType(), "bitcast_mask_out", sel);
}
}
// Done, finalize.
sel->replaceAllUsesWith(result);
sel->eraseFromParent();
if (optional_to_delete) {
optional_to_delete->eraseFromParent();
}
return result;
}
bool
FixBooleanSelectPass::runOnFunction(llvm::Function &F) {
bool modifiedAny = false;
// LLVM 3.3 only
#if defined(LLVM_3_3)
for (llvm::Function::iterator I = F.begin(), E = F.end();
I != E; ++I) {
llvm::BasicBlock* bb = &*I;
for (llvm::BasicBlock::iterator iter = bb->begin(), e = bb->end(); iter != e; ++iter) {
llvm::Instruction *inst = &*iter;
llvm::CmpInst *cmp = llvm::dyn_cast<llvm::CmpInst>(inst);
if (cmp &&
cmp->getType()->isVectorTy() &&
cmp->getType()->getVectorElementType()->isIntegerTy(1)) {
// Search for select instruction uses.
int selects = 0;
llvm::VectorType* sext_type = 0;
for (llvm::Instruction::use_iterator it=cmp->use_begin(); it!=cmp->use_end(); ++it ) {
llvm::SelectInst* sel = llvm::dyn_cast<llvm::SelectInst>(*it);
if (sel &&
sel->getType()->isVectorTy() &&
sel->getType()->getScalarSizeInBits() > 1) {
selects++;
// We pick the first one, but typical case when all select types are the same.
sext_type = llvm::dyn_cast<llvm::VectorType>(sel->getType());
break;
}
}
if (selects == 0) {
continue;
}
// Get an integer equivalent, if it's not yet an integer.
sext_type = llvm::VectorType::getInteger(sext_type);
// Do transformation
llvm::BasicBlock::iterator iter_copy=iter;
llvm::Instruction* next_inst = &*(++iter_copy);
// Create or reuse sext
llvm::SExtInst* sext = llvm::dyn_cast<llvm::SExtInst>(next_inst);
if (sext &&
sext->getOperand(0) == cmp &&
sext->getDestTy() == sext_type) {
// This sext can be reused
} else {
if (next_inst) {
sext = new llvm::SExtInst(cmp, sext_type, "sext_cmp", next_inst);
} else {
sext = new llvm::SExtInst(cmp, sext_type, "sext_cmp", bb);
}
}
// Walk and fix selects
std::vector<llvm::SelectInst*> sel_uses;
for (llvm::Instruction::use_iterator it=cmp->use_begin(); it!=cmp->use_end(); ++it) {
llvm::SelectInst* sel = llvm::dyn_cast<llvm::SelectInst>(*it);
if (sel &&
sel->getType()->getScalarSizeInBits() == sext_type->getScalarSizeInBits()) {
// Check that second operand is zero.
llvm::Constant* false_cond = llvm::dyn_cast<llvm::Constant>(sel->getFalseValue());
if (false_cond &&
false_cond->isZeroValue()) {
sel_uses.push_back(sel);
modifiedAny = true;
}
}
}
for (int i=0; i<sel_uses.size(); i++) {
fixSelect(sel_uses[i], sext);
}
}
}
}
#endif // LLVM 3.3
return modifiedAny;
}
static llvm::Pass *
CreateFixBooleanSelectPass() {
return new FixBooleanSelectPass();
}