diff --git a/builtins/target-avx-x2.ll b/builtins/target-avx-x2.ll index 90e2680c..53f57c88 100644 --- a/builtins/target-avx-x2.ll +++ b/builtins/target-avx-x2.ll @@ -385,13 +385,13 @@ load_and_broadcast(16, i32, 32) load_and_broadcast(16, i64, 64) ; no masked load instruction for i8 and i16 types?? -load_masked(16, i8, 8, 1) -load_masked(16, i16, 16, 2) +masked_load(16, i8, 8, 1) +masked_load(16, i16, 16, 2) declare <8 x float> @llvm.x86.avx.maskload.ps.256(i8 *, <8 x float> %mask) declare <4 x double> @llvm.x86.avx.maskload.pd.256(i8 *, <4 x double> %mask) -define <16 x i32> @__load_masked_32(i8 *, <16 x i32> %mask) nounwind alwaysinline { +define <16 x i32> @__masked_load_32(i8 *, <16 x i32> %mask) nounwind alwaysinline { %floatmask = bitcast <16 x i32> %mask to <16 x float> %mask0 = shufflevector <16 x float> %floatmask, <16 x float> undef, <8 x i32> @@ -409,7 +409,7 @@ define <16 x i32> @__load_masked_32(i8 *, <16 x i32> %mask) nounwind alwaysinlin } -define <16 x i64> @__load_masked_64(i8 *, <16 x i32> %mask) nounwind alwaysinline { +define <16 x i64> @__masked_load_64(i8 *, <16 x i32> %mask) nounwind alwaysinline { ; double up masks, bitcast to doubles %mask0 = shufflevector <16 x i32> %mask, <16 x i32> undef, <8 x i32> diff --git a/builtins/target-avx.ll b/builtins/target-avx.ll index dc7339bd..b86ca712 100644 --- a/builtins/target-avx.ll +++ b/builtins/target-avx.ll @@ -366,13 +366,13 @@ load_and_broadcast(8, i32, 32) load_and_broadcast(8, i64, 64) ; no masked load instruction for i8 and i16 types?? -load_masked(8, i8, 8, 1) -load_masked(8, i16, 16, 2) +masked_load(8, i8, 8, 1) +masked_load(8, i16, 16, 2) declare <8 x float> @llvm.x86.avx.maskload.ps.256(i8 *, <8 x float> %mask) declare <4 x double> @llvm.x86.avx.maskload.pd.256(i8 *, <4 x double> %mask) -define <8 x i32> @__load_masked_32(i8 *, <8 x i32> %mask) nounwind alwaysinline { +define <8 x i32> @__masked_load_32(i8 *, <8 x i32> %mask) nounwind alwaysinline { %floatmask = bitcast <8 x i32> %mask to <8 x float> %floatval = call <8 x float> @llvm.x86.avx.maskload.ps.256(i8 * %0, <8 x float> %floatmask) %retval = bitcast <8 x float> %floatval to <8 x i32> @@ -380,7 +380,7 @@ define <8 x i32> @__load_masked_32(i8 *, <8 x i32> %mask) nounwind alwaysinline } -define <8 x i64> @__load_masked_64(i8 *, <8 x i32> %mask) nounwind alwaysinline { +define <8 x i64> @__masked_load_64(i8 *, <8 x i32> %mask) nounwind alwaysinline { ; double up masks, bitcast to doubles %mask0 = shufflevector <8 x i32> %mask, <8 x i32> undef, <8 x i32> diff --git a/builtins/target-generic-common.ll b/builtins/target-generic-common.ll index b59e8d53..4c815de1 100644 --- a/builtins/target-generic-common.ll +++ b/builtins/target-generic-common.ll @@ -175,10 +175,10 @@ load_and_broadcast(WIDTH, i16, 16) load_and_broadcast(WIDTH, i32, 32) load_and_broadcast(WIDTH, i64, 64) -declare @__load_masked_8(i8 * nocapture, %mask) nounwind readonly -declare @__load_masked_16(i8 * nocapture, %mask) nounwind readonly -declare @__load_masked_32(i8 * nocapture, %mask) nounwind readonly -declare @__load_masked_64(i8 * nocapture, %mask) nounwind readonly +declare @__masked_load_8(i8 * nocapture, %mask) nounwind readonly +declare @__masked_load_16(i8 * nocapture, %mask) nounwind readonly +declare @__masked_load_32(i8 * nocapture, %mask) nounwind readonly +declare @__masked_load_64(i8 * nocapture, %mask) nounwind readonly declare void @__masked_store_8(* nocapture, , ) nounwind diff --git a/builtins/target-sse2-x2.ll b/builtins/target-sse2-x2.ll index a9d71ea9..c0030f31 100644 --- a/builtins/target-sse2-x2.ll +++ b/builtins/target-sse2-x2.ll @@ -429,10 +429,10 @@ load_and_broadcast(8, i16, 16) load_and_broadcast(8, i32, 32) load_and_broadcast(8, i64, 64) -load_masked(8, i8, 8, 1) -load_masked(8, i16, 16, 2) -load_masked(8, i32, 32, 4) -load_masked(8, i64, 64, 8) +masked_load(8, i8, 8, 1) +masked_load(8, i16, 16, 2) +masked_load(8, i32, 32, 4) +masked_load(8, i64, 64, 8) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; gather/scatter diff --git a/builtins/target-sse2.ll b/builtins/target-sse2.ll index 1a297199..8d9911d8 100644 --- a/builtins/target-sse2.ll +++ b/builtins/target-sse2.ll @@ -556,10 +556,10 @@ load_and_broadcast(4, i16, 16) load_and_broadcast(4, i32, 32) load_and_broadcast(4, i64, 64) -load_masked(4, i8, 8, 1) -load_masked(4, i16, 16, 2) -load_masked(4, i32, 32, 4) -load_masked(4, i64, 64, 8) +masked_load(4, i8, 8, 1) +masked_load(4, i16, 16, 2) +masked_load(4, i32, 32, 4) +masked_load(4, i64, 64, 8) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; gather/scatter diff --git a/builtins/target-sse4-x2.ll b/builtins/target-sse4-x2.ll index 764f8613..b7cd36ec 100644 --- a/builtins/target-sse4-x2.ll +++ b/builtins/target-sse4-x2.ll @@ -356,10 +356,10 @@ load_and_broadcast(8, i16, 16) load_and_broadcast(8, i32, 32) load_and_broadcast(8, i64, 64) -load_masked(8, i8, 8, 1) -load_masked(8, i16, 16, 2) -load_masked(8, i32, 32, 4) -load_masked(8, i64, 64, 8) +masked_load(8, i8, 8, 1) +masked_load(8, i16, 16, 2) +masked_load(8, i32, 32, 4) +masked_load(8, i64, 64, 8) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; gather/scatter diff --git a/builtins/target-sse4.ll b/builtins/target-sse4.ll index 7eadde4b..68ff49d9 100644 --- a/builtins/target-sse4.ll +++ b/builtins/target-sse4.ll @@ -455,10 +455,10 @@ load_and_broadcast(4, i16, 16) load_and_broadcast(4, i32, 32) load_and_broadcast(4, i64, 64) -load_masked(4, i8, 8, 1) -load_masked(4, i16, 16, 2) -load_masked(4, i32, 32, 4) -load_masked(4, i64, 64, 8) +masked_load(4, i8, 8, 1) +masked_load(4, i16, 16, 2) +masked_load(4, i32, 32, 4) +masked_load(4, i64, 64, 8) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; gather/scatter diff --git a/builtins/util.m4 b/builtins/util.m4 index 8853e81c..5eaac67e 100644 --- a/builtins/util.m4 +++ b/builtins/util.m4 @@ -2213,8 +2213,8 @@ define <$1 x $2> @__load_and_broadcast_$3(i8 *, <$1 x i32> %mask) nounwind alway ;; $3: suffix for function name (32, 64, ...) ;; $4: alignment for elements of type $2 (4, 8, ...) -define(`load_masked', ` -define <$1 x $2> @__load_masked_$3(i8 *, <$1 x i32> %mask) nounwind alwaysinline { +define(`masked_load', ` +define <$1 x $2> @__masked_load_$3(i8 *, <$1 x i32> %mask) nounwind alwaysinline { entry: %mm = call i32 @__movmsk(<$1 x i32> %mask) diff --git a/opt.cpp b/opt.cpp index 17458a06..0259cfe5 100644 --- a/opt.cpp +++ b/opt.cpp @@ -89,6 +89,7 @@ static llvm::Pass *CreateGatherScatterImprovementsPass(); static llvm::Pass *CreateLowerGatherScatterPass(); static llvm::Pass *CreateLowerMaskedStorePass(); static llvm::Pass *CreateMaskedStoreOptPass(); +static llvm::Pass *CreateMaskedLoadOptPass(); static llvm::Pass *CreateIsCompileTimeConstantPass(bool isLastTry); static llvm::Pass *CreateMakeInternalFuncsStaticPass(); @@ -253,6 +254,7 @@ Optimize(llvm::Module *module, int optLevel) { if (!g->opt.disableMaskAllOnOptimizations) { optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateMaskedStoreOptPass()); + optPM.add(CreateMaskedLoadOptPass()); } optPM.add(llvm::createDeadInstEliminationPass()); @@ -290,6 +292,7 @@ Optimize(llvm::Module *module, int optLevel) { if (!g->opt.disableMaskAllOnOptimizations) { optPM.add(CreateIntrinsicsOptPass()); optPM.add(CreateMaskedStoreOptPass()); + optPM.add(CreateMaskedLoadOptPass()); } optPM.add(CreateLowerMaskedStorePass()); if (!g->opt.disableGatherScatterOptimizations) @@ -298,6 +301,9 @@ Optimize(llvm::Module *module, int optLevel) { optPM.add(CreateLowerMaskedStorePass()); optPM.add(CreateLowerGatherScatterPass()); } + if (!g->opt.disableMaskAllOnOptimizations) { + optPM.add(CreateMaskedLoadOptPass()); + } optPM.add(llvm::createFunctionInliningPass()); optPM.add(llvm::createConstantPropagationPass()); optPM.add(CreateIntrinsicsOptPass()); @@ -461,13 +467,14 @@ IntrinsicsOpt::IntrinsicsOpt() of the mask values in turn and concatenating them into a single integer. In other words, given the 4-wide mask: < 0xffffffff, 0, 0, 0xffffffff >, we have 0b1001 = 9. - - @todo This will break if we ever do 32-wide compilation, in which case - it don't be possible to distinguish between -1 for "don't know" and - "known and all bits on". */ static int lGetMask(llvm::Value *factor) { + /* FIXME: This will break if we ever do 32-wide compilation, in which case + it don't be possible to distinguish between -1 for "don't know" and + "known and all bits on". */ + assert(g->target.vectorWidth < 32); + llvm::ConstantVector *cv = llvm::dyn_cast(factor); if (cv) { int mask = 0; @@ -1344,6 +1351,105 @@ CreateMaskedStoreOptPass() { } +/////////////////////////////////////////////////////////////////////////// +// MaskedLoadOptPass + +/** Masked load improvements for the all on/all off mask cases. +*/ +class MaskedLoadOptPass : public llvm::BasicBlockPass { +public: + static char ID; + MaskedLoadOptPass() : BasicBlockPass(ID) { } + + const char *getPassName() const { return "Masked Load Improvements"; } + bool runOnBasicBlock(llvm::BasicBlock &BB); +}; + + +char MaskedLoadOptPass::ID = 0; + +llvm::RegisterPass ml("masked-load-improvements", + "Masked Load Improvements Pass"); + +struct MLInfo { + MLInfo(const char *name, const int a) + : align(a) { + func = m->module->getFunction(name); + Assert(func != NULL); + } + llvm::Function *func; + const int align; +}; + + +bool +MaskedLoadOptPass::runOnBasicBlock(llvm::BasicBlock &bb) { + MLInfo mlInfo[] = { + MLInfo("__masked_load_8", 1), + MLInfo("__masked_load_16", 2), + MLInfo("__masked_load_32", 4), + MLInfo("__masked_load_64", 8) + }; + + bool modifiedAny = false; + restart: + // Iterate over all of the instructions to look for one of the various + // masked load functions + for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) { + llvm::CallInst *callInst = llvm::dyn_cast(&*iter); + if (!callInst) + continue; + + llvm::Function *called = callInst->getCalledFunction(); + int nFuncs = sizeof(mlInfo) / sizeof(mlInfo[0]); + MLInfo *info = NULL; + for (int i = 0; i < nFuncs; ++i) { + if (mlInfo[i].func != NULL && called == mlInfo[i].func) { + info = &mlInfo[i]; + break; + } + } + if (info == NULL) + continue; + + // Got one; grab the operands + llvm::Value *ptr = callInst->getArgOperand(0); + llvm::Value *mask = callInst->getArgOperand(1); + int allOnMask = (1 << g->target.vectorWidth) - 1; + + int maskAsInt = lGetMask(mask); + if (maskAsInt == 0) { + // Zero mask - no-op, so replace the load with an undef value + llvm::ReplaceInstWithValue(iter->getParent()->getInstList(), + iter, llvm::UndefValue::get(callInst->getType())); + modifiedAny = true; + goto restart; + } + else if (maskAsInt == allOnMask) { + // The mask is all on, so turn this into a regular load + LLVM_TYPE_CONST llvm::Type *ptrType = + llvm::PointerType::get(callInst->getType(), 0); + ptr = new llvm::BitCastInst(ptr, ptrType, "ptr_cast_for_load", + callInst); + llvm::Instruction *load = + new llvm::LoadInst(ptr, callInst->getName(), false /* not volatile */, + info->align, (llvm::Instruction *)NULL); + lCopyMetadata(load, callInst); + llvm::ReplaceInstWithInst(callInst, load); + modifiedAny = true; + goto restart; + } + } + return modifiedAny; +} + + +static llvm::Pass * +CreateMaskedLoadOptPass() { + return new MaskedLoadOptPass; +} + + /////////////////////////////////////////////////////////////////////////// // LowerMaskedStorePass @@ -1901,21 +2007,21 @@ bool GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) { GatherImpInfo gInfo[] = { GatherImpInfo("__pseudo_gather_base_offsets32_8", "__load_and_broadcast_8", - "__load_masked_8", 1), + "__masked_load_8", 1), GatherImpInfo("__pseudo_gather_base_offsets32_16", "__load_and_broadcast_16", - "__load_masked_16", 2), + "__masked_load_16", 2), GatherImpInfo("__pseudo_gather_base_offsets32_32", "__load_and_broadcast_32", - "__load_masked_32", 4), + "__masked_load_32", 4), GatherImpInfo("__pseudo_gather_base_offsets32_64", "__load_and_broadcast_64", - "__load_masked_64", 8), + "__masked_load_64", 8), GatherImpInfo("__pseudo_gather_base_offsets64_8", "__load_and_broadcast_8", - "__load_masked_8", 1), + "__masked_load_8", 1), GatherImpInfo("__pseudo_gather_base_offsets64_16", "__load_and_broadcast_16", - "__load_masked_16", 2), + "__masked_load_16", 2), GatherImpInfo("__pseudo_gather_base_offsets64_32", "__load_and_broadcast_32", - "__load_masked_32", 4), + "__masked_load_32", 4), GatherImpInfo("__pseudo_gather_base_offsets64_64", "__load_and_broadcast_64", - "__load_masked_64", 8) + "__masked_load_64", 8) }; ScatterImpInfo sInfo[] = { ScatterImpInfo("__pseudo_scatter_base_offsets32_8", "__pseudo_masked_store_8", @@ -2100,11 +2206,11 @@ GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) { llvm::ArrayRef argArray(&args[0], &args[2]); llvm::Instruction *newCall = llvm::CallInst::Create(gatherInfo->loadMaskedFunc, argArray, - "load_masked", (llvm::Instruction *)NULL); + "masked_load", (llvm::Instruction *)NULL); #else llvm::Instruction *newCall = llvm::CallInst::Create(gatherInfo->loadMaskedFunc, &args[0], - &args[2], "load_masked"); + &args[2], "masked_load"); #endif lCopyMetadata(newCall, callInst); llvm::ReplaceInstWithInst(callInst, newCall); @@ -2420,8 +2526,8 @@ MakeInternalFuncsStaticPass::runOnModule(llvm::Module &module) { "__gather_elt64_i32", "__gather_elt64_i64", "__load_and_broadcast_8", "__load_and_broadcast_16", "__load_and_broadcast_32", "__load_and_broadcast_64", - "__load_masked_8", "__load_masked_16", - "__load_masked_32", "__load_masked_64", + "__masked_load_8", "__masked_load_16", + "__masked_load_32", "__masked_load_64", "__masked_store_8", "__masked_store_16", "__masked_store_32", "__masked_store_64", "__masked_store_blend_8", "__masked_store_blend_16",