Improved handling of splatted constant vectors in C++ backend.

Now, when we're printing out a constant vector value, we check to see
if it's a splat and call out to one of the __splat_* functions in
the generated code if to.
This commit is contained in:
Matt Pharr
2012-04-19 13:11:15 -07:00
parent e4b3d03da5
commit 12c754c92b
3 changed files with 73 additions and 28 deletions

View File

@@ -1096,6 +1096,26 @@ bool CWriter::printCast(unsigned opc, Type *SrcTy, Type *DstTy) {
return false;
}
// FIXME: generalize this/make it not so hard-coded?
static const char *lGetSmearFunc(Type *matchType) {
switch (matchType->getTypeID()) {
case Type::FloatTyID: return "__smear_float";
case Type::DoubleTyID: return "__smear_double";
case Type::IntegerTyID: {
switch (cast<IntegerType>(matchType)->getBitWidth()) {
case 1: return "__smear_i1";
case 8: return "__smear_i8";
case 16: return "__smear_i16";
case 32: return "__smear_i32";
case 64: return "__smear_i64";
}
}
default: return NULL;
}
}
// printConstant - The LLVM Constant to C Constant converter.
void CWriter::printConstant(Constant *CPV, bool Static) {
if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(CPV)) {
@@ -1435,30 +1455,61 @@ void CWriter::printConstant(Constant *CPV, bool Static) {
Out << ")";
break;
}
case Type::VectorTyID:
printType(Out, CPV->getType());
Out << "(";
case Type::VectorTyID: {
VectorType *VT = dyn_cast<VectorType>(CPV->getType());
const char *smearFunc = lGetSmearFunc(VT->getElementType());
if (ConstantVector *CV = dyn_cast<ConstantVector>(CPV)) {
printConstantVector(CV, Static);
if (isa<ConstantAggregateZero>(CPV)) {
assert(smearFunc != NULL);
Constant *CZ = Constant::getNullValue(VT->getElementType());
Out << smearFunc << "(";
printConstant(CZ, Static);
Out << ")";
}
else if (ConstantVector *CV = dyn_cast<ConstantVector>(CPV)) {
llvm::Constant *splatValue = CV->getSplatValue();
if (splatValue != NULL && smearFunc != NULL) {
Out << smearFunc << "(";
printConstant(splatValue, Static);
Out << ")";
}
else {
printType(Out, CPV->getType());
Out << "(";
printConstantVector(CV, Static);
Out << ")";
}
}
#ifdef LLVM_3_1svn
} else if (ConstantDataSequential *CDS =
dyn_cast<ConstantDataSequential>(CPV)) {
printConstantDataSequential(CDS, Static);
else if (ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(CPV)) {
llvm::Constant *splatValue = CDV->getSplatValue();
if (splatValue != NULL && smearFunc != NULL) {
Out << smearFunc << "(";
printConstant(splatValue, Static);
Out << ")";
}
else {
printType(Out, CPV->getType());
Out << "(";
printConstantDataSequential(CDV, Static);
Out << ")";
}
#endif
} else {
assert(isa<ConstantAggregateZero>(CPV) || isa<UndefValue>(CPV));
VectorType *VT = cast<VectorType>(CPV->getType());
assert(isa<UndefValue>(CPV));
Constant *CZ = Constant::getNullValue(VT->getElementType());
printType(Out, CPV->getType());
Out << "(";
printConstant(CZ, Static);
for (unsigned i = 1, e = VT->getNumElements(); i != e; ++i) {
Out << ", ";
printConstant(CZ, Static);
}
Out << ")";
}
Out << ")";
break;
}
case Type::StructTyID:
if (!Static) {
// call init func...
@@ -4327,23 +4378,8 @@ SmearCleanupPass::runOnBasicBlock(llvm::BasicBlock &bb) {
assert(toMatch != NULL);
{
// FIXME: generalize this/make it not so hard-coded?
Type *matchType = toMatch->getType();
const char *smearFuncName = NULL;
switch (matchType->getTypeID()) {
case Type::FloatTyID: smearFuncName = "__smear_float"; break;
case Type::DoubleTyID: smearFuncName = "__smear_double"; break;
case Type::IntegerTyID: {
switch (cast<IntegerType>(matchType)->getBitWidth()) {
case 8: smearFuncName = "__smear_i8"; break;
case 16: smearFuncName = "__smear_i16"; break;
case 32: smearFuncName = "__smear_i32"; break;
case 64: smearFuncName = "__smear_i64"; break;
}
}
default: break;
}
const char *smearFuncName = lGetSmearFunc(matchType);
if (smearFuncName != NULL) {
Function *smearFunc = module->getFunction(smearFuncName);

View File

@@ -374,6 +374,11 @@ static FORCEINLINE void __store(__vec16_i1 *p, __vec16_i1 v, int align) {
*ptr = v.v;
}
static FORCEINLINE __vec16_i1 __smear_i1(int v) {
return __vec16_i1(v, v, v, v, v, v, v, v,
v, v, v, v, v, v, v, v);
}
///////////////////////////////////////////////////////////////////////////
// int8

View File

@@ -266,6 +266,10 @@ static FORCEINLINE void __store(__vec4_i1 *p, __vec4_i1 value, int align) {
_mm_storeu_ps((float *)(&p->v), value.v);
}
static FORCEINLINE __vec4_i1 __smear_i1(int v) {
return __vec4_i1(v, v, v, v);
}
///////////////////////////////////////////////////////////////////////////
// int8