Broadcast implementation as InsertElement+Shuffle and related improvements

This commit is contained in:
Dmitry Babokin
2013-04-10 02:18:24 +04:00
parent 603abf70dc
commit 5898532605
6 changed files with 153 additions and 56 deletions

74
ctx.cpp
View File

@@ -1379,6 +1379,19 @@ FunctionEmitContext::MasksAllEqual(llvm::Value *v1, llvm::Value *v2) {
#endif #endif
} }
llvm::Value *
FunctionEmitContext::ProgramIndex() {
llvm::SmallVector<llvm::Constant*, 16> array;
for (int i = 0; i < g->target->getVectorWidth() ; ++i) {
llvm::Constant *C = LLVMInt32(i);
array.push_back(C);
}
llvm::Constant* index = llvm::ConstantVector::get(array);
return index;
}
llvm::Value * llvm::Value *
FunctionEmitContext::GetStringPtr(const std::string &str) { FunctionEmitContext::GetStringPtr(const std::string &str) {
@@ -1729,26 +1742,48 @@ FunctionEmitContext::SmearUniform(llvm::Value *value, const char *name) {
llvm::Value *ret = NULL; llvm::Value *ret = NULL;
llvm::Type *eltType = value->getType(); llvm::Type *eltType = value->getType();
llvm::Type *vecType = NULL;
llvm::PointerType *pt = llvm::PointerType *pt =
llvm::dyn_cast<llvm::PointerType>(eltType); llvm::dyn_cast<llvm::PointerType>(eltType);
if (pt != NULL) { if (pt != NULL) {
// Varying pointers are represented as vectors of i32/i64s // Varying pointers are represented as vectors of i32/i64s
ret = llvm::UndefValue::get(LLVMTypes::VoidPointerVectorType); vecType = LLVMTypes::VoidPointerVectorType;
value = PtrToIntInst(value); value = PtrToIntInst(value);
} }
else else {
// All other varying types are represented as vectors of the // All other varying types are represented as vectors of the
// underlying type. // underlying type.
ret = llvm::UndefValue::get(llvm::VectorType::get(eltType, vecType = llvm::VectorType::get(eltType, g->target->getVectorWidth());
g->target->getVectorWidth()));
for (int i = 0; i < g->target->getVectorWidth(); ++i) {
llvm::Twine n = llvm::Twine("smear.") + llvm::Twine(name ? name : "") +
llvm::Twine(i);
ret = InsertInst(ret, value, i, n.str().c_str());
} }
// Check for a constant case.
if (llvm::Constant *const_val = llvm::dyn_cast<llvm::Constant>(value)) {
ret = llvm::ConstantVector::getSplat(
g->target->getVectorWidth(),
const_val);
return ret;
}
// Generate the follwoing sequence:
// %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0
// %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef,
// <4 x i32> zeroinitializer
//
llvm::Value *undef1 = llvm::UndefValue::get(vecType);
llvm::Value *undef2 = llvm::UndefValue::get(vecType);
// InsertElement
llvm::Twine tw1 = llvm::Twine("broadcast_init.") + llvm::Twine(name ? name : "");
llvm::Value *insert = InsertInst(undef1, value, 0, tw1.str().c_str());
// ShuffleVector
llvm::Constant *zeroVec = llvm::ConstantVector::getSplat(
g->target->getVectorWidth(),
llvm::Constant::getNullValue(llvm::Type::getInt32Ty(*g->ctx)));
llvm::Twine tw2 = llvm::Twine("broadcast.") + llvm::Twine(name ? name : "");
ret = ShuffleInst(insert, undef2, zeroVec, tw2.str().c_str());
return ret; return ret;
} }
@@ -3131,6 +3166,27 @@ FunctionEmitContext::InsertInst(llvm::Value *v, llvm::Value *eltVal, int elt,
} }
llvm::Value *
FunctionEmitContext::ShuffleInst(llvm::Value *v1, llvm::Value *v2, llvm::Value *mask,
const char *name) {
if (v1 == NULL || v2 == NULL || mask == NULL) {
AssertPos(currentPos, m->errorCount > 0);
return NULL;
}
if (name == NULL) {
char buf[32];
sprintf(buf, "_shuffle");
name = LLVMGetName(v1, buf);
}
llvm::Instruction *ii = new llvm::ShuffleVectorInst(v1, v2, mask, name, bblock);
AddDebugPos(ii);
return ii;
}
llvm::PHINode * llvm::PHINode *
FunctionEmitContext::PhiNode(llvm::Type *type, int count, FunctionEmitContext::PhiNode(llvm::Type *type, int count,
const char *name) { const char *name) {

7
ctx.h
View File

@@ -295,6 +295,10 @@ public:
that indicates whether the two masks are equal. */ that indicates whether the two masks are equal. */
llvm::Value *MasksAllEqual(llvm::Value *mask1, llvm::Value *mask2); llvm::Value *MasksAllEqual(llvm::Value *mask1, llvm::Value *mask2);
/** Generate ConstantVector, which contains ProgramIndex, i.e.
< i32 0, i32 1, i32 2, i32 3> */
llvm::Value *ProgramIndex();
/** Given a string, create an anonymous global variable to hold its /** Given a string, create an anonymous global variable to hold its
value and return the pointer to the string. */ value and return the pointer to the string. */
llvm::Value *GetStringPtr(const std::string &str); llvm::Value *GetStringPtr(const std::string &str);
@@ -500,6 +504,9 @@ public:
llvm::Value *InsertInst(llvm::Value *v, llvm::Value *eltVal, int elt, llvm::Value *InsertInst(llvm::Value *v, llvm::Value *eltVal, int elt,
const char *name = NULL); const char *name = NULL);
llvm::Value *ShuffleInst(llvm::Value *v1, llvm::Value *v2, llvm::Value *mask,
const char *name = NULL);
llvm::PHINode *PhiNode(llvm::Type *type, int count, llvm::PHINode *PhiNode(llvm::Type *type, int count,
const char *name = NULL); const char *name = NULL);
llvm::Instruction *SelectInst(llvm::Value *test, llvm::Value *val0, llvm::Instruction *SelectInst(llvm::Value *test, llvm::Value *val0,

View File

@@ -601,44 +601,74 @@ lGetIntValue(llvm::Value *offset) {
void void
LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, LLVMFlattenInsertChain(llvm::Value *inst, int vectorWidth,
llvm::Value **elements) { llvm::Value **elements) {
for (int i = 0; i < vectorWidth; ++i) for (int i = 0; i < vectorWidth; ++i) {
elements[i] = NULL; elements[i] = NULL;
}
while (ie != NULL) { // Catch a pattern of InsertElement chain.
int64_t iOffset = lGetIntValue(ie->getOperand(2)); if (llvm::InsertElementInst *ie =
Assert(iOffset >= 0 && iOffset < vectorWidth); llvm::dyn_cast<llvm::InsertElementInst>(inst)) {
Assert(elements[iOffset] == NULL); while (ie != NULL) {
int64_t iOffset = lGetIntValue(ie->getOperand(2));
Assert(iOffset >= 0 && iOffset < vectorWidth);
Assert(elements[iOffset] == NULL);
// Get the scalar value from this insert // Get the scalar value from this insert
elements[iOffset] = ie->getOperand(1); elements[iOffset] = ie->getOperand(1);
// Do we have another insert? // Do we have another insert?
llvm::Value *insertBase = ie->getOperand(0); llvm::Value *insertBase = ie->getOperand(0);
ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase); ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase);
if (ie == NULL) { if (ie == NULL) {
if (llvm::isa<llvm::UndefValue>(insertBase)) if (llvm::isa<llvm::UndefValue>(insertBase)) {
return; return;
}
// Get the value out of a constant vector if that's what we // Get the value out of a constant vector if that's what we
// have // have
llvm::ConstantVector *cv = llvm::ConstantVector *cv =
llvm::dyn_cast<llvm::ConstantVector>(insertBase); llvm::dyn_cast<llvm::ConstantVector>(insertBase);
// FIXME: this assert is a little questionable; we probably // FIXME: this assert is a little questionable; we probably
// shouldn't fail in this case but should just return an // shouldn't fail in this case but should just return an
// incomplete result. But there aren't currently any known // incomplete result. But there aren't currently any known
// cases where we have anything other than an undef value or a // cases where we have anything other than an undef value or a
// constant vector at the base, so if that ever does happen, // constant vector at the base, so if that ever does happen,
// it'd be nice to know what happend so that perhaps we can // it'd be nice to know what happend so that perhaps we can
// handle it. // handle it.
// FIXME: Also, should we handle ConstantDataVectors with // FIXME: Also, should we handle ConstantDataVectors with
// LLVM3.1? What about ConstantAggregateZero values?? // LLVM3.1? What about ConstantAggregateZero values??
Assert(cv != NULL); Assert(cv != NULL);
Assert(iOffset < (int)cv->getNumOperands()); Assert(iOffset < (int)cv->getNumOperands());
elements[iOffset] = cv->getOperand((int32_t)iOffset); elements[iOffset] = cv->getOperand((int32_t)iOffset);
}
}
}
// Catch a pattern of broadcast implemented as InsertElement + Shuffle:
// %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0
// %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef,
// <4 x i32> zeroinitializer
else if (llvm::ShuffleVectorInst *shuf =
llvm::dyn_cast<llvm::ShuffleVectorInst>(inst)) {
llvm::Value *indices = shuf->getOperand(2);
if (llvm::isa<llvm::ConstantAggregateZero>(indices)) {
llvm::Value *op = shuf->getOperand(0);
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(op);
if (ie != NULL &&
llvm::isa<llvm::UndefValue>(ie->getOperand(0))) {
llvm::ConstantInt *ci =
llvm::dyn_cast<llvm::ConstantInt>(ie->getOperand(2));
if (ci->isZero()) {
for (int i = 0; i < vectorWidth; ++i) {
elements[i] = ie->getOperand(1);
}
return;
}
}
} }
} }
} }
@@ -694,10 +724,10 @@ lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength,
else else
Assert(LLVMVectorValuesAllEqual(val)); Assert(LLVMVectorValuesAllEqual(val));
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(val); if (llvm::isa<llvm::InsertElementInst>(val) ||
if (ie != NULL) { llvm::isa<llvm::ShuffleVectorInst>(val)) {
llvm::Value *elts[ISPC_MAX_NVEC]; llvm::Value *elts[ISPC_MAX_NVEC];
LLVMFlattenInsertChain(ie, g->target->getVectorWidth(), elts); LLVMFlattenInsertChain(val, g->target->getVectorWidth(), elts);
// We just need to check the scalar first value, since we know that // We just need to check the scalar first value, since we know that
// all elements are equal // all elements are equal
return lIsExactMultiple(elts[0], baseValue, vectorLength, return lIsExactMultiple(elts[0], baseValue, vectorLength,
@@ -1440,10 +1470,10 @@ lExtractFirstVectorElement(llvm::Value *v,
// If we have a chain of insertelement instructions, then we can just // If we have a chain of insertelement instructions, then we can just
// flatten them out and grab the value for the first one. // flatten them out and grab the value for the first one.
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v); if (llvm::isa<llvm::InsertElementInst>(v) ||
if (ie != NULL) { llvm::isa<llvm::ShuffleVectorInst>(v)) {
llvm::Value *elements[ISPC_MAX_NVEC]; llvm::Value *elements[ISPC_MAX_NVEC];
LLVMFlattenInsertChain(ie, vt->getNumElements(), elements); LLVMFlattenInsertChain(v, vt->getNumElements(), elements);
return elements[0]; return elements[0];
} }

View File

@@ -264,8 +264,13 @@ extern bool LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts);
constant vector. For anything more complex (e.g. some other arbitrary constant vector. For anything more complex (e.g. some other arbitrary
value, it doesn't try to extract element values into the returned value, it doesn't try to extract element values into the returned
array. array.
This also handles common broadcast pattern:
%broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0
%broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef,
<4 x i32> zeroinitializer
*/ */
extern void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, extern void LLVMFlattenInsertChain(llvm::Value *inst, int vectorWidth,
llvm::Value **elements); llvm::Value **elements);
/** This is a utility routine for debugging that dumps out the given LLVM /** This is a utility routine for debugging that dumps out the given LLVM

13
opt.cpp
View File

@@ -1058,10 +1058,10 @@ lCheckForActualPointer(llvm::Value *v) {
*/ */
static llvm::Value * static llvm::Value *
lGetBasePointer(llvm::Value *v) { lGetBasePointer(llvm::Value *v) {
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v); if (llvm::isa<llvm::InsertElementInst>(v) ||
if (ie != NULL) { llvm::isa<llvm::ShuffleVectorInst>(v)) {
llvm::Value *elements[ISPC_MAX_NVEC]; llvm::Value *elements[ISPC_MAX_NVEC];
LLVMFlattenInsertChain(ie, g->target->getVectorWidth(), elements); LLVMFlattenInsertChain(v, g->target->getVectorWidth(), elements);
// Make sure none of the elements is undefined. // Make sure none of the elements is undefined.
// TODO: it's probably ok to allow undefined elements and return // TODO: it's probably ok to allow undefined elements and return
@@ -1080,9 +1080,12 @@ lGetBasePointer(llvm::Value *v) {
} }
// This case comes up with global/static arrays // This case comes up with global/static arrays
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v); if (llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v)) {
if (cv != NULL)
return lCheckForActualPointer(cv->getSplatValue()); return lCheckForActualPointer(cv->getSplatValue());
}
else if (llvm::ConstantDataVector *cdv = llvm::dyn_cast<llvm::ConstantDataVector>(v)) {
return lCheckForActualPointer(cdv->getSplatValue());
}
return NULL; return NULL;
} }

View File

@@ -1993,11 +1993,7 @@ ForeachActiveStmt::EmitCode(FunctionEmitContext *ctx) const {
// math...) // math...)
// Get the "program index" vector value // Get the "program index" vector value
llvm::Value *programIndex = llvm::Value *programIndex = ctx->ProgramIndex();
llvm::UndefValue::get(LLVMTypes::Int32VectorType);
for (int i = 0; i < g->target->getVectorWidth(); ++i)
programIndex = ctx->InsertInst(programIndex, LLVMInt32(i), i,
"prog_index");
// And smear the current lane out to a vector // And smear the current lane out to a vector
llvm::Value *firstSet32 = llvm::Value *firstSet32 =