diff --git a/builtins/util.m4 b/builtins/util.m4 index c1582e51..0e017322 100644 --- a/builtins/util.m4 +++ b/builtins/util.m4 @@ -798,23 +798,6 @@ not_const: } define @__shift_$1(, i32) nounwind readnone alwaysinline { - %isc = call i1 @__is_compile_time_constant_uniform_int32(i32 %1) - %zeropaddedvec = shufflevector %0, zeroinitializer, - < forloop(i, 0, eval(2*WIDTH-2), `i32 i, ')i32 eval(2*WIDTH-1) > - br i1 %isc, label %is_const, label %not_const - -is_const: - ; though verbose, this turms into tight code if %1 is a constant -forloop(i, 0, eval(WIDTH-1), ` - %delta_`'i = add i32 %1, i - %delta_clamped_`'i = and i32 %delta_`'i, eval(2*WIDTH-1) - %v_`'i = extractelement %zeropaddedvec, i32 %delta_clamped_`'i') - %ret_0 = insertelement zeroinitializer, $1 %v_0, i32 0 -forloop(i, 1, eval(WIDTH-1), ` %ret_`'i = insertelement %ret_`'eval(i-1), $1 %v_`'i, i32 i -') - ret %ret_`'eval(WIDTH-1) - -not_const: %ptr = alloca , i32 3 %ptr0 = getelementptr * %ptr, i32 0 store zeroinitializer, * %ptr0 diff --git a/opt.cpp b/opt.cpp index 0146e7cf..b1a22a1c 100644 --- a/opt.cpp +++ b/opt.cpp @@ -125,7 +125,7 @@ static llvm::Pass *CreateMakeInternalFuncsStaticPass(); static llvm::Pass *CreateDebugPass(char * output); -static llvm::Pass *CreateReplaceExtractInsertChainsPass(); +static llvm::Pass *CreateReplaceStdlibShiftPass(); #define DEBUG_START_PASS(NAME) \ if (g->debugPrint && \ @@ -524,6 +524,7 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(llvm::createPromoteMemoryToRegisterPass()); optPM.add(llvm::createAggressiveDCEPass()); + if (g->opt.disableGatherScatterOptimizations == false && g->target->getVectorWidth() > 1) { optPM.add(llvm::createInstructionCombiningPass(), 210); @@ -535,6 +536,9 @@ Optimize(llvm::Module *module, int optLevel) { } optPM.add(llvm::createDeadInstEliminationPass(), 220); + optPM.add(llvm::createIPConstantPropagationPass()); + optPM.add(CreateReplaceStdlibShiftPass()); + // Max struct size threshold for scalar replacement is // 1) 4 fields (r,g,b,w) // 2) field size: vectorWidth * sizeof(float) @@ -638,7 +642,6 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(CreateIsCompileTimeConstantPass(true)); optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateInstructionSimplifyPass()); - optPM.add(CreateReplaceExtractInsertChainsPass()); optPM.add(llvm::createMemCpyOptPass()); optPM.add(llvm::createSCCPPass()); @@ -4883,6 +4886,7 @@ lMatchAvgDownInt16(llvm::Value *inst) { } #endif // !LLVM_3_1 && !LLVM_3_2 + bool PeepholePass::runOnBasicBlock(llvm::BasicBlock &bb) { DEBUG_START_PASS("PeepholePass"); @@ -4928,31 +4932,6 @@ CreatePeepholePass() { return new PeepholePass; } -/////////////////////////////////////////////////////////////////////////// -// ReplaceExtractInsertChainsPass - -/** - We occassionally get chains of ExtractElementInsts followed by - InsertElementInsts. Unfortunately, all of these can't be replaced by - ShuffleVectorInsts as we don't know that things are constant at the time. - - This Pass will detect such chains, and replace them with ShuffleVectorInsts - if all the appropriate values are constant. - */ - -class ReplaceExtractInsertChainsPass : public llvm::BasicBlockPass { -public: - static char ID; - ReplaceExtractInsertChainsPass() : BasicBlockPass(ID) { - } - - const char *getPassName() const { return "Resolve \"replace extract insert chains\""; } - bool runOnBasicBlock(llvm::BasicBlock &BB); - -}; - -char ReplaceExtractInsertChainsPass::ID = 0; - #include /** Given an llvm::Value known to be an integer, return its value as @@ -4966,97 +4945,74 @@ lGetIntValue(llvm::Value *offset) { return intOffset->getSExtValue(); } +/////////////////////////////////////////////////////////////////////////// +// ReplaceStdlibShiftPass + +class ReplaceStdlibShiftPass : public llvm::BasicBlockPass { +public: + static char ID; + ReplaceStdlibShiftPass() : BasicBlockPass(ID) { + } + + const char *getPassName() const { return "Resolve \"replace extract insert chains\""; } + bool runOnBasicBlock(llvm::BasicBlock &BB); + +}; + +char ReplaceStdlibShiftPass::ID = 0; + bool -ReplaceExtractInsertChainsPass::runOnBasicBlock(llvm::BasicBlock &bb) { - DEBUG_START_PASS("ReplaceExtractInsertChainsPass"); +ReplaceStdlibShiftPass::runOnBasicBlock(llvm::BasicBlock &bb) { + DEBUG_START_PASS("ReplaceStdlibShiftPass"); bool modifiedAny = false; + + llvm::Function *shifts[6]; + shifts[0] = m->module->getFunction("__shift_i8"); + shifts[1] = m->module->getFunction("__shift_i16"); + shifts[2] = m->module->getFunction("__shift_i32"); + shifts[3] = m->module->getFunction("__shift_i64"); + shifts[4] = m->module->getFunction("__shift_float"); + shifts[5] = m->module->getFunction("__shift_double"); - // Initialize our mapping to the first spot in the zero vector - int vectorWidth = g->target->getVectorWidth(); - int shuffleMap[vectorWidth]; - for (int i = 0; i < vectorWidth; i++) { - shuffleMap[i] = vectorWidth; - } + for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) { + llvm::Instruction *inst = &*iter; - // Hack-y. 16 is likely the upper limit for now. - llvm::SmallSet inserts; - - // save the last Insert in the chain - llvm::Value * lastInsert = NULL; - - for (llvm::BasicBlock::iterator i = bb.begin(), e = bb.end(); i != e; ++i) { - // Iterate through the instructions looking for InsertElementInsts - llvm::InsertElementInst *ieInst = llvm::dyn_cast(&*i); - if (ieInst == NULL) { - // These aren't the instructions you're looking for. - continue; - } - - llvm::Value * base = ieInst->getOperand(0); - if ( (llvm::isa(base)) - || (llvm::isa(base)) - || (base == lastInsert)) { - // if source for insert scalar is 0 or an EEInst, add insert - llvm::Value *scalar = ieInst->getOperand(1); - if (llvm::ExtractElementInst *eeInst = llvm::dyn_cast(scalar)) { - // We're only going to deal with Inserts into a Constant vector lane - if (llvm::isa(eeInst->getOperand(1))) { - inserts.insert(ieInst); - lastInsert = ieInst; + if (llvm::CallInst *ci = llvm::dyn_cast(inst)) { + llvm::Function *func = ci->getCalledFunction(); + for (int i = 0; i < 6; i++) { + if (shifts[i] == func) { + // we matched a call + llvm::Value *shiftedVec = ci->getArgOperand(0); + llvm::Value *shiftAmt = ci->getArgOperand(1); + if (llvm::isa(shiftAmt)) { + int vectorWidth = g->target->getVectorWidth(); + int shuffleVals[vectorWidth]; + int shiftInt = lGetIntValue(shiftAmt); + for (int i = 0; i < vectorWidth; i++) { + int s = i + shiftInt; + s = (s < 0) ? vectorWidth : s; + s = (s >= vectorWidth) ? vectorWidth : s; + shuffleVals[i] = s; + } + llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleVals); + llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shiftedVec->getType()); + llvm::Value *shuffle = new llvm::ShuffleVectorInst(shiftedVec, zeroVec, + shuffleIdxs, "vecShift", ci); + ci->replaceAllUsesWith(shuffle); + modifiedAny = true; + } + } } } - else if (llvm::ConstantInt *ci = llvm::dyn_cast(scalar)) { - if (ci->isZero()) { - inserts.insert(ieInst); - lastInsert = ieInst; - } - } - else { - lastInsert = NULL; - } - } } - // Look for chains, not insert/shuffle sequences - if (inserts.size() > 1) { - // The vector from which we're extracting elements - llvm::Value * baseVec = NULL; - llvm::Value *ee = llvm::cast((*inserts.begin()))->getOperand(1); - if (llvm::ExtractElementInst *eeInst = llvm::dyn_cast(ee)) { - baseVec = eeInst->getOperand(0); - } - - bool sameBase = true; - for (llvm::SmallSet::iterator i = inserts.begin(); i != inserts.end(); i++) { - llvm::InsertElementInst *ie = llvm::cast(*i); - if (llvm::ExtractElementInst *ee = llvm::dyn_cast(ie->getOperand(1))) { - if (ee->getOperand(0) != baseVec) { - sameBase = false; - break; - } - int64_t from = lGetIntValue(ee->getIndexOperand()); - int64_t to = lGetIntValue(ie->getOperand(2)); - shuffleMap[to] = from; - } - } - if (sameBase) { - llvm::Value *shuffleIdxs = LLVMInt32Vector(shuffleMap); - llvm::Value *zeroVec = llvm::ConstantAggregateZero::get(shuffleIdxs->getType()); - llvm::Value *shuffle = new llvm::ShuffleVectorInst(baseVec, zeroVec, shuffleIdxs, "shiftInZero", llvm::cast(lastInsert)); - // For now, be lazy and let DCE clean up the Extracts/Inserts. - lastInsert->replaceAllUsesWith(shuffle); - - modifiedAny = true; - } - } - - DEBUG_END_PASS("ReplaceExtractInsertChainsPass"); + DEBUG_END_PASS("ReplaceStdlibShiftPass"); return modifiedAny; } static llvm::Pass * -CreateReplaceExtractInsertChainsPass() { - return new ReplaceExtractInsertChainsPass(); +CreateReplaceStdlibShiftPass() { + return new ReplaceStdlibShiftPass(); } diff --git a/stdlib.ispc b/stdlib.ispc index 248f664a..6768594b 100644 --- a/stdlib.ispc +++ b/stdlib.ispc @@ -172,32 +172,56 @@ static inline int64 rotate(int64 v, uniform int i) { __declspec(safe) static inline float shift(float v, uniform int i) { - return __shift_float(v, i); + varying float result; + unmasked { + result = __shift_float(v, i); + } + return result; } __declspec(safe) static inline int8 shift(int8 v, uniform int i) { - return __shift_i8(v, i); + varying int8 result; + unmasked { + result = __shift_i8(v, i); + } + return result; } __declspec(safe) static inline int16 shift(int16 v, uniform int i) { - return __shift_i16(v, i); + varying int16 result; + unmasked { + result = __shift_i16(v, i); + } + return result; } __declspec(safe) static inline int32 shift(int32 v, uniform int i) { - return __shift_i32(v, i); + varying int32 result; + unmasked { + result = __shift_i32(v, i); + } + return result; } __declspec(safe) static inline double shift(double v, uniform int i) { - return __shift_double(v, i); + varying double result; + unmasked { + result = __shift_double(v, i); + } + return result; } __declspec(safe) static inline int64 shift(int64 v, uniform int i) { - return __shift_i64(v, i); + varying int64 result; + unmasked { + result = __shift_i64(v, i); + } + return result; } __declspec(safe)