AVX updates / improvements.

Add optimization patterns to detect and simplify masked loads and stores
  with the mask all on / all off.
Enable AVX for LLVM 3.0 builds (still generally hits bugs / unimplemented
  stuff on the LLVM side, but it's getting there).
This commit is contained in:
Matt Pharr
2011-07-25 07:41:37 +01:00
parent 0932dcd98b
commit 16be1d313e
2 changed files with 92 additions and 6 deletions

View File

@@ -91,7 +91,11 @@ static void usage(int ret) {
printf(" disable-gather-scatter-flattening\tDisable flattening when all lanes are on\n"); printf(" disable-gather-scatter-flattening\tDisable flattening when all lanes are on\n");
printf(" disable-uniform-memory-optimizations\tDisable uniform-based coherent memory access\n"); printf(" disable-uniform-memory-optimizations\tDisable uniform-based coherent memory access\n");
printf(" disable-masked-store-optimizations\tDisable lowering to regular stores when possible\n"); printf(" disable-masked-store-optimizations\tDisable lowering to regular stores when possible\n");
#if defined(LLVM_3_0) || defined(LLVM_3_0svn)
printf(" [--target={sse2,sse4,sse4x2,avx}] Select target ISA (SSE4 is default unless compiling for atom; then SSE2 is.)\n");
#else
printf(" [--target={sse2,sse4,sse4x2}] Select target ISA (SSE4 is default unless compiling for atom; then SSE2 is.)\n"); printf(" [--target={sse2,sse4,sse4x2}] Select target ISA (SSE4 is default unless compiling for atom; then SSE2 is.)\n");
#endif // LLVM 3.0
printf(" [--version]\t\t\t\tPrint ispc version\n"); printf(" [--version]\t\t\t\tPrint ispc version\n");
printf(" [--woff]\t\t\t\tDisable warnings\n"); printf(" [--woff]\t\t\t\tDisable warnings\n");
printf(" [--wno-perf]\t\t\tDon't issue warnings related to performance-related issues\n"); printf(" [--wno-perf]\t\t\tDon't issue warnings related to performance-related issues\n");
@@ -118,13 +122,13 @@ static void lDoTarget(const char *target) {
g->target.nativeVectorWidth = 4; g->target.nativeVectorWidth = 4;
g->target.vectorWidth = 8; g->target.vectorWidth = 8;
} }
#if 0 #if defined(LLVM_3_0) || defined(LLVM_3_0svn)
else if (!strcasecmp(target, "avx")) { else if (!strcasecmp(target, "avx")) {
g->target.isa = Target::AVX; g->target.isa = Target::AVX;
g->target.nativeVectorWidth = 8; g->target.nativeVectorWidth = 8;
g->target.vectorWidth = 8; g->target.vectorWidth = 8;
} }
#endif #endif // LLVM 3.0
else else
usage(1); usage(1);
} }

90
opt.cpp
View File

@@ -304,6 +304,7 @@ Optimize(llvm::Module *module, int optLevel) {
true /* simplify lib calls */, true /* simplify lib calls */,
false /* may have exceptions */, false /* may have exceptions */,
llvm::createFunctionInliningPass()); llvm::createFunctionInliningPass());
#else #else
llvm::PassManagerBuilder builder; llvm::PassManagerBuilder builder;
builder.OptLevel = 3; builder.OptLevel = 3;
@@ -346,9 +347,9 @@ Optimize(llvm::Module *module, int optLevel) {
/** This is a relatively simple optimization pass that does a few small /** This is a relatively simple optimization pass that does a few small
optimizations that LLVM's x86 optimizer doesn't currently handle. optimizations that LLVM's x86 optimizer doesn't currently handle.
(Specifically, MOVMSK of a constant can be replaced with the (Specifically, MOVMSK of a constant can be replaced with the
corresponding constant value, and a BLENDVPS with either an 'all on' or corresponding constant value, BLENDVPS and AVX masked load/store with
'all off' blend factor can be replaced with the corredponding value of either an 'all on' or 'all off' masks can be replaced with simpler
one of the two operands. operations.
@todo The better thing to do would be to submit a patch to LLVM to get @todo The better thing to do would be to submit a patch to LLVM to get
these; they're presumably pretty simple patterns to match. these; they're presumably pretty simple patterns to match.
@@ -408,8 +409,13 @@ IntrinsicsOpt::IntrinsicsOpt()
llvm::Function *sseMovmsk = llvm::Function *sseMovmsk =
llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_sse_movmsk_ps); llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_sse_movmsk_ps);
maskInstructions.push_back(sseMovmsk); maskInstructions.push_back(sseMovmsk);
maskInstructions.push_back(m->module->getFunction("llvm.x86.avx.movmsk.ps"));
maskInstructions.push_back(m->module->getFunction("__movmsk")); maskInstructions.push_back(m->module->getFunction("__movmsk"));
#if defined(LLVM_3_0) || defined(LLVM_3_0svn)
llvm::Function *avxMovmsk =
llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_avx_movmsk_ps_256);
assert(avxMovmsk != NULL);
maskInstructions.push_back(avxMovmsk);
#endif
// And all of the blend instructions // And all of the blend instructions
blendInstructions.push_back(BlendInstruction( blendInstructions.push_back(BlendInstruction(
@@ -494,6 +500,19 @@ lIsUndef(llvm::Value *value) {
bool bool
IntrinsicsOpt::runOnBasicBlock(llvm::BasicBlock &bb) { IntrinsicsOpt::runOnBasicBlock(llvm::BasicBlock &bb) {
#if defined(LLVM_3_0) || defined(LLVM_3_0svn)
llvm::Function *avxMaskedLoad32 =
llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_avx_maskload_ps_256);
llvm::Function *avxMaskedLoad64 =
llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_avx_maskload_pd_256);
llvm::Function *avxMaskedStore32 =
llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_avx_maskstore_ps_256);
llvm::Function *avxMaskedStore64 =
llvm::Intrinsic::getDeclaration(m->module, llvm::Intrinsic::x86_avx_maskstore_pd_256);
assert(avxMaskedLoad32 != NULL && avxMaskedStore32 != NULL);
assert(avxMaskedLoad64 != NULL && avxMaskedStore64 != NULL);
#endif
bool modifiedAny = false; bool modifiedAny = false;
restart: restart:
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) { for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
@@ -564,6 +583,69 @@ IntrinsicsOpt::runOnBasicBlock(llvm::BasicBlock &bb) {
goto restart; goto restart;
} }
} }
#if defined(LLVM_3_0) || defined(LLVM_3_0svn)
else if (callInst->getCalledFunction() == avxMaskedLoad32 ||
callInst->getCalledFunction() == avxMaskedLoad64) {
llvm::Value *factor = callInst->getArgOperand(1);
int mask = lGetMask(factor);
if (mask == 0) {
// nothing being loaded, replace with undef value
llvm::Type *returnType = callInst->getType();
assert(llvm::isa<llvm::VectorType>(returnType));
llvm::Value *undefValue = llvm::UndefValue::get(returnType);
llvm::ReplaceInstWithValue(iter->getParent()->getInstList(),
iter, undefValue);
modifiedAny = true;
goto restart;
}
else if (mask == 0xff) {
// all lanes active; replace with a regular load
llvm::Type *returnType = callInst->getType();
assert(llvm::isa<llvm::VectorType>(returnType));
// cast the i8 * to the appropriate type
llvm::Value *castPtr =
new llvm::BitCastInst(callInst->getArgOperand(0),
llvm::PointerType::get(returnType, 0),
"ptr2vec", callInst);
lCopyMetadata(castPtr, callInst);
llvm::Instruction *loadInst =
new llvm::LoadInst(castPtr, "load", false /* not volatile */,
0 /* align */, (llvm::Instruction *)NULL);
lCopyMetadata(loadInst, callInst);
llvm::ReplaceInstWithInst(callInst, loadInst);
modifiedAny = true;
goto restart;
}
}
else if (callInst->getCalledFunction() == avxMaskedStore32 ||
callInst->getCalledFunction() == avxMaskedStore64) {
// NOTE: mask is the 2nd parameter, not the 3rd one!!
llvm::Value *factor = callInst->getArgOperand(1);
int mask = lGetMask(factor);
if (mask == 0) {
// nothing actually being stored, just remove the inst
callInst->eraseFromParent();
modifiedAny = true;
goto restart;
}
else if (mask == 0xff) {
// all lanes storing, so replace with a regular store
llvm::Value *rvalue = callInst->getArgOperand(1);
llvm::Type *storeType = rvalue->getType();
llvm::Value *castPtr =
new llvm::BitCastInst(callInst->getArgOperand(0),
llvm::PointerType::get(storeType, 0),
"ptr2vec", callInst);
lCopyMetadata(castPtr, callInst);
llvm::Instruction *storeInst =
new llvm::StoreInst(rvalue, castPtr, (llvm::Instruction *)NULL);
lCopyMetadata(storeInst, callInst);
llvm::ReplaceInstWithInst(callInst, storeInst);
modifiedAny = true;
goto restart;
}
}
#endif
} }
return modifiedAny; return modifiedAny;
} }