Redesign after being hit with the KISS bat.

This commit is contained in:
james.brodman
2013-10-23 14:25:43 -04:00
parent f97a2d68c8
commit 4d289b16c2
3 changed files with 92 additions and 129 deletions

168
opt.cpp
View File

@@ -125,7 +125,7 @@ static llvm::Pass *CreateMakeInternalFuncsStaticPass();
static llvm::Pass *CreateDebugPass(char * output);
static llvm::Pass *CreateReplaceExtractInsertChainsPass();
static llvm::Pass *CreateReplaceStdlibShiftPass();
#define DEBUG_START_PASS(NAME) \
if (g->debugPrint && \
@@ -524,6 +524,7 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(llvm::createPromoteMemoryToRegisterPass());
optPM.add(llvm::createAggressiveDCEPass());
if (g->opt.disableGatherScatterOptimizations == false &&
g->target->getVectorWidth() > 1) {
optPM.add(llvm::createInstructionCombiningPass(), 210);
@@ -535,6 +536,9 @@ Optimize(llvm::Module *module, int optLevel) {
}
optPM.add(llvm::createDeadInstEliminationPass(), 220);
optPM.add(llvm::createIPConstantPropagationPass());
optPM.add(CreateReplaceStdlibShiftPass());
// Max struct size threshold for scalar replacement is
// 1) 4 fields (r,g,b,w)
// 2) field size: vectorWidth * sizeof(float)
@@ -638,7 +642,6 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(CreateIsCompileTimeConstantPass(true));
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(CreateReplaceExtractInsertChainsPass());
optPM.add(llvm::createMemCpyOptPass());
optPM.add(llvm::createSCCPPass());
@@ -4883,6 +4886,7 @@ lMatchAvgDownInt16(llvm::Value *inst) {
}
#endif // !LLVM_3_1 && !LLVM_3_2
bool
PeepholePass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("PeepholePass");
@@ -4928,31 +4932,6 @@ CreatePeepholePass() {
return new PeepholePass;
}
///////////////////////////////////////////////////////////////////////////
// ReplaceExtractInsertChainsPass
/**
We occassionally get chains of ExtractElementInsts followed by
InsertElementInsts. Unfortunately, all of these can't be replaced by
ShuffleVectorInsts as we don't know that things are constant at the time.
This Pass will detect such chains, and replace them with ShuffleVectorInsts
if all the appropriate values are constant.
*/
class ReplaceExtractInsertChainsPass : public llvm::BasicBlockPass {
public:
static char ID;
ReplaceExtractInsertChainsPass() : BasicBlockPass(ID) {
}
const char *getPassName() const { return "Resolve \"replace extract insert chains\""; }
bool runOnBasicBlock(llvm::BasicBlock &BB);
};
char ReplaceExtractInsertChainsPass::ID = 0;
#include <iostream>
/** Given an llvm::Value known to be an integer, return its value as
@@ -4966,97 +4945,74 @@ lGetIntValue(llvm::Value *offset) {
return intOffset->getSExtValue();
}
///////////////////////////////////////////////////////////////////////////
// ReplaceStdlibShiftPass
class ReplaceStdlibShiftPass : public llvm::BasicBlockPass {
public:
static char ID;
ReplaceStdlibShiftPass() : BasicBlockPass(ID) {
}
const char *getPassName() const { return "Resolve \"replace extract insert chains\""; }
bool runOnBasicBlock(llvm::BasicBlock &BB);
};
char ReplaceStdlibShiftPass::ID = 0;
bool
ReplaceExtractInsertChainsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("ReplaceExtractInsertChainsPass");
ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("ReplaceStdlibShiftPass");
bool modifiedAny = false;
llvm::Function *shifts[6];
shifts[0] = m->module->getFunction("__shift_i8");
shifts[1] = m->module->getFunction("__shift_i16");
shifts[2] = m->module->getFunction("__shift_i32");
shifts[3] = m->module->getFunction("__shift_i64");
shifts[4] = m->module->getFunction("__shift_float");
shifts[5] = m->module->getFunction("__shift_double");
// Initialize our mapping to the first spot in the zero vector
int vectorWidth = g->target->getVectorWidth();
int shuffleMap[vectorWidth];
for (int i = 0; i < vectorWidth; i++) {
shuffleMap[i] = vectorWidth;
}
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
llvm::Instruction *inst = &*iter;
// Hack-y. 16 is likely the upper limit for now.
llvm::SmallSet<llvm::Value *, 16> inserts;
// save the last Insert in the chain
llvm::Value * lastInsert = NULL;
for (llvm::BasicBlock::iterator i = bb.begin(), e = bb.end(); i != e; ++i) {
// Iterate through the instructions looking for InsertElementInsts
llvm::InsertElementInst *ieInst = llvm::dyn_cast<llvm::InsertElementInst>(&*i);
if (ieInst == NULL) {
// These aren't the instructions you're looking for.
continue;
}
llvm::Value * base = ieInst->getOperand(0);
if ( (llvm::isa<llvm::UndefValue>(base))
|| (llvm::isa<llvm::ConstantAggregateZero>(base))
|| (base == lastInsert)) {
// if source for insert scalar is 0 or an EEInst, add insert
llvm::Value *scalar = ieInst->getOperand(1);
if (llvm::ExtractElementInst *eeInst = llvm::dyn_cast<llvm::ExtractElementInst>(scalar)) {
// We're only going to deal with Inserts into a Constant vector lane
if (llvm::isa<llvm::Constant>(eeInst->getOperand(1))) {
inserts.insert(ieInst);
lastInsert = ieInst;
if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst)) {
llvm::Function *func = ci->getCalledFunction();
for (int i = 0; i < 6; i++) {
if (shifts[i] == func) {
// we matched a call
llvm::Value *shiftedVec = ci->getArgOperand(0);
llvm::Value *shiftAmt = ci->getArgOperand(1);
if (llvm::isa<llvm::Constant>(shiftAmt)) {
int vectorWidth = g->target->getVectorWidth();
int shuffleVals[vectorWidth];
int shiftInt = lGetIntValue(shiftAmt);
for (int i = 0; i < vectorWidth; i++) {
int s = i + shiftInt;
s = (s < 0) ? vectorWidth : s;
s = (s >= vectorWidth) ? vectorWidth : s;
shuffleVals[i] = s;
}
llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals);
llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType());
llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec,
shuffleIdxs, "vecShift", ci);
ci->replaceAllUsesWith(shuffle);
modifiedAny = true;
}
}
}
}
else if (llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(scalar)) {
if (ci->isZero()) {
inserts.insert(ieInst);
lastInsert = ieInst;
}
}
else {
lastInsert = NULL;
}
}
}
// Look for chains, not insert/shuffle sequences
if (inserts.size() > 1) {
// The vector from which we're extracting elements
llvm::Value * baseVec = NULL;
llvm::Value *ee = llvm::cast<llvm::InsertElementInst>((*inserts.begin()))->getOperand(1);
if (llvm::ExtractElementInst *eeInst = llvm::dyn_cast<llvm::ExtractElementInst>(ee)) {
baseVec = eeInst->getOperand(0);
}
bool sameBase = true;
for (llvm::SmallSet<llvm::Value *,16>::iterator i = inserts.begin(); i != inserts.end(); i++) {
llvm::InsertElementInst *ie = llvm::cast<llvm::InsertElementInst>(*i);
if (llvm::ExtractElementInst *ee = llvm::dyn_cast<llvm::ExtractElementInst>(ie->getOperand(1))) {
if (ee->getOperand(0) != baseVec) {
sameBase = false;
break;
}
int64_t from = lGetIntValue(ee->getIndexOperand());
int64_t to = lGetIntValue(ie->getOperand(2));
shuffleMap[to] = from;
}
}
if (sameBase) {
llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleMap);
llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shuffleIdxs->getType());
llvm::Value *shuffle = new llvm::ShuffleVectorInst(baseVec, zeroVec, shuffleIdxs, "shiftInZero", llvm::cast<llvm::Instruction>(lastInsert));
// For now, be lazy and let DCE clean up the Extracts/Inserts.
lastInsert->replaceAllUsesWith(shuffle);
modifiedAny = true;
}
}
DEBUG_END_PASS("ReplaceExtractInsertChainsPass");
DEBUG_END_PASS("ReplaceStdlibShiftPass");
return modifiedAny;
}
static llvm::Pass *
CreateReplaceExtractInsertChainsPass() {
return new ReplaceExtractInsertChainsPass();
CreateReplaceStdlibShiftPass() {
return new ReplaceStdlibShiftPass();
}