Extract constant offsets from gather/scatter base+offsets offset vectors.

When we're able to turn a general gather/scatter into the "base + offsets"
form, we now try to extract out any constant components of the offsets and
then pass them as a separate parameter to the gather/scatter function
implementation.

We then in turn carefully emit code for the addressing calculation so that
these constant offsets match LLVM's patterns to detect this case, such that
we get the constant offsets directly encoded in the instruction's addressing
calculation in many cases, saving arithmetic instructions to do these
calculations.

Improves performance of stencil by ~15%.  Other workloads unchanged.
This commit is contained in:
Matt Pharr
2012-01-24 14:41:15 -08:00
parent 7be2c399b1
commit a5b7fca7e0
5 changed files with 614 additions and 355 deletions

435
opt.cpp
View File

@@ -205,6 +205,7 @@ lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1,
}
#if 0
static llvm::Instruction *
lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1,
llvm::Value *arg2, llvm::Value *arg3, const char *name,
@@ -218,7 +219,7 @@ lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1,
name, insertBefore);
#endif
}
#endif
static llvm::Instruction *
lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1,
@@ -234,6 +235,21 @@ lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1,
#endif
}
static llvm::Instruction *
lCallInst(llvm::Function *func, llvm::Value *arg0, llvm::Value *arg1,
llvm::Value *arg2, llvm::Value *arg3, llvm::Value *arg4,
llvm::Value *arg5, const char *name,
llvm::Instruction *insertBefore = NULL) {
llvm::Value *args[6] = { arg0, arg1, arg2, arg3, arg4, arg5 };
#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn)
llvm::ArrayRef<llvm::Value *> newArgArray(&args[0], &args[6]);
return llvm::CallInst::Create(func, newArgArray, name, insertBefore);
#else
return llvm::CallInst::Create(func, &newArgs[0], &newArgs[6],
name, insertBefore);
#endif
}
///////////////////////////////////////////////////////////////////////////
void
@@ -302,10 +318,13 @@ Optimize(llvm::Module *module, int optLevel) {
// Early optimizations to try to reduce the total amount of code to
// work with if we can
optPM.add(CreateDetectGSBaseOffsetsPass());
optPM.add(llvm::createReassociatePass());
optPM.add(llvm::createConstantPropagationPass());
optPM.add(llvm::createConstantPropagationPass());
optPM.add(llvm::createDeadInstEliminationPass());
optPM.add(llvm::createCFGSimplificationPass());
optPM.add(CreateDetectGSBaseOffsetsPass());
if (!g->opt.disableMaskAllOnOptimizations) {
optPM.add(CreateIntrinsicsOptPass());
optPM.add(CreateVSelMovmskOptPass());
@@ -314,11 +333,7 @@ Optimize(llvm::Module *module, int optLevel) {
}
optPM.add(llvm::createDeadInstEliminationPass());
optPM.add(llvm::createConstantPropagationPass());
optPM.add(llvm::createDeadInstEliminationPass());
// On to more serious optimizations
optPM.add(llvm::createCFGSimplificationPass());
if (runSROA)
optPM.add(llvm::createScalarReplAggregatesPass());
optPM.add(llvm::createInstructionCombiningPass());
@@ -1173,6 +1188,166 @@ lGetBasePtrAndOffsets(llvm::Value *ptrs, llvm::Value **offsets,
}
static llvm::Value *
lGetZeroOffsetVector(llvm::Value *origVec) {
if (origVec->getType() == LLVMTypes::Int32VectorType)
return LLVMInt32Vector((int32_t)0);
else
return LLVMInt64Vector((int64_t)0);
}
#if 0
static void
lPrint(llvm::Value *v, int indent = 0) {
if (llvm::isa<llvm::PHINode>(v))
return;
fprintf(stderr, "%*c", indent, ' ');
v->dump();
llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
if (inst != NULL) {
for (int i = 0; i < (int)inst->getNumOperands(); ++i) {
llvm::Value *op = inst->getOperand(i);
if (llvm::isa<llvm::Constant>(op) == false)
lPrint(op, indent+4);
}
}
}
#endif
/** Given a vector expression in vec, separate it into a compile-time
constant component and a variable component, returning the two parts in
*constOffset and *variableOffset. (It should be the case that the sum
of these two is exactly equal to the original vector.)
This routine only handles some (important) patterns; in some cases it
will fail and return components that are actually compile-time
constants in *variableOffset.
Finally, if there aren't any constant (or, respectivaly, variable)
components, the corresponding return value may be set to NULL.
*/
static void
lExtractConstantOffset(llvm::Value *vec, llvm::Value **constOffset,
llvm::Value **variableOffset,
llvm::Instruction *insertBefore) {
if (llvm::isa<llvm::ConstantVector>(vec) ||
llvm::isa<llvm::ConstantAggregateZero>(vec)) {
*constOffset = vec;
*variableOffset = NULL;
return;
}
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(vec);
if (sext != NULL) {
// Check the sext target.
llvm::Value *co, *vo;
lExtractConstantOffset(sext->getOperand(0), &co, &vo, insertBefore);
// make new sext instructions for the two parts
if (co == NULL)
*constOffset = NULL;
else
*constOffset = new llvm::SExtInst(co, sext->getType(),
"const_offset_sext", insertBefore);
if (vo == NULL)
*variableOffset = NULL;
else
*variableOffset = new llvm::SExtInst(vo, sext->getType(),
"variable_offset_sext",
insertBefore);
return;
}
// FIXME? handle bitcasts / type casts here
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(vec);
if (bop != NULL) {
llvm::Value *op0 = bop->getOperand(0);
llvm::Value *op1 = bop->getOperand(1);
llvm::Value *c0, *v0, *c1, *v1;
if (bop->getOpcode() == llvm::Instruction::Add) {
lExtractConstantOffset(op0, &c0, &v0, insertBefore);
lExtractConstantOffset(op1, &c1, &v1, insertBefore);
if (c0 == NULL)
*constOffset = c1;
else if (c1 == NULL)
*constOffset = c0;
else
*constOffset =
llvm::BinaryOperator::Create(llvm::Instruction::Add, c0, c1,
"const_op", insertBefore);
if (v0 == NULL)
*variableOffset = v1;
else if (v1 == NULL)
*variableOffset = v0;
else
*variableOffset =
llvm::BinaryOperator::Create(llvm::Instruction::Add, v0, v1,
"variable_op", insertBefore);
return;
}
else if (bop->getOpcode() == llvm::Instruction::Mul) {
lExtractConstantOffset(op0, &c0, &v0, insertBefore);
lExtractConstantOffset(op1, &c1, &v1, insertBefore);
// Given the product of constant and variable terms, we have:
// (c0 + v0) * (c1 + v1) == (c0 c1) + (v0 c1 + c0 v1 + v0 v1)
// Note that the first term is a constant and the last three are
// variable.
if (c0 != NULL && c1 != NULL)
*constOffset =
llvm::BinaryOperator::Create(llvm::Instruction::Mul, c0, c1,
"const_mul", insertBefore);
else
*constOffset = NULL;
llvm::Value *va = NULL, *vb = NULL, *vc = NULL;
if (v0 != NULL && c1 != NULL)
va = llvm::BinaryOperator::Create(llvm::Instruction::Mul, v0, c1,
"va_mul", insertBefore);
if (c0 != NULL && v1 != NULL)
vb = llvm::BinaryOperator::Create(llvm::Instruction::Mul, c0, v1,
"vb_mul", insertBefore);
if (v0 != NULL && v1 != NULL)
vc = llvm::BinaryOperator::Create(llvm::Instruction::Mul, v0, v1,
"vc_mul", insertBefore);
llvm::Value *vab = NULL;
if (va != NULL && vb != NULL)
vab = llvm::BinaryOperator::Create(llvm::Instruction::Add, va, vb,
"vab_add", insertBefore);
else if (va != NULL)
vab = va;
else
vab = vb;
if (vab != NULL && vc != NULL)
*variableOffset =
llvm::BinaryOperator::Create(llvm::Instruction::Add, vab, vc,
"vabc_add", insertBefore);
else if (vab != NULL)
*variableOffset = vab;
else
*variableOffset = vc;
return;
}
}
// Nothing matched, just return what we have as a variable component
*constOffset = NULL;
*variableOffset = vec;
}
/* Returns true if the given value is a constant vector of integers with
the value 2, 4, 8 in all of the elements. (Returns the splatted value
in *splat, if so). */
@@ -1277,6 +1452,123 @@ lExtractOffsetVector248Scale(llvm::Value **vec) {
return LLVMInt32(1);
}
#if 0
static llvm::Value *
lExtractUniforms(llvm::Value **vec, llvm::Instruction *insertBefore) {
fprintf(stderr, " lextract: ");
(*vec)->dump();
fprintf(stderr, "\n");
if (llvm::isa<llvm::ConstantVector>(*vec) ||
llvm::isa<llvm::ConstantAggregateZero>(*vec))
return NULL;
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(*vec);
if (sext != NULL) {
llvm::Value *sextOp = sext->getOperand(0);
// Check the sext target.
llvm::Value *unif = lExtractUniforms(&sextOp, insertBefore);
if (unif == NULL)
return NULL;
// make a new sext instruction so that we end up with the right
// type
*vec = new llvm::SExtInst(sextOp, sext->getType(), "offset_sext", sext);
return unif;
}
std::vector<llvm::PHINode *> phis;
if (LLVMVectorValuesAllEqual(*vec, g->target.vectorWidth, phis)) {
// FIXME: we may want to redo all of the expression here, in scalar
// form (if at all possible), for code quality...
llvm::Value *unif =
llvm::ExtractElementInst::Create(*vec, LLVMInt32(0),
"first_uniform", insertBefore);
*vec = NULL;
return unif;
}
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(*vec);
if (bop == NULL)
return NULL;
llvm::Value *op0 = bop->getOperand(0), *op1 = bop->getOperand(1);
if (bop->getOpcode() == llvm::Instruction::Add) {
llvm::Value *s0 = lExtractUniforms(&op0, insertBefore);
llvm::Value *s1 = lExtractUniforms(&op1, insertBefore);
if (s0 == NULL && s1 == NULL)
return NULL;
if (op0 == NULL)
*vec = op1;
else if (op1 == NULL)
*vec = op0;
else
*vec = llvm::BinaryOperator::Create(llvm::Instruction::Add,
op0, op1, "new_add", insertBefore);
if (s0 == NULL)
return s1;
else if (s1 == NULL)
return s0;
else
return llvm::BinaryOperator::Create(llvm::Instruction::Add, s0, s1,
"add_unif", insertBefore);
}
#if 0
else if (bop->getOpcode() == llvm::Instruction::Mul) {
// Check each operand for being one of the scale factors we care about.
int splat;
if (lIs248Splat(op0, &splat)) {
*vec = op1;
return LLVMInt32(splat);
}
else if (lIs248Splat(op1, &splat)) {
*vec = op0;
return LLVMInt32(splat);
}
else
return LLVMInt32(1);
}
#endif
else
return NULL;
}
static void
lExtractUniformsFromOffset(llvm::Value **basePtr, llvm::Value **offsetVector,
llvm::Value *offsetScale,
llvm::Instruction *insertBefore) {
#if 1
(*basePtr)->dump();
printf("\n");
(*offsetVector)->dump();
printf("\n");
offsetScale->dump();
printf("-----\n");
#endif
llvm::Value *uniformDelta = lExtractUniforms(offsetVector, insertBefore);
if (uniformDelta == NULL)
return;
llvm::Value *index[1] = { uniformDelta };
#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn)
llvm::ArrayRef<llvm::Value *> arrayRef(&index[0], &index[1]);
*basePtr = llvm::GetElementPtrInst::Create(*basePtr, arrayRef, "new_base",
insertBefore);
#else
*basePtr = llvm::GetElementPtrInst::Create(*basePtr, &index[0],
&index[1], "new_base",
insertBefore);
#endif
// this should only happen if we have only uniforms, but that in turn
// shouldn't be a gather/scatter!
Assert(*offsetVector != NULL);
}
#endif
struct GSInfo {
GSInfo(const char *pgFuncName, const char *pgboFuncName,
@@ -1367,7 +1659,24 @@ DetectGSBaseOffsetsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
// to the next instruction...
continue;
llvm::Value *offsetScale = lExtractOffsetVector248Scale(&offsetVector);
// Try to decompose the offset vector into a compile time constant
// component and a varying component. The constant component is
// passed as a separate parameter to the gather/scatter functions,
// which in turn allows their implementations to end up emitting
// x86 instructions with constant offsets encoded in them.
llvm::Value *constOffset, *variableOffset;
lExtractConstantOffset(offsetVector, &constOffset, &variableOffset,
callInst);
if (constOffset == NULL)
constOffset = lGetZeroOffsetVector(offsetVector);
if (variableOffset == NULL)
variableOffset = lGetZeroOffsetVector(offsetVector);
// See if the varying component is scaled by 2, 4, or 8. If so,
// extract that scale factor and rewrite variableOffset to remove
// it. (This also is pulled out so that we can match the scales by
// 2/4/8 offered by x86 addressing operators.)
llvm::Value *offsetScale = lExtractOffsetVector248Scale(&variableOffset);
// Cast the base pointer to a void *, since that's what the
// __pseudo_*_base_offsets_* functions want.
@@ -1386,11 +1695,15 @@ DetectGSBaseOffsetsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
// walk past the sext to get the i32 offset values and then
// call out to the corresponding 32-bit gather/scatter
// function.
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(offsetVector);
llvm::SExtInst *sext = llvm::dyn_cast<llvm::SExtInst>(variableOffset);
if (sext != NULL &&
sext->getOperand(0)->getType() == LLVMTypes::Int32VectorType) {
offsetVector = sext->getOperand(0);
variableOffset = sext->getOperand(0);
gatherScatterFunc = info->baseOffsets32Func;
if (constOffset->getType() != LLVMTypes::Int32VectorType)
constOffset =
new llvm::TruncInst(constOffset, LLVMTypes::Int32VectorType,
"trunc_const_offset", callInst);
}
}
@@ -1403,8 +1716,8 @@ DetectGSBaseOffsetsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
// the instruction isn't inserted into a basic block and that
// way we can then call ReplaceInstWithInst().
llvm::Instruction *newCall =
lCallInst(gatherScatterFunc, basePtr, offsetVector, offsetScale,
mask, "newgather", NULL);
lCallInst(gatherScatterFunc, basePtr, variableOffset, offsetScale,
constOffset, mask, "newgather", NULL);
lCopyMetadata(newCall, callInst);
llvm::ReplaceInstWithInst(callInst, newCall);
}
@@ -1416,8 +1729,8 @@ DetectGSBaseOffsetsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
// base+offsets instruction. See above for why passing NULL
// for the Instruction * is intended.
llvm::Instruction *newCall =
lCallInst(gatherScatterFunc, basePtr, offsetVector, offsetScale,
storeValue, mask, "", NULL);
lCallInst(gatherScatterFunc, basePtr, variableOffset, offsetScale,
constOffset, storeValue, mask, "", NULL);
lCopyMetadata(newCall, callInst);
llvm::ReplaceInstWithInst(callInst, newCall);
}
@@ -2016,6 +2329,26 @@ struct GatherImpInfo {
};
static llvm::Value *
lComputeCommonPointer(llvm::Value *base, llvm::Value *offsets,
llvm::Instruction *insertBefore) {
llvm::Value *firstOffset =
llvm::ExtractElementInst::Create(offsets, LLVMInt32(0), "first_offset",
insertBefore);
llvm::Value *offsetIndex[1] = { firstOffset };
#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn)
llvm::ArrayRef<llvm::Value *> arrayRef(&offsetIndex[0], &offsetIndex[1]);
return
llvm::GetElementPtrInst::Create(base, arrayRef, "ptr", insertBefore);
#else
return
llvm::GetElementPtrInst::Create(base, &offsetIndex[0], &offsetIndex[1],
"ptr", insertBefore);
#endif
}
struct ScatterImpInfo {
ScatterImpInfo(const char *pName, const char *msName,
LLVM_TYPE_CONST llvm::Type *vpt, int a)
@@ -2109,45 +2442,42 @@ GSToLoadStorePass::runOnBasicBlock(llvm::BasicBlock &bb) {
Assert(ok);
llvm::Value *base = callInst->getArgOperand(0);
llvm::Value *offsets = callInst->getArgOperand(1);
llvm::Value *varyingOffsets = callInst->getArgOperand(1);
llvm::Value *offsetScale = callInst->getArgOperand(2);
llvm::Value *storeValue = (scatterInfo != NULL) ? callInst->getArgOperand(3) : NULL;
llvm::Value *mask = callInst->getArgOperand((gatherInfo != NULL) ? 3 : 4);
llvm::Value *constOffsets = callInst->getArgOperand(3);
llvm::Value *storeValue = (scatterInfo != NULL) ? callInst->getArgOperand(4) : NULL;
llvm::Value *mask = callInst->getArgOperand((gatherInfo != NULL) ? 4 : 5);
// Compute the full offset vector: offsetScale * varyingOffsets + constOffsets
llvm::ConstantInt *offsetScaleInt =
llvm::dyn_cast<llvm::ConstantInt>(offsetScale);
Assert(offsetScaleInt != NULL);
uint64_t scaleValue = offsetScaleInt->getZExtValue();
if (offsets->getType() == LLVMTypes::Int64VectorType)
// offsetScale is an i32, so sext it so that if we use it in a
// multiply below, it has the same type as the i64 offset used
// as the other operand...
offsetScale = new llvm::SExtInst(offsetScale, LLVMTypes::Int64Type,
"offset_sext", callInst);
std::vector<llvm::Constant *> scales;
for (int i = 0; i < g->target.vectorWidth; ++i) {
if (varyingOffsets->getType() == LLVMTypes::Int64VectorType)
scales.push_back(LLVMInt64(scaleValue));
else
scales.push_back(LLVMInt32(scaleValue));
}
llvm::Constant *offsetScaleVec = llvm::ConstantVector::get(scales);
llvm::Value *scaledVarying =
llvm::BinaryOperator::Create(llvm::Instruction::Mul, offsetScaleVec,
varyingOffsets, "scaled_varying", callInst);
llvm::Value *fullOffsets =
llvm::BinaryOperator::Create(llvm::Instruction::Add, scaledVarying,
constOffsets, "varying+const_offsets",
callInst);
{
std::vector<llvm::PHINode *> seenPhis;
if (LLVMVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) {
if (LLVMVectorValuesAllEqual(fullOffsets, g->target.vectorWidth, seenPhis)) {
// If all the offsets are equal, then compute the single
// pointer they all represent based on the first one of them
// (arbitrarily).
// FIXME: the code from here to where ptr is computed is highly
// redundant with the case for a vector linear below.
llvm::Value *firstOffset =
llvm::ExtractElementInst::Create(offsets, LLVMInt32(0), "first_offset",
callInst);
llvm::Value *indices[1] = { firstOffset };
#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn)
llvm::ArrayRef<llvm::Value *> arrayRef(&indices[0], &indices[1]);
llvm::Value *ptr =
llvm::GetElementPtrInst::Create(base, arrayRef, "ptr", callInst);
#else
llvm::Value *ptr =
llvm::GetElementPtrInst::Create(base, &indices[0], &indices[1],
"ptr", callInst);
#endif
llvm::Value *ptr = lComputeCommonPointer(base, fullOffsets, callInst);
lCopyMetadata(ptr, callInst);
if (gatherInfo != NULL) {
@@ -2175,9 +2505,11 @@ GSToLoadStorePass::runOnBasicBlock(llvm::BasicBlock &bb) {
llvm::ExtractElementInst::Create(storeValue, LLVMInt32(0), "rvalue_first",
callInst);
lCopyMetadata(first, callInst);
ptr = new llvm::BitCastInst(ptr, llvm::PointerType::get(first->getType(), 0),
"ptr2rvalue_type", callInst);
lCopyMetadata(ptr, callInst);
llvm::Instruction *sinst = new llvm::StoreInst(first, ptr, false,
scatterInfo->align);
lCopyMetadata(sinst, callInst);
@@ -2190,34 +2522,15 @@ GSToLoadStorePass::runOnBasicBlock(llvm::BasicBlock &bb) {
}
int step = gatherInfo ? gatherInfo->align : scatterInfo->align;
step /= (int)offsetScaleInt->getZExtValue();
std::vector<llvm::PHINode *> seenPhis;
if (step > 0 && lVectorIsLinear(offsets, g->target.vectorWidth,
if (step > 0 && lVectorIsLinear(fullOffsets, g->target.vectorWidth,
step, seenPhis)) {
// We have a linear sequence of memory locations being accessed
// starting with the location given by the offset from
// offsetElements[0], with stride of 4 or 8 bytes (for 32 bit
// and 64 bit gather/scatters, respectively.)
// Get the base pointer using the first guy's offset.
llvm::Value *firstOffset =
llvm::ExtractElementInst::Create(offsets, LLVMInt32(0), "first_offset",
callInst);
llvm::Value *scaledOffset =
llvm::BinaryOperator::Create(llvm::Instruction::Mul, firstOffset,
offsetScale, "scaled_offset", callInst);
llvm::Value *indices[1] = { scaledOffset };
#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn)
llvm::ArrayRef<llvm::Value *> arrayRef(&indices[0], &indices[1]);
llvm::Value *ptr =
llvm::GetElementPtrInst::Create(base, arrayRef, "ptr", callInst);
#else
llvm::Value *ptr =
llvm::GetElementPtrInst::Create(base, &indices[0], &indices[1],
"ptr", callInst);
#endif
llvm::Value *ptr = lComputeCommonPointer(base, fullOffsets, callInst);
lCopyMetadata(ptr, callInst);
if (gatherInfo != NULL) {