Added masked load optimization pass.
This pass handles the "all on" and "all off" mask cases appropriately. Also renamed load_masked stuff in built-ins to masked_load for consistency with masked_store.
This commit is contained in:
@@ -385,13 +385,13 @@ load_and_broadcast(16, i32, 32)
|
||||
load_and_broadcast(16, i64, 64)
|
||||
|
||||
; no masked load instruction for i8 and i16 types??
|
||||
load_masked(16, i8, 8, 1)
|
||||
load_masked(16, i16, 16, 2)
|
||||
masked_load(16, i8, 8, 1)
|
||||
masked_load(16, i16, 16, 2)
|
||||
|
||||
declare <8 x float> @llvm.x86.avx.maskload.ps.256(i8 *, <8 x float> %mask)
|
||||
declare <4 x double> @llvm.x86.avx.maskload.pd.256(i8 *, <4 x double> %mask)
|
||||
|
||||
define <16 x i32> @__load_masked_32(i8 *, <16 x i32> %mask) nounwind alwaysinline {
|
||||
define <16 x i32> @__masked_load_32(i8 *, <16 x i32> %mask) nounwind alwaysinline {
|
||||
%floatmask = bitcast <16 x i32> %mask to <16 x float>
|
||||
%mask0 = shufflevector <16 x float> %floatmask, <16 x float> undef,
|
||||
<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
|
||||
@@ -409,7 +409,7 @@ define <16 x i32> @__load_masked_32(i8 *, <16 x i32> %mask) nounwind alwaysinlin
|
||||
}
|
||||
|
||||
|
||||
define <16 x i64> @__load_masked_64(i8 *, <16 x i32> %mask) nounwind alwaysinline {
|
||||
define <16 x i64> @__masked_load_64(i8 *, <16 x i32> %mask) nounwind alwaysinline {
|
||||
; double up masks, bitcast to doubles
|
||||
%mask0 = shufflevector <16 x i32> %mask, <16 x i32> undef,
|
||||
<8 x i32> <i32 0, i32 0, i32 1, i32 1, i32 2, i32 2, i32 3, i32 3>
|
||||
|
||||
@@ -366,13 +366,13 @@ load_and_broadcast(8, i32, 32)
|
||||
load_and_broadcast(8, i64, 64)
|
||||
|
||||
; no masked load instruction for i8 and i16 types??
|
||||
load_masked(8, i8, 8, 1)
|
||||
load_masked(8, i16, 16, 2)
|
||||
masked_load(8, i8, 8, 1)
|
||||
masked_load(8, i16, 16, 2)
|
||||
|
||||
declare <8 x float> @llvm.x86.avx.maskload.ps.256(i8 *, <8 x float> %mask)
|
||||
declare <4 x double> @llvm.x86.avx.maskload.pd.256(i8 *, <4 x double> %mask)
|
||||
|
||||
define <8 x i32> @__load_masked_32(i8 *, <8 x i32> %mask) nounwind alwaysinline {
|
||||
define <8 x i32> @__masked_load_32(i8 *, <8 x i32> %mask) nounwind alwaysinline {
|
||||
%floatmask = bitcast <8 x i32> %mask to <8 x float>
|
||||
%floatval = call <8 x float> @llvm.x86.avx.maskload.ps.256(i8 * %0, <8 x float> %floatmask)
|
||||
%retval = bitcast <8 x float> %floatval to <8 x i32>
|
||||
@@ -380,7 +380,7 @@ define <8 x i32> @__load_masked_32(i8 *, <8 x i32> %mask) nounwind alwaysinline
|
||||
}
|
||||
|
||||
|
||||
define <8 x i64> @__load_masked_64(i8 *, <8 x i32> %mask) nounwind alwaysinline {
|
||||
define <8 x i64> @__masked_load_64(i8 *, <8 x i32> %mask) nounwind alwaysinline {
|
||||
; double up masks, bitcast to doubles
|
||||
%mask0 = shufflevector <8 x i32> %mask, <8 x i32> undef,
|
||||
<8 x i32> <i32 0, i32 0, i32 1, i32 1, i32 2, i32 2, i32 3, i32 3>
|
||||
|
||||
@@ -175,10 +175,10 @@ load_and_broadcast(WIDTH, i16, 16)
|
||||
load_and_broadcast(WIDTH, i32, 32)
|
||||
load_and_broadcast(WIDTH, i64, 64)
|
||||
|
||||
declare <WIDTH x i8> @__load_masked_8(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i16> @__load_masked_16(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i32> @__load_masked_32(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i64> @__load_masked_64(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i8> @__masked_load_8(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i16> @__masked_load_16(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i32> @__masked_load_32(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
declare <WIDTH x i64> @__masked_load_64(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
|
||||
|
||||
declare void @__masked_store_8(<WIDTH x i8>* nocapture, <WIDTH x i8>,
|
||||
<WIDTH x i1>) nounwind
|
||||
|
||||
@@ -429,10 +429,10 @@ load_and_broadcast(8, i16, 16)
|
||||
load_and_broadcast(8, i32, 32)
|
||||
load_and_broadcast(8, i64, 64)
|
||||
|
||||
load_masked(8, i8, 8, 1)
|
||||
load_masked(8, i16, 16, 2)
|
||||
load_masked(8, i32, 32, 4)
|
||||
load_masked(8, i64, 64, 8)
|
||||
masked_load(8, i8, 8, 1)
|
||||
masked_load(8, i16, 16, 2)
|
||||
masked_load(8, i32, 32, 4)
|
||||
masked_load(8, i64, 64, 8)
|
||||
|
||||
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
;; gather/scatter
|
||||
|
||||
@@ -556,10 +556,10 @@ load_and_broadcast(4, i16, 16)
|
||||
load_and_broadcast(4, i32, 32)
|
||||
load_and_broadcast(4, i64, 64)
|
||||
|
||||
load_masked(4, i8, 8, 1)
|
||||
load_masked(4, i16, 16, 2)
|
||||
load_masked(4, i32, 32, 4)
|
||||
load_masked(4, i64, 64, 8)
|
||||
masked_load(4, i8, 8, 1)
|
||||
masked_load(4, i16, 16, 2)
|
||||
masked_load(4, i32, 32, 4)
|
||||
masked_load(4, i64, 64, 8)
|
||||
|
||||
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
;; gather/scatter
|
||||
|
||||
@@ -356,10 +356,10 @@ load_and_broadcast(8, i16, 16)
|
||||
load_and_broadcast(8, i32, 32)
|
||||
load_and_broadcast(8, i64, 64)
|
||||
|
||||
load_masked(8, i8, 8, 1)
|
||||
load_masked(8, i16, 16, 2)
|
||||
load_masked(8, i32, 32, 4)
|
||||
load_masked(8, i64, 64, 8)
|
||||
masked_load(8, i8, 8, 1)
|
||||
masked_load(8, i16, 16, 2)
|
||||
masked_load(8, i32, 32, 4)
|
||||
masked_load(8, i64, 64, 8)
|
||||
|
||||
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
;; gather/scatter
|
||||
|
||||
@@ -455,10 +455,10 @@ load_and_broadcast(4, i16, 16)
|
||||
load_and_broadcast(4, i32, 32)
|
||||
load_and_broadcast(4, i64, 64)
|
||||
|
||||
load_masked(4, i8, 8, 1)
|
||||
load_masked(4, i16, 16, 2)
|
||||
load_masked(4, i32, 32, 4)
|
||||
load_masked(4, i64, 64, 8)
|
||||
masked_load(4, i8, 8, 1)
|
||||
masked_load(4, i16, 16, 2)
|
||||
masked_load(4, i32, 32, 4)
|
||||
masked_load(4, i64, 64, 8)
|
||||
|
||||
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
|
||||
;; gather/scatter
|
||||
|
||||
@@ -2213,8 +2213,8 @@ define <$1 x $2> @__load_and_broadcast_$3(i8 *, <$1 x i32> %mask) nounwind alway
|
||||
;; $3: suffix for function name (32, 64, ...)
|
||||
;; $4: alignment for elements of type $2 (4, 8, ...)
|
||||
|
||||
define(`load_masked', `
|
||||
define <$1 x $2> @__load_masked_$3(i8 *, <$1 x i32> %mask) nounwind alwaysinline {
|
||||
define(`masked_load', `
|
||||
define <$1 x $2> @__masked_load_$3(i8 *, <$1 x i32> %mask) nounwind alwaysinline {
|
||||
entry:
|
||||
%mm = call i32 @__movmsk(<$1 x i32> %mask)
|
||||
|
||||
|
||||
138
opt.cpp
138
opt.cpp
@@ -89,6 +89,7 @@ static llvm::Pass *CreateGatherScatterImprovementsPass();
|
||||
static llvm::Pass *CreateLowerGatherScatterPass();
|
||||
static llvm::Pass *CreateLowerMaskedStorePass();
|
||||
static llvm::Pass *CreateMaskedStoreOptPass();
|
||||
static llvm::Pass *CreateMaskedLoadOptPass();
|
||||
static llvm::Pass *CreateIsCompileTimeConstantPass(bool isLastTry);
|
||||
static llvm::Pass *CreateMakeInternalFuncsStaticPass();
|
||||
|
||||
@@ -253,6 +254,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
if (!g->opt.disableMaskAllOnOptimizations) {
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateMaskedStoreOptPass());
|
||||
optPM.add(CreateMaskedLoadOptPass());
|
||||
}
|
||||
optPM.add(llvm::createDeadInstEliminationPass());
|
||||
|
||||
@@ -290,6 +292,7 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
if (!g->opt.disableMaskAllOnOptimizations) {
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
optPM.add(CreateMaskedStoreOptPass());
|
||||
optPM.add(CreateMaskedLoadOptPass());
|
||||
}
|
||||
optPM.add(CreateLowerMaskedStorePass());
|
||||
if (!g->opt.disableGatherScatterOptimizations)
|
||||
@@ -298,6 +301,9 @@ Optimize(llvm::Module *module, int optLevel) {
|
||||
optPM.add(CreateLowerMaskedStorePass());
|
||||
optPM.add(CreateLowerGatherScatterPass());
|
||||
}
|
||||
if (!g->opt.disableMaskAllOnOptimizations) {
|
||||
optPM.add(CreateMaskedLoadOptPass());
|
||||
}
|
||||
optPM.add(llvm::createFunctionInliningPass());
|
||||
optPM.add(llvm::createConstantPropagationPass());
|
||||
optPM.add(CreateIntrinsicsOptPass());
|
||||
@@ -461,13 +467,14 @@ IntrinsicsOpt::IntrinsicsOpt()
|
||||
of the mask values in turn and concatenating them into a single integer.
|
||||
In other words, given the 4-wide mask: < 0xffffffff, 0, 0, 0xffffffff >,
|
||||
we have 0b1001 = 9.
|
||||
|
||||
@todo This will break if we ever do 32-wide compilation, in which case
|
||||
it don't be possible to distinguish between -1 for "don't know" and
|
||||
"known and all bits on".
|
||||
*/
|
||||
static int
|
||||
lGetMask(llvm::Value *factor) {
|
||||
/* FIXME: This will break if we ever do 32-wide compilation, in which case
|
||||
it don't be possible to distinguish between -1 for "don't know" and
|
||||
"known and all bits on". */
|
||||
assert(g->target.vectorWidth < 32);
|
||||
|
||||
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(factor);
|
||||
if (cv) {
|
||||
int mask = 0;
|
||||
@@ -1344,6 +1351,105 @@ CreateMaskedStoreOptPass() {
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// MaskedLoadOptPass
|
||||
|
||||
/** Masked load improvements for the all on/all off mask cases.
|
||||
*/
|
||||
class MaskedLoadOptPass : public llvm::BasicBlockPass {
|
||||
public:
|
||||
static char ID;
|
||||
MaskedLoadOptPass() : BasicBlockPass(ID) { }
|
||||
|
||||
const char *getPassName() const { return "Masked Load Improvements"; }
|
||||
bool runOnBasicBlock(llvm::BasicBlock &BB);
|
||||
};
|
||||
|
||||
|
||||
char MaskedLoadOptPass::ID = 0;
|
||||
|
||||
llvm::RegisterPass<MaskedLoadOptPass> ml("masked-load-improvements",
|
||||
"Masked Load Improvements Pass");
|
||||
|
||||
struct MLInfo {
|
||||
MLInfo(const char *name, const int a)
|
||||
: align(a) {
|
||||
func = m->module->getFunction(name);
|
||||
Assert(func != NULL);
|
||||
}
|
||||
llvm::Function *func;
|
||||
const int align;
|
||||
};
|
||||
|
||||
|
||||
bool
|
||||
MaskedLoadOptPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
MLInfo mlInfo[] = {
|
||||
MLInfo("__masked_load_8", 1),
|
||||
MLInfo("__masked_load_16", 2),
|
||||
MLInfo("__masked_load_32", 4),
|
||||
MLInfo("__masked_load_64", 8)
|
||||
};
|
||||
|
||||
bool modifiedAny = false;
|
||||
restart:
|
||||
// Iterate over all of the instructions to look for one of the various
|
||||
// masked load functions
|
||||
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
|
||||
llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
|
||||
if (!callInst)
|
||||
continue;
|
||||
|
||||
llvm::Function *called = callInst->getCalledFunction();
|
||||
int nFuncs = sizeof(mlInfo) / sizeof(mlInfo[0]);
|
||||
MLInfo *info = NULL;
|
||||
for (int i = 0; i < nFuncs; ++i) {
|
||||
if (mlInfo[i].func != NULL && called == mlInfo[i].func) {
|
||||
info = &mlInfo[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (info == NULL)
|
||||
continue;
|
||||
|
||||
// Got one; grab the operands
|
||||
llvm::Value *ptr = callInst->getArgOperand(0);
|
||||
llvm::Value *mask = callInst->getArgOperand(1);
|
||||
int allOnMask = (1 << g->target.vectorWidth) - 1;
|
||||
|
||||
int maskAsInt = lGetMask(mask);
|
||||
if (maskAsInt == 0) {
|
||||
// Zero mask - no-op, so replace the load with an undef value
|
||||
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
|
||||
iter, llvm::UndefValue::get(callInst->getType()));
|
||||
modifiedAny = true;
|
||||
goto restart;
|
||||
}
|
||||
else if (maskAsInt == allOnMask) {
|
||||
// The mask is all on, so turn this into a regular load
|
||||
LLVM_TYPE_CONST llvm::Type *ptrType =
|
||||
llvm::PointerType::get(callInst->getType(), 0);
|
||||
ptr = new llvm::BitCastInst(ptr, ptrType, "ptr_cast_for_load",
|
||||
callInst);
|
||||
llvm::Instruction *load =
|
||||
new llvm::LoadInst(ptr, callInst->getName(), false /* not volatile */,
|
||||
info->align, (llvm::Instruction *)NULL);
|
||||
lCopyMetadata(load, callInst);
|
||||
llvm::ReplaceInstWithInst(callInst, load);
|
||||
modifiedAny = true;
|
||||
goto restart;
|
||||
}
|
||||
}
|
||||
return modifiedAny;
|
||||
}
|
||||
|
||||
|
||||
static llvm::Pass *
|
||||
CreateMaskedLoadOptPass() {
|
||||
return new MaskedLoadOptPass;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// LowerMaskedStorePass
|
||||
|
||||
@@ -1901,21 +2007,21 @@ bool
|
||||
GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
GatherImpInfo gInfo[] = {
|
||||
GatherImpInfo("__pseudo_gather_base_offsets32_8", "__load_and_broadcast_8",
|
||||
"__load_masked_8", 1),
|
||||
"__masked_load_8", 1),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets32_16", "__load_and_broadcast_16",
|
||||
"__load_masked_16", 2),
|
||||
"__masked_load_16", 2),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets32_32", "__load_and_broadcast_32",
|
||||
"__load_masked_32", 4),
|
||||
"__masked_load_32", 4),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets32_64", "__load_and_broadcast_64",
|
||||
"__load_masked_64", 8),
|
||||
"__masked_load_64", 8),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets64_8", "__load_and_broadcast_8",
|
||||
"__load_masked_8", 1),
|
||||
"__masked_load_8", 1),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets64_16", "__load_and_broadcast_16",
|
||||
"__load_masked_16", 2),
|
||||
"__masked_load_16", 2),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets64_32", "__load_and_broadcast_32",
|
||||
"__load_masked_32", 4),
|
||||
"__masked_load_32", 4),
|
||||
GatherImpInfo("__pseudo_gather_base_offsets64_64", "__load_and_broadcast_64",
|
||||
"__load_masked_64", 8)
|
||||
"__masked_load_64", 8)
|
||||
};
|
||||
ScatterImpInfo sInfo[] = {
|
||||
ScatterImpInfo("__pseudo_scatter_base_offsets32_8", "__pseudo_masked_store_8",
|
||||
@@ -2100,11 +2206,11 @@ GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
|
||||
llvm::ArrayRef<llvm::Value *> argArray(&args[0], &args[2]);
|
||||
llvm::Instruction *newCall =
|
||||
llvm::CallInst::Create(gatherInfo->loadMaskedFunc, argArray,
|
||||
"load_masked", (llvm::Instruction *)NULL);
|
||||
"masked_load", (llvm::Instruction *)NULL);
|
||||
#else
|
||||
llvm::Instruction *newCall =
|
||||
llvm::CallInst::Create(gatherInfo->loadMaskedFunc, &args[0],
|
||||
&args[2], "load_masked");
|
||||
&args[2], "masked_load");
|
||||
#endif
|
||||
lCopyMetadata(newCall, callInst);
|
||||
llvm::ReplaceInstWithInst(callInst, newCall);
|
||||
@@ -2420,8 +2526,8 @@ MakeInternalFuncsStaticPass::runOnModule(llvm::Module &module) {
|
||||
"__gather_elt64_i32", "__gather_elt64_i64",
|
||||
"__load_and_broadcast_8", "__load_and_broadcast_16",
|
||||
"__load_and_broadcast_32", "__load_and_broadcast_64",
|
||||
"__load_masked_8", "__load_masked_16",
|
||||
"__load_masked_32", "__load_masked_64",
|
||||
"__masked_load_8", "__masked_load_16",
|
||||
"__masked_load_32", "__masked_load_64",
|
||||
"__masked_store_8", "__masked_store_16",
|
||||
"__masked_store_32", "__masked_store_64",
|
||||
"__masked_store_blend_8", "__masked_store_blend_16",
|
||||
|
||||
Reference in New Issue
Block a user