From 12c754c92b7fe9d3e005c3ff278fb0655c83242d Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Thu, 19 Apr 2012 13:11:15 -0700 Subject: [PATCH] Improved handling of splatted constant vectors in C++ backend. Now, when we're printing out a constant vector value, we check to see if it's a splat and call out to one of the __splat_* functions in the generated code if to. --- cbackend.cpp | 92 ++++++++++++++++++++++---------- examples/intrinsics/generic-16.h | 5 ++ examples/intrinsics/sse4.h | 4 ++ 3 files changed, 73 insertions(+), 28 deletions(-) diff --git a/cbackend.cpp b/cbackend.cpp index 0c582ce0..ebe96c45 100644 --- a/cbackend.cpp +++ b/cbackend.cpp @@ -1096,6 +1096,26 @@ bool CWriter::printCast(unsigned opc, Type *SrcTy, Type *DstTy) { return false; } + +// FIXME: generalize this/make it not so hard-coded? +static const char *lGetSmearFunc(Type *matchType) { + switch (matchType->getTypeID()) { + case Type::FloatTyID: return "__smear_float"; + case Type::DoubleTyID: return "__smear_double"; + case Type::IntegerTyID: { + switch (cast(matchType)->getBitWidth()) { + case 1: return "__smear_i1"; + case 8: return "__smear_i8"; + case 16: return "__smear_i16"; + case 32: return "__smear_i32"; + case 64: return "__smear_i64"; + } + } + default: return NULL; + } +} + + // printConstant - The LLVM Constant to C Constant converter. void CWriter::printConstant(Constant *CPV, bool Static) { if (const ConstantExpr *CE = dyn_cast(CPV)) { @@ -1435,30 +1455,61 @@ void CWriter::printConstant(Constant *CPV, bool Static) { Out << ")"; break; } - case Type::VectorTyID: - printType(Out, CPV->getType()); - Out << "("; + case Type::VectorTyID: { + VectorType *VT = dyn_cast(CPV->getType()); + const char *smearFunc = lGetSmearFunc(VT->getElementType()); - if (ConstantVector *CV = dyn_cast(CPV)) { - printConstantVector(CV, Static); + if (isa(CPV)) { + assert(smearFunc != NULL); + + Constant *CZ = Constant::getNullValue(VT->getElementType()); + Out << smearFunc << "("; + printConstant(CZ, Static); + Out << ")"; + } + else if (ConstantVector *CV = dyn_cast(CPV)) { + llvm::Constant *splatValue = CV->getSplatValue(); + if (splatValue != NULL && smearFunc != NULL) { + Out << smearFunc << "("; + printConstant(splatValue, Static); + Out << ")"; + } + else { + printType(Out, CPV->getType()); + Out << "("; + printConstantVector(CV, Static); + Out << ")"; + } + } #ifdef LLVM_3_1svn - } else if (ConstantDataSequential *CDS = - dyn_cast(CPV)) { - printConstantDataSequential(CDS, Static); + else if (ConstantDataVector *CDV = dyn_cast(CPV)) { + llvm::Constant *splatValue = CDV->getSplatValue(); + if (splatValue != NULL && smearFunc != NULL) { + Out << smearFunc << "("; + printConstant(splatValue, Static); + Out << ")"; + } + else { + printType(Out, CPV->getType()); + Out << "("; + printConstantDataSequential(CDV, Static); + Out << ")"; + } #endif } else { - assert(isa(CPV) || isa(CPV)); - VectorType *VT = cast(CPV->getType()); + assert(isa(CPV)); Constant *CZ = Constant::getNullValue(VT->getElementType()); + printType(Out, CPV->getType()); + Out << "("; printConstant(CZ, Static); for (unsigned i = 1, e = VT->getNumElements(); i != e; ++i) { Out << ", "; printConstant(CZ, Static); } + Out << ")"; } - Out << ")"; break; - + } case Type::StructTyID: if (!Static) { // call init func... @@ -4327,23 +4378,8 @@ SmearCleanupPass::runOnBasicBlock(llvm::BasicBlock &bb) { assert(toMatch != NULL); { - // FIXME: generalize this/make it not so hard-coded? Type *matchType = toMatch->getType(); - const char *smearFuncName = NULL; - - switch (matchType->getTypeID()) { - case Type::FloatTyID: smearFuncName = "__smear_float"; break; - case Type::DoubleTyID: smearFuncName = "__smear_double"; break; - case Type::IntegerTyID: { - switch (cast(matchType)->getBitWidth()) { - case 8: smearFuncName = "__smear_i8"; break; - case 16: smearFuncName = "__smear_i16"; break; - case 32: smearFuncName = "__smear_i32"; break; - case 64: smearFuncName = "__smear_i64"; break; - } - } - default: break; - } + const char *smearFuncName = lGetSmearFunc(matchType); if (smearFuncName != NULL) { Function *smearFunc = module->getFunction(smearFuncName); diff --git a/examples/intrinsics/generic-16.h b/examples/intrinsics/generic-16.h index 861db2a4..d6a5c121 100644 --- a/examples/intrinsics/generic-16.h +++ b/examples/intrinsics/generic-16.h @@ -374,6 +374,11 @@ static FORCEINLINE void __store(__vec16_i1 *p, __vec16_i1 v, int align) { *ptr = v.v; } +static FORCEINLINE __vec16_i1 __smear_i1(int v) { + return __vec16_i1(v, v, v, v, v, v, v, v, + v, v, v, v, v, v, v, v); +} + /////////////////////////////////////////////////////////////////////////// // int8 diff --git a/examples/intrinsics/sse4.h b/examples/intrinsics/sse4.h index c6299893..48a67719 100644 --- a/examples/intrinsics/sse4.h +++ b/examples/intrinsics/sse4.h @@ -266,6 +266,10 @@ static FORCEINLINE void __store(__vec4_i1 *p, __vec4_i1 value, int align) { _mm_storeu_ps((float *)(&p->v), value.v); } +static FORCEINLINE __vec4_i1 __smear_i1(int v) { + return __vec4_i1(v, v, v, v); +} + /////////////////////////////////////////////////////////////////////////// // int8