Redesign after being hit with the KISS bat.

This commit is contained in:
james.brodman
2013-10-23 14:25:43 -04:00
parent f97a2d68c8
commit 4d289b16c2
3 changed files with 92 additions and 129 deletions

View File

@@ -798,23 +798,6 @@ not_const:
}
define <WIDTH x $1> @__shift_$1(<WIDTH x $1>, i32) nounwind readnone alwaysinline {
%isc = call i1 @__is_compile_time_constant_uniform_int32(i32 %1)
%zeropaddedvec = shufflevector <WIDTH x $1> %0, <WIDTH x $1> zeroinitializer,
<eval(2*WIDTH) x i32> < forloop(i, 0, eval(2*WIDTH-2), `i32 i, ')i32 eval(2*WIDTH-1) >
br i1 %isc, label %is_const, label %not_const
is_const:
; though verbose, this turms into tight code if %1 is a constant
forloop(i, 0, eval(WIDTH-1), `
%delta_`'i = add i32 %1, i
%delta_clamped_`'i = and i32 %delta_`'i, eval(2*WIDTH-1)
%v_`'i = extractelement <eval(2*WIDTH) x $1> %zeropaddedvec, i32 %delta_clamped_`'i')
%ret_0 = insertelement <WIDTH x $1> zeroinitializer, $1 %v_0, i32 0
forloop(i, 1, eval(WIDTH-1), ` %ret_`'i = insertelement <WIDTH x $1> %ret_`'eval(i-1), $1 %v_`'i, i32 i
')
ret <WIDTH x $1> %ret_`'eval(WIDTH-1)
not_const:
%ptr = alloca <WIDTH x $1>, i32 3
%ptr0 = getelementptr <WIDTH x $1> * %ptr, i32 0
store <WIDTH x $1> zeroinitializer, <WIDTH x $1> * %ptr0

168
opt.cpp
View File

@@ -125,7 +125,7 @@ static llvm::Pass *CreateMakeInternalFuncsStaticPass();
static llvm::Pass *CreateDebugPass(char * output);
static llvm::Pass *CreateReplaceExtractInsertChainsPass();
static llvm::Pass *CreateReplaceStdlibShiftPass();
#define DEBUG_START_PASS(NAME) \
if (g->debugPrint && \
@@ -524,6 +524,7 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(llvm::createPromoteMemoryToRegisterPass());
optPM.add(llvm::createAggressiveDCEPass());
if (g->opt.disableGatherScatterOptimizations == false &&
g->target->getVectorWidth() > 1) {
optPM.add(llvm::createInstructionCombiningPass(), 210);
@@ -535,6 +536,9 @@ Optimize(llvm::Module *module, int optLevel) {
}
optPM.add(llvm::createDeadInstEliminationPass(), 220);
optPM.add(llvm::createIPConstantPropagationPass());
optPM.add(CreateReplaceStdlibShiftPass());
// Max struct size threshold for scalar replacement is
// 1) 4 fields (r,g,b,w)
// 2) field size: vectorWidth * sizeof(float)
@@ -638,7 +642,6 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(CreateIsCompileTimeConstantPass(true));
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateInstructionSimplifyPass());
optPM.add(CreateReplaceExtractInsertChainsPass());
optPM.add(llvm::createMemCpyOptPass());
optPM.add(llvm::createSCCPPass());
@@ -4883,6 +4886,7 @@ lMatchAvgDownInt16(llvm::Value *inst) {
}
#endif // !LLVM_3_1 && !LLVM_3_2
bool
PeepholePass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("PeepholePass");
@@ -4928,31 +4932,6 @@ CreatePeepholePass() {
return new PeepholePass;
}
///////////////////////////////////////////////////////////////////////////
// ReplaceExtractInsertChainsPass
/**
We occassionally get chains of ExtractElementInsts followed by
InsertElementInsts. Unfortunately, all of these can't be replaced by
ShuffleVectorInsts as we don't know that things are constant at the time.
This Pass will detect such chains, and replace them with ShuffleVectorInsts
if all the appropriate values are constant.
*/
class ReplaceExtractInsertChainsPass : public llvm::BasicBlockPass {
public:
static char ID;
ReplaceExtractInsertChainsPass() : BasicBlockPass(ID) {
}
const char *getPassName() const { return "Resolve \"replace extract insert chains\""; }
bool runOnBasicBlock(llvm::BasicBlock &BB);
};
char ReplaceExtractInsertChainsPass::ID = 0;
#include <iostream>
/** Given an llvm::Value known to be an integer, return its value as
@@ -4966,97 +4945,74 @@ lGetIntValue(llvm::Value *offset) {
return intOffset->getSExtValue();
}
///////////////////////////////////////////////////////////////////////////
// ReplaceStdlibShiftPass
class ReplaceStdlibShiftPass : public llvm::BasicBlockPass {
public:
static char ID;
ReplaceStdlibShiftPass() : BasicBlockPass(ID) {
}
const char *getPassName() const { return "Resolve \"replace extract insert chains\""; }
bool runOnBasicBlock(llvm::BasicBlock &BB);
};
char ReplaceStdlibShiftPass::ID = 0;
bool
ReplaceExtractInsertChainsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("ReplaceExtractInsertChainsPass");
ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) {
DEBUG_START_PASS("ReplaceStdlibShiftPass");
bool modifiedAny = false;
llvm::Function *shifts[6];
shifts[0] = m->module->getFunction("__shift_i8");
shifts[1] = m->module->getFunction("__shift_i16");
shifts[2] = m->module->getFunction("__shift_i32");
shifts[3] = m->module->getFunction("__shift_i64");
shifts[4] = m->module->getFunction("__shift_float");
shifts[5] = m->module->getFunction("__shift_double");
// Initialize our mapping to the first spot in the zero vector
int vectorWidth = g->target->getVectorWidth();
int shuffleMap[vectorWidth];
for (int i = 0; i < vectorWidth; i++) {
shuffleMap[i] = vectorWidth;
}
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
llvm::Instruction *inst = &*iter;
// Hack-y. 16 is likely the upper limit for now.
llvm::SmallSet<llvm::Value *, 16> inserts;
// save the last Insert in the chain
llvm::Value * lastInsert = NULL;
for (llvm::BasicBlock::iterator i = bb.begin(), e = bb.end(); i != e; ++i) {
// Iterate through the instructions looking for InsertElementInsts
llvm::InsertElementInst *ieInst = llvm::dyn_cast<llvm::InsertElementInst>(&*i);
if (ieInst == NULL) {
// These aren't the instructions you're looking for.
continue;
}
llvm::Value * base = ieInst->getOperand(0);
if ( (llvm::isa<llvm::UndefValue>(base))
|| (llvm::isa<llvm::ConstantAggregateZero>(base))
|| (base == lastInsert)) {
// if source for insert scalar is 0 or an EEInst, add insert
llvm::Value *scalar = ieInst->getOperand(1);
if (llvm::ExtractElementInst *eeInst = llvm::dyn_cast<llvm::ExtractElementInst>(scalar)) {
// We're only going to deal with Inserts into a Constant vector lane
if (llvm::isa<llvm::Constant>(eeInst->getOperand(1))) {
inserts.insert(ieInst);
lastInsert = ieInst;
if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(inst)) {
llvm::Function *func = ci->getCalledFunction();
for (int i = 0; i < 6; i++) {
if (shifts[i] == func) {
// we matched a call
llvm::Value *shiftedVec = ci->getArgOperand(0);
llvm::Value *shiftAmt = ci->getArgOperand(1);
if (llvm::isa<llvm::Constant>(shiftAmt)) {
int vectorWidth = g->target->getVectorWidth();
int shuffleVals[vectorWidth];
int shiftInt = lGetIntValue(shiftAmt);
for (int i = 0; i < vectorWidth; i++) {
int s = i + shiftInt;
s = (s < 0) ? vectorWidth : s;
s = (s >= vectorWidth) ? vectorWidth : s;
shuffleVals[i] = s;
}
llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals);
llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType());
llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec,
shuffleIdxs, "vecShift", ci);
ci->replaceAllUsesWith(shuffle);
modifiedAny = true;
}
}
}
}
else if (llvm::ConstantInt *ci = llvm::dyn_cast<llvm::ConstantInt>(scalar)) {
if (ci->isZero()) {
inserts.insert(ieInst);
lastInsert = ieInst;
}
}
else {
lastInsert = NULL;
}
}
}
// Look for chains, not insert/shuffle sequences
if (inserts.size() > 1) {
// The vector from which we're extracting elements
llvm::Value * baseVec = NULL;
llvm::Value *ee = llvm::cast<llvm::InsertElementInst>((*inserts.begin()))->getOperand(1);
if (llvm::ExtractElementInst *eeInst = llvm::dyn_cast<llvm::ExtractElementInst>(ee)) {
baseVec = eeInst->getOperand(0);
}
bool sameBase = true;
for (llvm::SmallSet<llvm::Value *,16>::iterator i = inserts.begin(); i != inserts.end(); i++) {
llvm::InsertElementInst *ie = llvm::cast<llvm::InsertElementInst>(*i);
if (llvm::ExtractElementInst *ee = llvm::dyn_cast<llvm::ExtractElementInst>(ie->getOperand(1))) {
if (ee->getOperand(0) != baseVec) {
sameBase = false;
break;
}
int64_t from = lGetIntValue(ee->getIndexOperand());
int64_t to = lGetIntValue(ie->getOperand(2));
shuffleMap[to] = from;
}
}
if (sameBase) {
llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleMap);
llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shuffleIdxs->getType());
llvm::Value *shuffle = new llvm::ShuffleVectorInst(baseVec, zeroVec, shuffleIdxs, "shiftInZero", llvm::cast<llvm::Instruction>(lastInsert));
// For now, be lazy and let DCE clean up the Extracts/Inserts.
lastInsert->replaceAllUsesWith(shuffle);
modifiedAny = true;
}
}
DEBUG_END_PASS("ReplaceExtractInsertChainsPass");
DEBUG_END_PASS("ReplaceStdlibShiftPass");
return modifiedAny;
}
static llvm::Pass *
CreateReplaceExtractInsertChainsPass() {
return new ReplaceExtractInsertChainsPass();
CreateReplaceStdlibShiftPass() {
return new ReplaceStdlibShiftPass();
}

View File

@@ -172,32 +172,56 @@ static inline int64 rotate(int64 v, uniform int i) {
__declspec(safe)
static inline float shift(float v, uniform int i) {
return __shift_float(v, i);
varying float result;
unmasked {
result = __shift_float(v, i);
}
return result;
}
__declspec(safe)
static inline int8 shift(int8 v, uniform int i) {
return __shift_i8(v, i);
varying int8 result;
unmasked {
result = __shift_i8(v, i);
}
return result;
}
__declspec(safe)
static inline int16 shift(int16 v, uniform int i) {
return __shift_i16(v, i);
varying int16 result;
unmasked {
result = __shift_i16(v, i);
}
return result;
}
__declspec(safe)
static inline int32 shift(int32 v, uniform int i) {
return __shift_i32(v, i);
varying int32 result;
unmasked {
result = __shift_i32(v, i);
}
return result;
}
__declspec(safe)
static inline double shift(double v, uniform int i) {
return __shift_double(v, i);
varying double result;
unmasked {
result = __shift_double(v, i);
}
return result;
}
__declspec(safe)
static inline int64 shift(int64 v, uniform int i) {
return __shift_i64(v, i);
varying int64 result;
unmasked {
result = __shift_i64(v, i);
}
return result;
}
__declspec(safe)