Fix bugs in LLVMExtractFirstVectorElement().

When we're manually scalarizing the extraction of the first element
of a vector value, we need to be careful about handling constant values
and about where new instructions are inserted.  The old code was
sloppy about this, which in turn lead to invalid IR in some cases.
For example, the two bugs below were essentially due to generating
an extractelement inst from a zeroinitializer value and then inserting
it in the wrong bblock such that a phi node that used that value was
malformed.

Fixes issues #240 and #229.
This commit is contained in:
Matt Pharr
2012-04-19 09:45:04 -07:00
parent a2bb899a6b
commit 326c45fa17
3 changed files with 52 additions and 26 deletions

View File

@@ -1390,19 +1390,38 @@ LLVMDumpValue(llvm::Value *v) {
static llvm::Value * static llvm::Value *
lExtractFirstVectorElement(llvm::Value *v, llvm::Instruction *insertBefore, lExtractFirstVectorElement(llvm::Value *v,
std::map<llvm::PHINode *, llvm::PHINode *> &phiMap) { std::map<llvm::PHINode *, llvm::PHINode *> &phiMap) {
// If it's not an instruction (i.e. is a constant), then we can just
// emit an extractelement instruction and let the regular optimizer do
// the rest.
if (llvm::isa<llvm::Instruction>(v) == false)
return llvm::ExtractElementInst::Create(v, LLVMInt32(0), "first_elt",
insertBefore);
llvm::VectorType *vt = llvm::VectorType *vt =
llvm::dyn_cast<llvm::VectorType>(v->getType()); llvm::dyn_cast<llvm::VectorType>(v->getType());
Assert(vt != NULL); Assert(vt != NULL);
// First, handle various constant types; do the extraction manually, as
// appropriate.
if (llvm::isa<llvm::ConstantAggregateZero>(v) == true) {
Assert(vt->getElementType()->isIntegerTy());
return llvm::ConstantInt::get(vt->getElementType(), 0);
}
if (llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v)) {
#ifdef LLVM_3_1svn
return cv->getOperand(0);
#else
llvm::SmallVector<llvm::Constant *, ISPC_MAX_NVEC> elements;
cv->getVectorElements(elements);
return elements[0];
#endif // LLVM_3_1
}
#ifdef LLVM_3_1svn
if (llvm::ConstantDataVector *cdv =
llvm::dyn_cast<llvm::ConstantDataVector>(v))
return cdv->getElementAsConstant(0);
#endif // LLVM_3_1
// Otherwise, all that we should have at this point is an instruction
// of some sort
Assert(llvm::isa<llvm::Constant>(v) == false);
Assert(llvm::isa<llvm::Instruction>(v) == true);
std::string newName = v->getName().str() + std::string(".elt0"); std::string newName = v->getName().str() + std::string(".elt0");
// Rewrite regular binary operators and casts to the scalarized // Rewrite regular binary operators and casts to the scalarized
@@ -1410,20 +1429,24 @@ lExtractFirstVectorElement(llvm::Value *v, llvm::Instruction *insertBefore,
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v); llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
if (bop != NULL) { if (bop != NULL) {
llvm::Value *v0 = lExtractFirstVectorElement(bop->getOperand(0), llvm::Value *v0 = lExtractFirstVectorElement(bop->getOperand(0),
insertBefore, phiMap); phiMap);
llvm::Value *v1 = lExtractFirstVectorElement(bop->getOperand(1), llvm::Value *v1 = lExtractFirstVectorElement(bop->getOperand(1),
insertBefore, phiMap); phiMap);
// Note that the new binary operator is inserted immediately before
// the previous vector one
return llvm::BinaryOperator::Create(bop->getOpcode(), v0, v1, return llvm::BinaryOperator::Create(bop->getOpcode(), v0, v1,
newName, insertBefore); newName, bop);
} }
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v); llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
if (cast != NULL) { if (cast != NULL) {
llvm::Value *v = lExtractFirstVectorElement(cast->getOperand(0), llvm::Value *v = lExtractFirstVectorElement(cast->getOperand(0),
insertBefore, phiMap); phiMap);
// Similarly, the equivalent scalar cast instruction goes right
// before the vector cast
return llvm::CastInst::Create(cast->getOpcode(), v, return llvm::CastInst::Create(cast->getOpcode(), v,
vt->getElementType(), newName, vt->getElementType(), newName,
insertBefore); cast);
} }
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v); llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
@@ -1438,8 +1461,7 @@ lExtractFirstVectorElement(llvm::Value *v, llvm::Instruction *insertBefore,
// return the pointer and not get stuck in an infinite loop. // return the pointer and not get stuck in an infinite loop.
// //
// The insertion point for the new phi node also has to be the // The insertion point for the new phi node also has to be the
// start of the bblock of the original phi node, which isn't // start of the bblock of the original phi node.
// necessarily the same bblock as insertBefore is in!
llvm::Instruction *phiInsertPos = phi->getParent()->begin(); llvm::Instruction *phiInsertPos = phi->getParent()->begin();
llvm::PHINode *scalarPhi = llvm::PHINode *scalarPhi =
llvm::PHINode::Create(vt->getElementType(), llvm::PHINode::Create(vt->getElementType(),
@@ -1449,7 +1471,7 @@ lExtractFirstVectorElement(llvm::Value *v, llvm::Instruction *insertBefore,
for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) { for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) {
llvm::Value *v = lExtractFirstVectorElement(phi->getIncomingValue(i), llvm::Value *v = lExtractFirstVectorElement(phi->getIncomingValue(i),
insertBefore, phiMap); phiMap);
scalarPhi->addIncoming(v, phi->getIncomingBlock(i)); scalarPhi->addIncoming(v, phi->getIncomingBlock(i));
} }
@@ -1466,15 +1488,22 @@ lExtractFirstVectorElement(llvm::Value *v, llvm::Instruction *insertBefore,
} }
// Worst case, for everything else, just do a regular extract element // Worst case, for everything else, just do a regular extract element
return llvm::ExtractElementInst::Create(v, LLVMInt32(0), "first_elt", // instruction, which we insert immediately after the instruction we
insertBefore); // have here.
llvm::Instruction *insertAfter = llvm::dyn_cast<llvm::Instruction>(v);
Assert(insertAfter != NULL);
llvm::Instruction *ee =
llvm::ExtractElementInst::Create(v, LLVMInt32(0), "first_elt",
(llvm::Instruction *)NULL);
ee->insertAfter(insertAfter);
return ee;
} }
llvm::Value * llvm::Value *
LLVMExtractFirstVectorElement(llvm::Value *v, llvm::Instruction *insertBefore) { LLVMExtractFirstVectorElement(llvm::Value *v) {
std::map<llvm::PHINode *, llvm::PHINode *> phiMap; std::map<llvm::PHINode *, llvm::PHINode *> phiMap;
llvm::Value *ret = lExtractFirstVectorElement(v, insertBefore, phiMap); llvm::Value *ret = lExtractFirstVectorElement(v, phiMap);
return ret; return ret;
} }

View File

@@ -274,8 +274,7 @@ extern void LLVMDumpValue(llvm::Value *v);
worth of values just to extract the first element, in cases where only worth of values just to extract the first element, in cases where only
the first element's value is needed. the first element's value is needed.
*/ */
extern llvm::Value *LLVMExtractFirstVectorElement(llvm::Value *v, extern llvm::Value *LLVMExtractFirstVectorElement(llvm::Value *v);
llvm::Instruction *insertBefore);
/** This function takes two vectors, expected to be the same length, and /** This function takes two vectors, expected to be the same length, and
returns a new vector of twice the length that represents concatenating returns a new vector of twice the length that represents concatenating

View File

@@ -2295,8 +2295,7 @@ struct GatherImpInfo {
static llvm::Value * static llvm::Value *
lComputeCommonPointer(llvm::Value *base, llvm::Value *offsets, lComputeCommonPointer(llvm::Value *base, llvm::Value *offsets,
llvm::Instruction *insertBefore) { llvm::Instruction *insertBefore) {
llvm::Value *firstOffset = LLVMExtractFirstVectorElement(offsets, llvm::Value *firstOffset = LLVMExtractFirstVectorElement(offsets);
insertBefore);
return lGEPInst(base, firstOffset, "ptr", insertBefore); return lGEPInst(base, firstOffset, "ptr", insertBefore);
} }
@@ -3290,8 +3289,7 @@ lComputeBasePtr(llvm::CallInst *gatherInst, llvm::Instruction *insertBefore) {
// All of the variable offsets values should be the same, due to // All of the variable offsets values should be the same, due to
// checking for this in GatherCoalescePass::runOnBasicBlock(). Thus, // checking for this in GatherCoalescePass::runOnBasicBlock(). Thus,
// extract the first value and use that as a scalar. // extract the first value and use that as a scalar.
llvm::Value *variable = LLVMExtractFirstVectorElement(variableOffsets, llvm::Value *variable = LLVMExtractFirstVectorElement(variableOffsets);
insertBefore);
if (variable->getType() == LLVMTypes::Int64Type) if (variable->getType() == LLVMTypes::Int64Type)
offsetScale = new llvm::ZExtInst(offsetScale, LLVMTypes::Int64Type, offsetScale = new llvm::ZExtInst(offsetScale, LLVMTypes::Int64Type,
"scale_to64", insertBefore); "scale_to64", insertBefore);