Added masked load optimization pass.

This pass handles the "all on" and "all off" mask cases appropriately.

Also renamed load_masked stuff in built-ins to masked_load for consistency with
masked_store.
This commit is contained in:
Matt Pharr
2012-01-04 11:51:26 -08:00
parent 75f18c7c66
commit 562d61caff
9 changed files with 152 additions and 46 deletions

View File

@@ -385,13 +385,13 @@ load_and_broadcast(16, i32, 32)
load_and_broadcast(16, i64, 64)
; no masked load instruction for i8 and i16 types??
load_masked(16, i8, 8, 1)
load_masked(16, i16, 16, 2)
masked_load(16, i8, 8, 1)
masked_load(16, i16, 16, 2)
declare <8 x float> @llvm.x86.avx.maskload.ps.256(i8 *, <8 x float> %mask)
declare <4 x double> @llvm.x86.avx.maskload.pd.256(i8 *, <4 x double> %mask)
define <16 x i32> @__load_masked_32(i8 *, <16 x i32> %mask) nounwind alwaysinline {
define <16 x i32> @__masked_load_32(i8 *, <16 x i32> %mask) nounwind alwaysinline {
%floatmask = bitcast <16 x i32> %mask to <16 x float>
%mask0 = shufflevector <16 x float> %floatmask, <16 x float> undef,
<8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
@@ -409,7 +409,7 @@ define <16 x i32> @__load_masked_32(i8 *, <16 x i32> %mask) nounwind alwaysinlin
}
define <16 x i64> @__load_masked_64(i8 *, <16 x i32> %mask) nounwind alwaysinline {
define <16 x i64> @__masked_load_64(i8 *, <16 x i32> %mask) nounwind alwaysinline {
; double up masks, bitcast to doubles
%mask0 = shufflevector <16 x i32> %mask, <16 x i32> undef,
<8 x i32> <i32 0, i32 0, i32 1, i32 1, i32 2, i32 2, i32 3, i32 3>

View File

@@ -366,13 +366,13 @@ load_and_broadcast(8, i32, 32)
load_and_broadcast(8, i64, 64)
; no masked load instruction for i8 and i16 types??
load_masked(8, i8, 8, 1)
load_masked(8, i16, 16, 2)
masked_load(8, i8, 8, 1)
masked_load(8, i16, 16, 2)
declare <8 x float> @llvm.x86.avx.maskload.ps.256(i8 *, <8 x float> %mask)
declare <4 x double> @llvm.x86.avx.maskload.pd.256(i8 *, <4 x double> %mask)
define <8 x i32> @__load_masked_32(i8 *, <8 x i32> %mask) nounwind alwaysinline {
define <8 x i32> @__masked_load_32(i8 *, <8 x i32> %mask) nounwind alwaysinline {
%floatmask = bitcast <8 x i32> %mask to <8 x float>
%floatval = call <8 x float> @llvm.x86.avx.maskload.ps.256(i8 * %0, <8 x float> %floatmask)
%retval = bitcast <8 x float> %floatval to <8 x i32>
@@ -380,7 +380,7 @@ define <8 x i32> @__load_masked_32(i8 *, <8 x i32> %mask) nounwind alwaysinline
}
define <8 x i64> @__load_masked_64(i8 *, <8 x i32> %mask) nounwind alwaysinline {
define <8 x i64> @__masked_load_64(i8 *, <8 x i32> %mask) nounwind alwaysinline {
; double up masks, bitcast to doubles
%mask0 = shufflevector <8 x i32> %mask, <8 x i32> undef,
<8 x i32> <i32 0, i32 0, i32 1, i32 1, i32 2, i32 2, i32 3, i32 3>

View File

@@ -175,10 +175,10 @@ load_and_broadcast(WIDTH, i16, 16)
load_and_broadcast(WIDTH, i32, 32)
load_and_broadcast(WIDTH, i64, 64)
declare <WIDTH x i8> @__load_masked_8(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i16> @__load_masked_16(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i32> @__load_masked_32(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i64> @__load_masked_64(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i8> @__masked_load_8(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i16> @__masked_load_16(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i32> @__masked_load_32(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare <WIDTH x i64> @__masked_load_64(i8 * nocapture, <WIDTH x i1> %mask) nounwind readonly
declare void @__masked_store_8(<WIDTH x i8>* nocapture, <WIDTH x i8>,
<WIDTH x i1>) nounwind

View File

@@ -429,10 +429,10 @@ load_and_broadcast(8, i16, 16)
load_and_broadcast(8, i32, 32)
load_and_broadcast(8, i64, 64)
load_masked(8, i8, 8, 1)
load_masked(8, i16, 16, 2)
load_masked(8, i32, 32, 4)
load_masked(8, i64, 64, 8)
masked_load(8, i8, 8, 1)
masked_load(8, i16, 16, 2)
masked_load(8, i32, 32, 4)
masked_load(8, i64, 64, 8)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; gather/scatter

View File

@@ -556,10 +556,10 @@ load_and_broadcast(4, i16, 16)
load_and_broadcast(4, i32, 32)
load_and_broadcast(4, i64, 64)
load_masked(4, i8, 8, 1)
load_masked(4, i16, 16, 2)
load_masked(4, i32, 32, 4)
load_masked(4, i64, 64, 8)
masked_load(4, i8, 8, 1)
masked_load(4, i16, 16, 2)
masked_load(4, i32, 32, 4)
masked_load(4, i64, 64, 8)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; gather/scatter

View File

@@ -356,10 +356,10 @@ load_and_broadcast(8, i16, 16)
load_and_broadcast(8, i32, 32)
load_and_broadcast(8, i64, 64)
load_masked(8, i8, 8, 1)
load_masked(8, i16, 16, 2)
load_masked(8, i32, 32, 4)
load_masked(8, i64, 64, 8)
masked_load(8, i8, 8, 1)
masked_load(8, i16, 16, 2)
masked_load(8, i32, 32, 4)
masked_load(8, i64, 64, 8)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; gather/scatter

View File

@@ -455,10 +455,10 @@ load_and_broadcast(4, i16, 16)
load_and_broadcast(4, i32, 32)
load_and_broadcast(4, i64, 64)
load_masked(4, i8, 8, 1)
load_masked(4, i16, 16, 2)
load_masked(4, i32, 32, 4)
load_masked(4, i64, 64, 8)
masked_load(4, i8, 8, 1)
masked_load(4, i16, 16, 2)
masked_load(4, i32, 32, 4)
masked_load(4, i64, 64, 8)
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;; gather/scatter

View File

@@ -2213,8 +2213,8 @@ define <$1 x $2> @__load_and_broadcast_$3(i8 *, <$1 x i32> %mask) nounwind alway
;; $3: suffix for function name (32, 64, ...)
;; $4: alignment for elements of type $2 (4, 8, ...)
define(`load_masked', `
define <$1 x $2> @__load_masked_$3(i8 *, <$1 x i32> %mask) nounwind alwaysinline {
define(`masked_load', `
define <$1 x $2> @__masked_load_$3(i8 *, <$1 x i32> %mask) nounwind alwaysinline {
entry:
%mm = call i32 @__movmsk(<$1 x i32> %mask)

138
opt.cpp
View File

@@ -89,6 +89,7 @@ static llvm::Pass *CreateGatherScatterImprovementsPass();
static llvm::Pass *CreateLowerGatherScatterPass();
static llvm::Pass *CreateLowerMaskedStorePass();
static llvm::Pass *CreateMaskedStoreOptPass();
static llvm::Pass *CreateMaskedLoadOptPass();
static llvm::Pass *CreateIsCompileTimeConstantPass(bool isLastTry);
static llvm::Pass *CreateMakeInternalFuncsStaticPass();
@@ -253,6 +254,7 @@ Optimize(llvm::Module *module, int optLevel) {
if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateMaskedStoreOptPass());
optPM.add(CreateMaskedLoadOptPass());
}
optPM.add(llvm::createDeadInstEliminationPass());
@@ -290,6 +292,7 @@ Optimize(llvm::Module *module, int optLevel) {
if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateMaskedStoreOptPass());
optPM.add(CreateMaskedLoadOptPass());
}
optPM.add(CreateLowerMaskedStorePass());
if (!g->opt.disableGatherScatterOptimizations)
@@ -298,6 +301,9 @@ Optimize(llvm::Module *module, int optLevel) {
optPM.add(CreateLowerMaskedStorePass());
optPM.add(CreateLowerGatherScatterPass());
}
if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateMaskedLoadOptPass());
}
optPM.add(llvm::createFunctionInliningPass());
optPM.add(llvm::createConstantPropagationPass());
optPM.add(CreateIntrinsicsOptPass());
@@ -461,13 +467,14 @@ IntrinsicsOpt::IntrinsicsOpt()
of the mask values in turn and concatenating them into a single integer.
In other words, given the 4-wide mask: < 0xffffffff, 0, 0, 0xffffffff >,
we have 0b1001 = 9.
@todo This will break if we ever do 32-wide compilation, in which case
it don't be possible to distinguish between -1 for "don't know" and
"known and all bits on".
*/
static int
lGetMask(llvm::Value *factor) {
/* FIXME: This will break if we ever do 32-wide compilation, in which case
it don't be possible to distinguish between -1 for "don't know" and
"known and all bits on". */
assert(g->target.vectorWidth < 32);
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(factor);
if (cv) {
int mask = 0;
@@ -1344,6 +1351,105 @@ CreateMaskedStoreOptPass() {
}
///////////////////////////////////////////////////////////////////////////
// MaskedLoadOptPass
/** Masked load improvements for the all on/all off mask cases.
*/
class MaskedLoadOptPass : public llvm::BasicBlockPass {
public:
static char ID;
MaskedLoadOptPass() : BasicBlockPass(ID) { }
const char *getPassName() const { return "Masked Load Improvements"; }
bool runOnBasicBlock(llvm::BasicBlock &BB);
};
char MaskedLoadOptPass::ID = 0;
llvm::RegisterPass<MaskedLoadOptPass> ml("masked-load-improvements",
"Masked Load Improvements Pass");
struct MLInfo {
MLInfo(const char *name, const int a)
: align(a) {
func = m->module->getFunction(name);
Assert(func != NULL);
}
llvm::Function *func;
const int align;
};
bool
MaskedLoadOptPass::runOnBasicBlock(llvm::BasicBlock &bb) {
MLInfo mlInfo[] = {
MLInfo("__masked_load_8", 1),
MLInfo("__masked_load_16", 2),
MLInfo("__masked_load_32", 4),
MLInfo("__masked_load_64", 8)
};
bool modifiedAny = false;
restart:
// Iterate over all of the instructions to look for one of the various
// masked load functions
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
llvm::CallInst *callInst = llvm::dyn_cast<llvm::CallInst>(&*iter);
if (!callInst)
continue;
llvm::Function *called = callInst->getCalledFunction();
int nFuncs = sizeof(mlInfo) / sizeof(mlInfo[0]);
MLInfo *info = NULL;
for (int i = 0; i < nFuncs; ++i) {
if (mlInfo[i].func != NULL && called == mlInfo[i].func) {
info = &mlInfo[i];
break;
}
}
if (info == NULL)
continue;
// Got one; grab the operands
llvm::Value *ptr = callInst->getArgOperand(0);
llvm::Value *mask = callInst->getArgOperand(1);
int allOnMask = (1 << g->target.vectorWidth) - 1;
int maskAsInt = lGetMask(mask);
if (maskAsInt == 0) {
// Zero mask - no-op, so replace the load with an undef value
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, llvm::UndefValue::get(callInst->getType()));
modifiedAny = true;
goto restart;
}
else if (maskAsInt == allOnMask) {
// The mask is all on, so turn this into a regular load
LLVM_TYPE_CONST llvm::Type *ptrType =
llvm::PointerType::get(callInst->getType(), 0);
ptr = new llvm::BitCastInst(ptr, ptrType, "ptr_cast_for_load",
callInst);
llvm::Instruction *load =
new llvm::LoadInst(ptr, callInst->getName(), false /* not volatile */,
info->align, (llvm::Instruction *)NULL);
lCopyMetadata(load, callInst);
llvm::ReplaceInstWithInst(callInst, load);
modifiedAny = true;
goto restart;
}
}
return modifiedAny;
}
static llvm::Pass *
CreateMaskedLoadOptPass() {
return new MaskedLoadOptPass;
}
///////////////////////////////////////////////////////////////////////////
// LowerMaskedStorePass
@@ -1901,21 +2007,21 @@ bool
GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
GatherImpInfo gInfo[] = {
GatherImpInfo("__pseudo_gather_base_offsets32_8", "__load_and_broadcast_8",
"__load_masked_8", 1),
"__masked_load_8", 1),
GatherImpInfo("__pseudo_gather_base_offsets32_16", "__load_and_broadcast_16",
"__load_masked_16", 2),
"__masked_load_16", 2),
GatherImpInfo("__pseudo_gather_base_offsets32_32", "__load_and_broadcast_32",
"__load_masked_32", 4),
"__masked_load_32", 4),
GatherImpInfo("__pseudo_gather_base_offsets32_64", "__load_and_broadcast_64",
"__load_masked_64", 8),
"__masked_load_64", 8),
GatherImpInfo("__pseudo_gather_base_offsets64_8", "__load_and_broadcast_8",
"__load_masked_8", 1),
"__masked_load_8", 1),
GatherImpInfo("__pseudo_gather_base_offsets64_16", "__load_and_broadcast_16",
"__load_masked_16", 2),
"__masked_load_16", 2),
GatherImpInfo("__pseudo_gather_base_offsets64_32", "__load_and_broadcast_32",
"__load_masked_32", 4),
"__masked_load_32", 4),
GatherImpInfo("__pseudo_gather_base_offsets64_64", "__load_and_broadcast_64",
"__load_masked_64", 8)
"__masked_load_64", 8)
};
ScatterImpInfo sInfo[] = {
ScatterImpInfo("__pseudo_scatter_base_offsets32_8", "__pseudo_masked_store_8",
@@ -2100,11 +2206,11 @@ GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
llvm::ArrayRef<llvm::Value *> argArray(&args[0], &args[2]);
llvm::Instruction *newCall =
llvm::CallInst::Create(gatherInfo->loadMaskedFunc, argArray,
"load_masked", (llvm::Instruction *)NULL);
"masked_load", (llvm::Instruction *)NULL);
#else
llvm::Instruction *newCall =
llvm::CallInst::Create(gatherInfo->loadMaskedFunc, &args[0],
&args[2], "load_masked");
&args[2], "masked_load");
#endif
lCopyMetadata(newCall, callInst);
llvm::ReplaceInstWithInst(callInst, newCall);
@@ -2420,8 +2526,8 @@ MakeInternalFuncsStaticPass::runOnModule(llvm::Module &module) {
"__gather_elt64_i32", "__gather_elt64_i64",
"__load_and_broadcast_8", "__load_and_broadcast_16",
"__load_and_broadcast_32", "__load_and_broadcast_64",
"__load_masked_8", "__load_masked_16",
"__load_masked_32", "__load_masked_64",
"__masked_load_8", "__masked_load_16",
"__masked_load_32", "__masked_load_64",
"__masked_store_8", "__masked_store_16",
"__masked_store_32", "__masked_store_64",
"__masked_store_blend_8", "__masked_store_blend_16",