Changes in cbackend.cpp to match broadcast generation changes

This commit is contained in:
Dmitry Babokin
2013-04-12 00:10:41 +04:00
parent 4c35d9456a
commit 7371d82bdf

View File

@@ -4395,16 +4395,21 @@ public:
static char ID;
llvm::Module *module;
int vectorWidth;
unsigned int vectorWidth;
private:
unsigned int ChainLength(llvm::InsertElementInst *inst) const;
llvm::Value *getInsertChainSmearValue(llvm::Instruction* inst) const;
llvm::Value *getShuffleSmearValue(llvm::Instruction* inst) const;
};
char SmearCleanupPass::ID = 0;
static int
lChainLength(llvm::InsertElementInst *inst) {
int length = 0;
unsigned int
SmearCleanupPass::ChainLength(llvm::InsertElementInst *inst) const {
unsigned int length = 0;
while (inst != NULL) {
++length;
inst = llvm::dyn_cast<llvm::InsertElementInst>(inst->getOperand(0));
@@ -4413,45 +4418,105 @@ lChainLength(llvm::InsertElementInst *inst) {
}
llvm::Value *
SmearCleanupPass::getInsertChainSmearValue(llvm::Instruction* inst) const {
// TODO: we don't check indexes where we do insertion, so we may trigger
// transformation for a wrong chain.
// This way of doing broadcast is obsolete and should be probably removed
// some day.
llvm::InsertElementInst *insertInst =
llvm::dyn_cast<llvm::InsertElementInst>(inst);
if (!insertInst) {
return NULL;
}
// We consider only chians of vectorWidth length.
if (ChainLength(insertInst) != vectorWidth) {
return NULL;
}
// FIXME: we only want to do this to vectors with width equal to
// the target vector width. But we can't easily get that here, so
// for now we at least avoid one case where we definitely don't
// want to do this.
llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(insertInst->getType());
if (vt->getNumElements() == 1) {
return NULL;
}
llvm::Value *smearValue = NULL;
while (insertInst != NULL) {
// operand 1 is inserted value
llvm::Value *insertValue = insertInst->getOperand(1);
if (smearValue == NULL) {
smearValue = insertValue;
}
else if (smearValue != insertValue) {
return NULL;
}
// operand 0 is a vector to insert into.
insertInst =
llvm::dyn_cast<llvm::InsertElementInst>(insertInst->getOperand(0));
}
assert(smearValue != NULL);
return smearValue;
}
llvm::Value *
SmearCleanupPass::getShuffleSmearValue(llvm::Instruction* inst) const {
llvm::ShuffleVectorInst *shuffleInst =
llvm::dyn_cast<llvm::ShuffleVectorInst>(inst);
if (!shuffleInst) {
return NULL;
}
llvm::Constant* mask =
llvm::dyn_cast<llvm::Constant>(shuffleInst->getOperand(2));
// Check that the shuffle is a broadcast of the first element of the first vector,
// i.e. mask vector is all-zeros vector of expected size.
if (!(mask &&
mask->isNullValue() &&
llvm::dyn_cast<llvm::VectorType>(mask->getType())->getNumElements() == vectorWidth)) {
return NULL;
}
llvm::InsertElementInst *insertInst =
llvm::dyn_cast<llvm::InsertElementInst>(shuffleInst->getOperand(0));
// Check that it's an InsertElementInst that inserts a value to first element.
if (!(insertInst &&
llvm::isa<llvm::Constant>(insertInst->getOperand(2)) &&
llvm::dyn_cast<llvm::Constant>(insertInst->getOperand(2))->isNullValue())) {
return NULL;
}
llvm::Value *result = insertInst->getOperand(1);
return result;
}
bool
SmearCleanupPass::runOnBasicBlock(llvm::BasicBlock &bb) {
bool modifiedAny = false;
restart:
for (llvm::BasicBlock::iterator iter = bb.begin(), e = bb.end(); iter != e; ++iter) {
llvm::InsertElementInst *insertInst =
llvm::dyn_cast<llvm::InsertElementInst>(&*iter);
if (insertInst == NULL)
llvm::Value *smearValue = NULL;
if (!(smearValue = getInsertChainSmearValue(iter)) &&
!(smearValue = getShuffleSmearValue(iter))) {
continue;
// Only do this on the last insert in a chain...
if (lChainLength(insertInst) != vectorWidth)
continue;
// FIXME: we only want to do this to vectors with width equal to
// the target vector width. But we can't easily get that here, so
// for now we at least avoid one case where we definitely don't
// want to do this.
llvm::VectorType *vt = llvm::dyn_cast<llvm::VectorType>(insertInst->getType());
if (vt->getNumElements() == 1)
continue;
llvm::Value *toMatch = NULL;
while (insertInst != NULL) {
llvm::Value *insertValue = insertInst->getOperand(1);
if (toMatch == NULL)
toMatch = insertValue;
else if (toMatch != insertValue)
goto not_equal;
insertInst =
llvm::dyn_cast<llvm::InsertElementInst>(insertInst->getOperand(0));
}
assert(toMatch != NULL);
{
llvm::Type *matchType = toMatch->getType();
const char *smearFuncName = lGetTypedFunc("smear", matchType, vectorWidth);
llvm::Type *smearType = smearValue->getType();
const char *smearFuncName = lGetTypedFunc("smear", smearType, vectorWidth);
if (smearFuncName != NULL) {
llvm::Function *smearFunc = module->getFunction(smearFuncName);
if (smearFunc == NULL) {
@@ -4460,7 +4525,7 @@ SmearCleanupPass::runOnBasicBlock(llvm::BasicBlock &bb) {
// parameter type.
llvm::Constant *sf =
module->getOrInsertFunction(smearFuncName, iter->getType(),
matchType, NULL);
smearType, NULL);
smearFunc = llvm::dyn_cast<llvm::Function>(sf);
assert(smearFunc != NULL);
#if defined(LLVM_3_1)
@@ -4473,10 +4538,10 @@ SmearCleanupPass::runOnBasicBlock(llvm::BasicBlock &bb) {
}
assert(smearFunc != NULL);
llvm::Value *args[1] = { toMatch };
llvm::Value *args[1] = { smearValue };
llvm::ArrayRef<llvm::Value *> argArray(&args[0], &args[1]);
llvm::Instruction *smearCall =
llvm::CallInst::Create(smearFunc, argArray, LLVMGetName(toMatch, "_smear"),
llvm::CallInst::Create(smearFunc, argArray, LLVMGetName(smearValue, "_smear"),
(llvm::Instruction *)NULL);
ReplaceInstWithInst(iter, smearCall);
@@ -4484,9 +4549,6 @@ SmearCleanupPass::runOnBasicBlock(llvm::BasicBlock &bb) {
modifiedAny = true;
goto restart;
}
}
not_equal:
;
}
return modifiedAny;