For << and >> with C++, detect when all instances are shifting by the same amount.

In this case, we now emit calls to potentially-specialized functions for the
left/right shifts that take a single integer value for the shift amount.  These
in turn can be matched to the corresponding intrinsics for the SSE target.

Issue #145.
This commit is contained in:
Matt Pharr
2012-01-19 10:04:32 -07:00
parent 3f89295d10
commit 68f6ea8def
6 changed files with 433 additions and 280 deletions

View File

@@ -24,6 +24,8 @@
#define PRIx64 "llx" #define PRIx64 "llx"
#endif #endif
#include "llvmutil.h"
#include "llvm/CallingConv.h" #include "llvm/CallingConv.h"
#include "llvm/Constants.h" #include "llvm/Constants.h"
#include "llvm/DerivedTypes.h" #include "llvm/DerivedTypes.h"
@@ -232,6 +234,7 @@ namespace {
unsigned NextAnonValueNumber; unsigned NextAnonValueNumber;
std::string includeName; std::string includeName;
int vectorWidth;
/// UnnamedStructIDs - This contains a unique ID for each struct that is /// UnnamedStructIDs - This contains a unique ID for each struct that is
/// either anonymous or has no name. /// either anonymous or has no name.
@@ -240,11 +243,13 @@ namespace {
public: public:
static char ID; static char ID;
explicit CWriter(formatted_raw_ostream &o, const char *incname) explicit CWriter(formatted_raw_ostream &o, const char *incname,
int vecwidth)
: FunctionPass(ID), Out(o), IL(0), Mang(0), LI(0), : FunctionPass(ID), Out(o), IL(0), Mang(0), LI(0),
TheModule(0), TAsm(0), MRI(0), MOFI(0), TCtx(0), TD(0), TheModule(0), TAsm(0), MRI(0), MOFI(0), TCtx(0), TD(0),
OpaqueCounter(0), NextAnonValueNumber(0), OpaqueCounter(0), NextAnonValueNumber(0),
includeName(incname ? incname : "generic_defs.h") { includeName(incname ? incname : "generic_defs.h"),
vectorWidth(vecwidth) {
initializeLoopInfoPass(*PassRegistry::getPassRegistry()); initializeLoopInfoPass(*PassRegistry::getPassRegistry());
FPCounter = 0; FPCounter = 0;
} }
@@ -2894,7 +2899,21 @@ void CWriter::visitBinaryOperator(Instruction &I) {
Out << "("; Out << "(";
writeOperand(I.getOperand(0)); writeOperand(I.getOperand(0));
Out << ", "; Out << ", ";
writeOperand(I.getOperand(1)); if ((I.getOpcode() == Instruction::Shl ||
I.getOpcode() == Instruction::LShr ||
I.getOpcode() == Instruction::AShr)) {
std::vector<PHINode *> phis;
if (LLVMVectorValuesAllEqual(I.getOperand(1),
vectorWidth, phis)) {
Out << "__extract_element(";
writeOperand(I.getOperand(1));
Out << ", 0) ";
}
else
writeOperand(I.getOperand(1));
}
else
writeOperand(I.getOperand(1));
Out << ")"; Out << ")";
return; return;
} }
@@ -3635,7 +3654,7 @@ std::string CWriter::InterpretASMConstraint(InlineAsm::ConstraintInfo& c) {
#endif #endif
std::string E; std::string E;
if (const Target *Match = TargetRegistry::lookupTarget(Triple, E)) if (const llvm::Target *Match = TargetRegistry::lookupTarget(Triple, E))
TargetAsm = Match->createMCAsmInfo(Triple); TargetAsm = Match->createMCAsmInfo(Triple);
else else
return c.Codes[0]; return c.Codes[0];
@@ -4337,7 +4356,7 @@ WriteCXXFile(llvm::Module *module, const char *fn, int vectorWidth,
pm.add(new BitcastCleanupPass); pm.add(new BitcastCleanupPass);
pm.add(createDeadCodeEliminationPass()); // clean up after smear pass pm.add(createDeadCodeEliminationPass()); // clean up after smear pass
//CO pm.add(createPrintModulePass(&fos)); //CO pm.add(createPrintModulePass(&fos));
pm.add(new CWriter(fos, includeName)); pm.add(new CWriter(fos, includeName, vectorWidth));
pm.add(createGCInfoDeleter()); pm.add(createGCInfoDeleter());
//CO pm.add(createVerifierPass()); //CO pm.add(createVerifierPass());

View File

@@ -251,6 +251,14 @@ static FORCEINLINE TYPE __select(bool cond, TYPE a, TYPE b) { \
return cond ? a : b; \ return cond ? a : b; \
} }
#define SHIFT_UNIFORM(TYPE, CAST, NAME, OP) \
static FORCEINLINE TYPE NAME(TYPE a, int32_t b) { \
TYPE ret; \
for (int i = 0; i < 16; ++i) \
ret.v[i] = (CAST)(a.v[i]) OP b; \
return ret; \
}
#define SMEAR(VTYPE, NAME, STYPE) \ #define SMEAR(VTYPE, NAME, STYPE) \
static FORCEINLINE VTYPE __smear_##NAME(STYPE v) { \ static FORCEINLINE VTYPE __smear_##NAME(STYPE v) { \
VTYPE ret; \ VTYPE ret; \
@@ -386,6 +394,10 @@ BINARY_OP_CAST(__vec16_i8, int8_t, __srem, %)
BINARY_OP_CAST(__vec16_i8, uint8_t, __lshr, >>) BINARY_OP_CAST(__vec16_i8, uint8_t, __lshr, >>)
BINARY_OP_CAST(__vec16_i8, int8_t, __ashr, >>) BINARY_OP_CAST(__vec16_i8, int8_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i8, uint8_t, __lshr, >>)
SHIFT_UNIFORM(__vec16_i8, int8_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i8, int8_t, __shl, <<)
CMP_OP(__vec16_i8, int8_t, __equal, ==) CMP_OP(__vec16_i8, int8_t, __equal, ==)
CMP_OP(__vec16_i8, int8_t, __not_equal, !=) CMP_OP(__vec16_i8, int8_t, __not_equal, !=)
CMP_OP(__vec16_i8, uint8_t, __unsigned_less_equal, <=) CMP_OP(__vec16_i8, uint8_t, __unsigned_less_equal, <=)
@@ -425,6 +437,10 @@ BINARY_OP_CAST(__vec16_i16, int16_t, __srem, %)
BINARY_OP_CAST(__vec16_i16, uint16_t, __lshr, >>) BINARY_OP_CAST(__vec16_i16, uint16_t, __lshr, >>)
BINARY_OP_CAST(__vec16_i16, int16_t, __ashr, >>) BINARY_OP_CAST(__vec16_i16, int16_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i16, uint16_t, __lshr, >>)
SHIFT_UNIFORM(__vec16_i16, int16_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i16, int16_t, __shl, <<)
CMP_OP(__vec16_i16, int16_t, __equal, ==) CMP_OP(__vec16_i16, int16_t, __equal, ==)
CMP_OP(__vec16_i16, int16_t, __not_equal, !=) CMP_OP(__vec16_i16, int16_t, __not_equal, !=)
CMP_OP(__vec16_i16, uint16_t, __unsigned_less_equal, <=) CMP_OP(__vec16_i16, uint16_t, __unsigned_less_equal, <=)
@@ -464,6 +480,10 @@ BINARY_OP_CAST(__vec16_i32, int32_t, __srem, %)
BINARY_OP_CAST(__vec16_i32, uint32_t, __lshr, >>) BINARY_OP_CAST(__vec16_i32, uint32_t, __lshr, >>)
BINARY_OP_CAST(__vec16_i32, int32_t, __ashr, >>) BINARY_OP_CAST(__vec16_i32, int32_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i32, uint32_t, __lshr, >>)
SHIFT_UNIFORM(__vec16_i32, int32_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i32, int32_t, __shl, <<)
CMP_OP(__vec16_i32, int32_t, __equal, ==) CMP_OP(__vec16_i32, int32_t, __equal, ==)
CMP_OP(__vec16_i32, int32_t, __not_equal, !=) CMP_OP(__vec16_i32, int32_t, __not_equal, !=)
CMP_OP(__vec16_i32, uint32_t, __unsigned_less_equal, <=) CMP_OP(__vec16_i32, uint32_t, __unsigned_less_equal, <=)
@@ -503,6 +523,10 @@ BINARY_OP_CAST(__vec16_i64, int64_t, __srem, %)
BINARY_OP_CAST(__vec16_i64, uint64_t, __lshr, >>) BINARY_OP_CAST(__vec16_i64, uint64_t, __lshr, >>)
BINARY_OP_CAST(__vec16_i64, int64_t, __ashr, >>) BINARY_OP_CAST(__vec16_i64, int64_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i64, uint64_t, __lshr, >>)
SHIFT_UNIFORM(__vec16_i64, int64_t, __ashr, >>)
SHIFT_UNIFORM(__vec16_i64, int64_t, __shl, <<)
CMP_OP(__vec16_i64, int64_t, __equal, ==) CMP_OP(__vec16_i64, int64_t, __equal, ==)
CMP_OP(__vec16_i64, int64_t, __not_equal, !=) CMP_OP(__vec16_i64, int64_t, __not_equal, !=)
CMP_OP(__vec16_i64, uint64_t, __unsigned_less_equal, <=) CMP_OP(__vec16_i64, uint64_t, __unsigned_less_equal, <=)

View File

@@ -303,6 +303,13 @@ static FORCEINLINE __vec4_i8 __shl(__vec4_i8 a, __vec4_i8 b) {
_mm_extract_epi8(a.v, 3) << _mm_extract_epi8(b.v, 3)); _mm_extract_epi8(a.v, 3) << _mm_extract_epi8(b.v, 3));
} }
static FORCEINLINE __vec4_i8 __shl(__vec4_i8 a, int32_t b) {
return __vec4_i8(_mm_extract_epi8(a.v, 0) << b,
_mm_extract_epi8(a.v, 1) << b,
_mm_extract_epi8(a.v, 2) << b,
_mm_extract_epi8(a.v, 3) << b);
}
static FORCEINLINE __vec4_i8 __udiv(__vec4_i8 a, __vec4_i8 b) { static FORCEINLINE __vec4_i8 __udiv(__vec4_i8 a, __vec4_i8 b) {
return __vec4_i8((uint8_t)_mm_extract_epi8(a.v, 0) / return __vec4_i8((uint8_t)_mm_extract_epi8(a.v, 0) /
(uint8_t)_mm_extract_epi8(b.v, 0), (uint8_t)_mm_extract_epi8(b.v, 0),
@@ -358,6 +365,13 @@ static FORCEINLINE __vec4_i8 __lshr(__vec4_i8 a, __vec4_i8 b) {
(uint8_t)_mm_extract_epi8(b.v, 3)); (uint8_t)_mm_extract_epi8(b.v, 3));
} }
static FORCEINLINE __vec4_i8 __lshr(__vec4_i8 a, int32_t b) {
return __vec4_i8((uint8_t)_mm_extract_epi8(a.v, 0) >> b,
(uint8_t)_mm_extract_epi8(a.v, 1) >> b,
(uint8_t)_mm_extract_epi8(a.v, 2) >> b,
(uint8_t)_mm_extract_epi8(a.v, 3) >> b);
}
static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, __vec4_i8 b) { static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, __vec4_i8 b) {
return __vec4_i8((int8_t)_mm_extract_epi8(a.v, 0) >> return __vec4_i8((int8_t)_mm_extract_epi8(a.v, 0) >>
(int8_t)_mm_extract_epi8(b.v, 0), (int8_t)_mm_extract_epi8(b.v, 0),
@@ -369,6 +383,13 @@ static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, __vec4_i8 b) {
(int8_t)_mm_extract_epi8(b.v, 3)); (int8_t)_mm_extract_epi8(b.v, 3));
} }
static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, int32_t b) {
return __vec4_i8((int8_t)_mm_extract_epi8(a.v, 0) >> b,
(int8_t)_mm_extract_epi8(a.v, 1) >> b,
(int8_t)_mm_extract_epi8(a.v, 2) >> b,
(int8_t)_mm_extract_epi8(a.v, 3) >> b);
}
static FORCEINLINE __vec4_i1 __equal(__vec4_i8 a, __vec4_i8 b) { static FORCEINLINE __vec4_i1 __equal(__vec4_i8 a, __vec4_i8 b) {
__m128i cmp = _mm_cmpeq_epi8(a.v, b.v); __m128i cmp = _mm_cmpeq_epi8(a.v, b.v);
return __vec4_i1(_mm_extract_epi8(cmp, 0), return __vec4_i1(_mm_extract_epi8(cmp, 0),
@@ -547,6 +568,10 @@ static FORCEINLINE __vec4_i16 __shl(__vec4_i16 a, __vec4_i16 b) {
_mm_extract_epi16(a.v, 3) << _mm_extract_epi16(b.v, 3)); _mm_extract_epi16(a.v, 3) << _mm_extract_epi16(b.v, 3));
} }
static FORCEINLINE __vec4_i16 __shl(__vec4_i16 a, int32_t b) {
return _mm_sll_epi16(a.v, _mm_set_epi32(0, 0, 0, b));
}
static FORCEINLINE __vec4_i16 __udiv(__vec4_i16 a, __vec4_i16 b) { static FORCEINLINE __vec4_i16 __udiv(__vec4_i16 a, __vec4_i16 b) {
return __vec4_i16((uint16_t)_mm_extract_epi16(a.v, 0) / return __vec4_i16((uint16_t)_mm_extract_epi16(a.v, 0) /
(uint16_t)_mm_extract_epi16(b.v, 0), (uint16_t)_mm_extract_epi16(b.v, 0),
@@ -602,6 +627,10 @@ static FORCEINLINE __vec4_i16 __lshr(__vec4_i16 a, __vec4_i16 b) {
(uint16_t)_mm_extract_epi16(b.v, 3)); (uint16_t)_mm_extract_epi16(b.v, 3));
} }
static FORCEINLINE __vec4_i16 __lshr(__vec4_i16 a, int32_t b) {
return _mm_srl_epi16(a.v, _mm_set_epi32(0, 0, 0, b));
}
static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, __vec4_i16 b) { static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, __vec4_i16 b) {
return __vec4_i16((int16_t)_mm_extract_epi16(a.v, 0) >> return __vec4_i16((int16_t)_mm_extract_epi16(a.v, 0) >>
(int16_t)_mm_extract_epi16(b.v, 0), (int16_t)_mm_extract_epi16(b.v, 0),
@@ -613,6 +642,10 @@ static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, __vec4_i16 b) {
(int16_t)_mm_extract_epi16(b.v, 3)); (int16_t)_mm_extract_epi16(b.v, 3));
} }
static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, int32_t b) {
return _mm_sra_epi16(a.v, _mm_set_epi32(0, 0, 0, b));
}
static FORCEINLINE __vec4_i1 __equal(__vec4_i16 a, __vec4_i16 b) { static FORCEINLINE __vec4_i1 __equal(__vec4_i16 a, __vec4_i16 b) {
__m128i cmp = _mm_cmpeq_epi16(a.v, b.v); __m128i cmp = _mm_cmpeq_epi16(a.v, b.v);
return __vec4_i1(_mm_extract_epi16(cmp, 0), return __vec4_i1(_mm_extract_epi16(cmp, 0),
@@ -789,9 +822,6 @@ static FORCEINLINE __vec4_i32 __xor(__vec4_i32 a, __vec4_i32 b) {
} }
static FORCEINLINE __vec4_i32 __shl(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __shl(__vec4_i32 a, __vec4_i32 b) {
// FIXME: if we can determine at compile time that b has the same value
// across all elements, then we can use _mm_sll_epi32.
/* fixme: llvm generates thie code for shift left, which is presumably /* fixme: llvm generates thie code for shift left, which is presumably
more efficient than doing each component individually as below. more efficient than doing each component individually as below.
@@ -813,57 +843,92 @@ _f___ii: ## @f___ii
ret ret
*/ */
return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) << _mm_extract_epi32(b.v, 0), return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) <<
(uint32_t)_mm_extract_epi32(a.v, 1) << _mm_extract_epi32(b.v, 1), _mm_extract_epi32(b.v, 0),
(uint32_t)_mm_extract_epi32(a.v, 2) << _mm_extract_epi32(b.v, 2), (uint32_t)_mm_extract_epi32(a.v, 1) <<
(uint32_t)_mm_extract_epi32(a.v, 3) << _mm_extract_epi32(b.v, 3)); _mm_extract_epi32(b.v, 1),
(uint32_t)_mm_extract_epi32(a.v, 2) <<
_mm_extract_epi32(b.v, 2),
(uint32_t)_mm_extract_epi32(a.v, 3) <<
_mm_extract_epi32(b.v, 3));
}
static FORCEINLINE __vec4_i32 __shl(__vec4_i32 a, int32_t b) {
return _mm_sll_epi32(a.v, _mm_set_epi32(0, 0, 0, b));
} }
static FORCEINLINE __vec4_i32 __udiv(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __udiv(__vec4_i32 a, __vec4_i32 b) {
return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) / (uint32_t)_mm_extract_epi32(b.v, 0), return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) /
(uint32_t)_mm_extract_epi32(a.v, 1) / (uint32_t)_mm_extract_epi32(b.v, 1), (uint32_t)_mm_extract_epi32(b.v, 0),
(uint32_t)_mm_extract_epi32(a.v, 2) / (uint32_t)_mm_extract_epi32(b.v, 2), (uint32_t)_mm_extract_epi32(a.v, 1) /
(uint32_t)_mm_extract_epi32(a.v, 3) / (uint32_t)_mm_extract_epi32(b.v, 3)); (uint32_t)_mm_extract_epi32(b.v, 1),
(uint32_t)_mm_extract_epi32(a.v, 2) /
(uint32_t)_mm_extract_epi32(b.v, 2),
(uint32_t)_mm_extract_epi32(a.v, 3) /
(uint32_t)_mm_extract_epi32(b.v, 3));
} }
static FORCEINLINE __vec4_i32 __sdiv(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __sdiv(__vec4_i32 a, __vec4_i32 b) {
return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) / (int32_t)_mm_extract_epi32(b.v, 0), return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) /
(int32_t)_mm_extract_epi32(a.v, 1) / (int32_t)_mm_extract_epi32(b.v, 1), (int32_t)_mm_extract_epi32(b.v, 0),
(int32_t)_mm_extract_epi32(a.v, 2) / (int32_t)_mm_extract_epi32(b.v, 2), (int32_t)_mm_extract_epi32(a.v, 1) /
(int32_t)_mm_extract_epi32(a.v, 3) / (int32_t)_mm_extract_epi32(b.v, 3)); (int32_t)_mm_extract_epi32(b.v, 1),
(int32_t)_mm_extract_epi32(a.v, 2) /
(int32_t)_mm_extract_epi32(b.v, 2),
(int32_t)_mm_extract_epi32(a.v, 3) /
(int32_t)_mm_extract_epi32(b.v, 3));
} }
static FORCEINLINE __vec4_i32 __urem(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __urem(__vec4_i32 a, __vec4_i32 b) {
return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) % (uint32_t)_mm_extract_epi32(b.v, 0), return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) %
(uint32_t)_mm_extract_epi32(a.v, 1) % (uint32_t)_mm_extract_epi32(b.v, 1), (uint32_t)_mm_extract_epi32(b.v, 0),
(uint32_t)_mm_extract_epi32(a.v, 2) % (uint32_t)_mm_extract_epi32(b.v, 2), (uint32_t)_mm_extract_epi32(a.v, 1) %
(uint32_t)_mm_extract_epi32(a.v, 3) % (uint32_t)_mm_extract_epi32(b.v, 3)); (uint32_t)_mm_extract_epi32(b.v, 1),
(uint32_t)_mm_extract_epi32(a.v, 2) %
(uint32_t)_mm_extract_epi32(b.v, 2),
(uint32_t)_mm_extract_epi32(a.v, 3) %
(uint32_t)_mm_extract_epi32(b.v, 3));
} }
static FORCEINLINE __vec4_i32 __srem(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __srem(__vec4_i32 a, __vec4_i32 b) {
return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) % (int32_t)_mm_extract_epi32(b.v, 0), return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) %
(int32_t)_mm_extract_epi32(a.v, 1) % (int32_t)_mm_extract_epi32(b.v, 1), (int32_t)_mm_extract_epi32(b.v, 0),
(int32_t)_mm_extract_epi32(a.v, 2) % (int32_t)_mm_extract_epi32(b.v, 2), (int32_t)_mm_extract_epi32(a.v, 1) %
(int32_t)_mm_extract_epi32(a.v, 3) % (int32_t)_mm_extract_epi32(b.v, 3)); (int32_t)_mm_extract_epi32(b.v, 1),
(int32_t)_mm_extract_epi32(a.v, 2) %
(int32_t)_mm_extract_epi32(b.v, 2),
(int32_t)_mm_extract_epi32(a.v, 3) %
(int32_t)_mm_extract_epi32(b.v, 3));
} }
static FORCEINLINE __vec4_i32 __lshr(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __lshr(__vec4_i32 a, __vec4_i32 b) {
// FIXME: if we can determine at compile time that b has the same value return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) >>
// across all elements, e.g. using gcc's __builtin_constant_p, then we _mm_extract_epi32(b.v, 0),
// can use _mm_srl_epi32. (uint32_t)_mm_extract_epi32(a.v, 1) >>
return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) >> _mm_extract_epi32(b.v, 0), _mm_extract_epi32(b.v, 1),
(uint32_t)_mm_extract_epi32(a.v, 1) >> _mm_extract_epi32(b.v, 1), (uint32_t)_mm_extract_epi32(a.v, 2) >>
(uint32_t)_mm_extract_epi32(a.v, 2) >> _mm_extract_epi32(b.v, 2), _mm_extract_epi32(b.v, 2),
(uint32_t)_mm_extract_epi32(a.v, 3) >> _mm_extract_epi32(b.v, 3)); (uint32_t)_mm_extract_epi32(a.v, 3) >>
_mm_extract_epi32(b.v, 3));
}
static FORCEINLINE __vec4_i32 __lshr(__vec4_i32 a, int32_t b) {
return _mm_srl_epi32(a.v, _mm_set_epi32(0, 0, 0, b));
} }
static FORCEINLINE __vec4_i32 __ashr(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i32 __ashr(__vec4_i32 a, __vec4_i32 b) {
// FIXME: if we can determine at compile time that b has the same value return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) >>
// across all elements, then we can use _mm_sra_epi32. _mm_extract_epi32(b.v, 0),
return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) >> _mm_extract_epi32(b.v, 0), (int32_t)_mm_extract_epi32(a.v, 1) >>
(int32_t)_mm_extract_epi32(a.v, 1) >> _mm_extract_epi32(b.v, 1), _mm_extract_epi32(b.v, 1),
(int32_t)_mm_extract_epi32(a.v, 2) >> _mm_extract_epi32(b.v, 2), (int32_t)_mm_extract_epi32(a.v, 2) >>
(int32_t)_mm_extract_epi32(a.v, 3) >> _mm_extract_epi32(b.v, 3)); _mm_extract_epi32(b.v, 2),
(int32_t)_mm_extract_epi32(a.v, 3) >>
_mm_extract_epi32(b.v, 3));
}
static FORCEINLINE __vec4_i32 __ashr(__vec4_i32 a, int32_t b) {
return _mm_sra_epi32(a.v, _mm_set_epi32(0, 0, 0, b));
} }
static FORCEINLINE __vec4_i1 __equal(__vec4_i32 a, __vec4_i32 b) { static FORCEINLINE __vec4_i1 __equal(__vec4_i32 a, __vec4_i32 b) {
@@ -1016,6 +1081,12 @@ static FORCEINLINE __vec4_i64 __shl(__vec4_i64 a, __vec4_i64 b) {
_mm_extract_epi64(a.v[1], 1) << _mm_extract_epi64(b.v[1], 1)); _mm_extract_epi64(a.v[1], 1) << _mm_extract_epi64(b.v[1], 1));
} }
static FORCEINLINE __vec4_i64 __shl(__vec4_i64 a, int32_t b) {
__m128i amt = _mm_set_epi32(0, 0, 0, b);
return __vec4_i64(_mm_sll_epi64(a.v[0], amt),
_mm_sll_epi64(a.v[1], amt));
}
static FORCEINLINE __vec4_i64 __udiv(__vec4_i64 a, __vec4_i64 b) { static FORCEINLINE __vec4_i64 __udiv(__vec4_i64 a, __vec4_i64 b) {
return __vec4_i64((uint64_t)_mm_extract_epi64(a.v[0], 0) / return __vec4_i64((uint64_t)_mm_extract_epi64(a.v[0], 0) /
(uint64_t)_mm_extract_epi64(b.v[0], 0), (uint64_t)_mm_extract_epi64(b.v[0], 0),
@@ -1071,6 +1142,12 @@ static FORCEINLINE __vec4_i64 __lshr(__vec4_i64 a, __vec4_i64 b) {
(uint64_t)_mm_extract_epi64(b.v[1], 1)); (uint64_t)_mm_extract_epi64(b.v[1], 1));
} }
static FORCEINLINE __vec4_i64 __lshr(__vec4_i64 a, int32_t b) {
__m128i amt = _mm_set_epi32(0, 0, 0, b);
return __vec4_i64(_mm_srl_epi64(a.v[0], amt),
_mm_srl_epi64(a.v[1], amt));
}
static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, __vec4_i64 b) { static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, __vec4_i64 b) {
return __vec4_i64((int64_t)_mm_extract_epi64(a.v[0], 0) >> return __vec4_i64((int64_t)_mm_extract_epi64(a.v[0], 0) >>
(int64_t)_mm_extract_epi64(b.v[0], 0), (int64_t)_mm_extract_epi64(b.v[0], 0),
@@ -1082,6 +1159,13 @@ static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, __vec4_i64 b) {
(int64_t)_mm_extract_epi64(b.v[1], 1)); (int64_t)_mm_extract_epi64(b.v[1], 1));
} }
static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, int32_t b) {
return __vec4_i64((int64_t)_mm_extract_epi64(a.v[0], 0) >> b,
(int64_t)_mm_extract_epi64(a.v[0], 1) >> b,
(int64_t)_mm_extract_epi64(a.v[1], 0) >> b,
(int64_t)_mm_extract_epi64(a.v[1], 1) >> b);
}
static FORCEINLINE __vec4_i1 __equal(__vec4_i64 a, __vec4_i64 b) { static FORCEINLINE __vec4_i1 __equal(__vec4_i64 a, __vec4_i64 b) {
__m128i cmp0 = _mm_cmpeq_epi64(a.v[0], b.v[0]); __m128i cmp0 = _mm_cmpeq_epi64(a.v[0], b.v[0]);
__m128i cmp1 = _mm_cmpeq_epi64(a.v[1], b.v[1]); __m128i cmp1 = _mm_cmpeq_epi64(a.v[1], b.v[1]);

View File

@@ -36,7 +36,9 @@
*/ */
#include "llvmutil.h" #include "llvmutil.h"
#include "ispc.h"
#include "type.h" #include "type.h"
#include <llvm/Instructions.h>
LLVM_TYPE_CONST llvm::Type *LLVMTypes::VoidType = NULL; LLVM_TYPE_CONST llvm::Type *LLVMTypes::VoidType = NULL;
LLVM_TYPE_CONST llvm::PointerType *LLVMTypes::VoidPointerType = NULL; LLVM_TYPE_CONST llvm::PointerType *LLVMTypes::VoidPointerType = NULL;
@@ -465,3 +467,239 @@ LLVMBoolVector(const bool *bvec) {
} }
return llvm::ConstantVector::get(vals); return llvm::ConstantVector::get(vals);
} }
/** Conservative test to see if two llvm::Values are equal. There are
(potentially many) cases where the two values actually are equal but
this will return false. However, if it does return true, the two
vectors definitely are equal.
@todo This seems to catch all of the cases we currently need it for in
practice, but it's be nice to make it a little more robust/general. In
general, though, a little something called the halting problem means we
won't get all of them.
*/
static bool
lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
std::vector<llvm::PHINode *> &seenPhi0,
std::vector<llvm::PHINode *> &seenPhi1) {
// Thanks to the fact that LLVM hashes and returns the same pointer for
// constants (of all sorts, even constant expressions), this first test
// actually catches a lot of cases. LLVM's SSA form also helps a lot
// with this..
if (v0 == v1)
return true;
Assert(seenPhi0.size() == seenPhi1.size());
for (unsigned int i = 0; i < seenPhi0.size(); ++i)
if (v0 == seenPhi0[i] && v1 == seenPhi1[i])
return true;
llvm::BinaryOperator *bo0 = llvm::dyn_cast<llvm::BinaryOperator>(v0);
llvm::BinaryOperator *bo1 = llvm::dyn_cast<llvm::BinaryOperator>(v1);
if (bo0 != NULL && bo1 != NULL) {
if (bo0->getOpcode() != bo1->getOpcode())
return false;
return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0),
seenPhi0, seenPhi1) &&
lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1),
seenPhi0, seenPhi1));
}
llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
if (phi0 != NULL && phi1 != NULL) {
if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues())
return false;
seenPhi0.push_back(phi0);
seenPhi1.push_back(phi1);
unsigned int numIncoming = phi0->getNumIncomingValues();
// Check all of the incoming values: if all of them are all equal,
// then we're good.
bool anyFailure = false;
for (unsigned int i = 0; i < numIncoming; ++i) {
Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
if (!lValuesAreEqual(phi0->getIncomingValue(i),
phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
anyFailure = true;
break;
}
}
seenPhi0.pop_back();
seenPhi1.pop_back();
return !anyFailure;
}
return false;
}
/** Given an llvm::Value known to be an integer, return its value as
an int64_t.
*/
static int64_t
lGetIntValue(llvm::Value *offset) {
llvm::ConstantInt *intOffset = llvm::dyn_cast<llvm::ConstantInt>(offset);
Assert(intOffset && (intOffset->getBitWidth() == 32 ||
intOffset->getBitWidth() == 64));
return intOffset->getSExtValue();
}
/** This function takes chains of InsertElement instructions along the
lines of:
%v0 = insertelement undef, value_0, i32 index_0
%v1 = insertelement %v1, value_1, i32 index_1
...
%vn = insertelement %vn-1, value_n-1, i32 index_n-1
and initializes the provided elements array such that the i'th
llvm::Value * in the array is the element that was inserted into the
i'th element of the vector.
*/
void
LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth,
llvm::Value **elements) {
for (int i = 0; i < vectorWidth; ++i)
elements[i] = NULL;
while (ie != NULL) {
int64_t iOffset = lGetIntValue(ie->getOperand(2));
Assert(iOffset >= 0 && iOffset < vectorWidth);
Assert(elements[iOffset] == NULL);
elements[iOffset] = ie->getOperand(1);
llvm::Value *insertBase = ie->getOperand(0);
ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase);
if (ie == NULL) {
if (llvm::isa<llvm::UndefValue>(insertBase))
return;
llvm::ConstantVector *cv =
llvm::dyn_cast<llvm::ConstantVector>(insertBase);
Assert(cv != NULL);
Assert(iOffset < (int)cv->getNumOperands());
elements[iOffset] = cv->getOperand(iOffset);
}
}
}
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */
bool
LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
if (llvm::isa<llvm::ConstantAggregateZero>(v))
return true;
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
if (cv != NULL)
return (cv->getSplatValue() != NULL);
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
if (bop != NULL)
return (LLVMVectorValuesAllEqual(bop->getOperand(0), vectorLength,
seenPhis) &&
LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength,
seenPhis));
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
if (cast != NULL)
return LLVMVectorValuesAllEqual(cast->getOperand(0), vectorLength,
seenPhis);
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
if (ie != NULL) {
llvm::Value *elements[ISPC_MAX_NVEC];
LLVMFlattenInsertChain(ie, vectorLength, elements);
// We will ignore any values of elements[] that are NULL; as they
// correspond to undefined values--we just want to see if all of
// the defined values have the same value.
int lastNonNull = 0;
while (lastNonNull < vectorLength && elements[lastNonNull] == NULL)
++lastNonNull;
if (lastNonNull == vectorLength)
// all of them are undef!
return true;
for (int i = lastNonNull; i < vectorLength; ++i) {
if (elements[i] == NULL)
continue;
std::vector<llvm::PHINode *> seenPhi0;
std::vector<llvm::PHINode *> seenPhi1;
if (lValuesAreEqual(elements[lastNonNull], elements[i], seenPhi0,
seenPhi1) == false)
return false;
lastNonNull = i;
}
return true;
}
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
if (phi) {
for (unsigned int i = 0; i < seenPhis.size(); ++i)
if (seenPhis[i] == phi)
return true;
seenPhis.push_back(phi);
unsigned int numIncoming = phi->getNumIncomingValues();
// Check all of the incoming values: if all of them are all equal,
// then we're good.
for (unsigned int i = 0; i < numIncoming; ++i) {
if (!LLVMVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength,
seenPhis)) {
seenPhis.pop_back();
return false;
}
}
seenPhis.pop_back();
return true;
}
Assert(!llvm::isa<llvm::Constant>(v));
if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v) ||
!llvm::isa<llvm::Instruction>(v))
return false;
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
if (shuffle != NULL) {
llvm::Value *indices = shuffle->getOperand(2);
if (LLVMVectorValuesAllEqual(indices, vectorLength, seenPhis))
// The easy case--just a smear of the same element across the
// whole vector.
return true;
// TODO: handle more general cases?
return false;
}
#if 0
fprintf(stderr, "all equal: ");
v->dump();
fprintf(stderr, "\n");
llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
if (inst) {
inst->getParent()->dump();
fprintf(stderr, "\n");
fprintf(stderr, "\n");
}
#endif
return false;
}

View File

@@ -38,12 +38,23 @@
#ifndef ISPC_LLVMUTIL_H #ifndef ISPC_LLVMUTIL_H
#define ISPC_LLVMUTIL_H 1 #define ISPC_LLVMUTIL_H 1
#include "ispc.h"
#include <llvm/LLVMContext.h> #include <llvm/LLVMContext.h>
#include <llvm/Type.h> #include <llvm/Type.h>
#include <llvm/DerivedTypes.h> #include <llvm/DerivedTypes.h>
#include <llvm/Constants.h> #include <llvm/Constants.h>
namespace llvm {
class PHINode;
class InsertElementInst;
}
// llvm::Type *s are no longer const in llvm 3.0
#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn)
#define LLVM_TYPE_CONST
#else
#define LLVM_TYPE_CONST const
#endif
/** This structure holds pointers to a variety of LLVM types; code /** This structure holds pointers to a variety of LLVM types; code
elsewhere can use them from here, ratherthan needing to make more elsewhere can use them from here, ratherthan needing to make more
@@ -99,6 +110,7 @@ extern llvm::Constant *LLVMTrue, *LLVMFalse;
of LLVMTypes and the LLVMTrue/LLVMFalse constants. However, it can't of LLVMTypes and the LLVMTrue/LLVMFalse constants. However, it can't
be called until the compilation target is known. be called until the compilation target is known.
*/ */
struct Target;
extern void InitLLVMUtil(llvm::LLVMContext *ctx, Target target); extern void InitLLVMUtil(llvm::LLVMContext *ctx, Target target);
/** Returns an LLVM i8 constant of the given value */ /** Returns an LLVM i8 constant of the given value */
@@ -205,4 +217,13 @@ extern llvm::Constant *LLVMMaskAllOn;
/** LLVM constant value representing an 'all off' SIMD lane mask */ /** LLVM constant value representing an 'all off' SIMD lane mask */
extern llvm::Constant *LLVMMaskAllOff; extern llvm::Constant *LLVMMaskAllOff;
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */
extern bool LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis);
void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth,
llvm::Value **elements);
#endif // ISPC_LLVMUTIL_H #endif // ISPC_LLVMUTIL_H

243
opt.cpp
View File

@@ -921,58 +921,6 @@ char GatherScatterFlattenOpt::ID = 0;
llvm::RegisterPass<GatherScatterFlattenOpt> gsf("gs-flatten", "Gather/Scatter Flatten Pass"); llvm::RegisterPass<GatherScatterFlattenOpt> gsf("gs-flatten", "Gather/Scatter Flatten Pass");
/** Given an llvm::Value known to be an integer, return its value as
an int64_t.
*/
static int64_t
lGetIntValue(llvm::Value *offset) {
llvm::ConstantInt *intOffset = llvm::dyn_cast<llvm::ConstantInt>(offset);
Assert(intOffset && (intOffset->getBitWidth() == 32 ||
intOffset->getBitWidth() == 64));
return intOffset->getSExtValue();
}
/** This function takes chains of InsertElement instructions along the
lines of:
%v0 = insertelement undef, value_0, i32 index_0
%v1 = insertelement %v1, value_1, i32 index_1
...
%vn = insertelement %vn-1, value_n-1, i32 index_n-1
and initializes the provided elements array such that the i'th
llvm::Value * in the array is the element that was inserted into the
i'th element of the vector.
*/
static void
lFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth,
llvm::Value **elements) {
for (int i = 0; i < vectorWidth; ++i)
elements[i] = NULL;
while (ie != NULL) {
int64_t iOffset = lGetIntValue(ie->getOperand(2));
Assert(iOffset >= 0 && iOffset < vectorWidth);
Assert(elements[iOffset] == NULL);
elements[iOffset] = ie->getOperand(1);
llvm::Value *insertBase = ie->getOperand(0);
ie = llvm::dyn_cast<llvm::InsertElementInst>(insertBase);
if (ie == NULL) {
if (llvm::isa<llvm::UndefValue>(insertBase))
return;
llvm::ConstantVector *cv =
llvm::dyn_cast<llvm::ConstantVector>(insertBase);
Assert(cv != NULL);
Assert(iOffset < (int)cv->getNumOperands());
elements[iOffset] = cv->getOperand(iOffset);
}
}
}
/** Check to make sure that this value is actually a pointer in the end. /** Check to make sure that this value is actually a pointer in the end.
We need to make sure that given an expression like vec(offset) + We need to make sure that given an expression like vec(offset) +
ptr2int(ptr), lGetBasePointer() doesn't return vec(offset) for the base ptr2int(ptr), lGetBasePointer() doesn't return vec(offset) for the base
@@ -1011,7 +959,7 @@ lGetBasePointer(llvm::Value *v) {
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v); llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
if (ie != NULL) { if (ie != NULL) {
llvm::Value *elements[ISPC_MAX_NVEC]; llvm::Value *elements[ISPC_MAX_NVEC];
lFlattenInsertChain(ie, g->target.vectorWidth, elements); LLVMFlattenInsertChain(ie, g->target.vectorWidth, elements);
// Make sure none of the elements is undefined. // Make sure none of the elements is undefined.
// TODO: it's probably ok to allow undefined elements and return // TODO: it's probably ok to allow undefined elements and return
@@ -1825,187 +1773,6 @@ llvm::RegisterPass<GSImprovementsPass> gsi("gs-improvements",
"Gather/Scatter Improvements Pass"); "Gather/Scatter Improvements Pass");
/** Conservative test to see if two llvm::Values are equal. There are
(potentially many) cases where the two values actually are equal but
this will return false. However, if it does return true, the two
vectors definitely are equal.
@todo This seems to catch all of the cases we currently need it for in
practice, but it's be nice to make it a little more robust/general. In
general, though, a little something called the halting problem means we
won't get all of them.
*/
static bool
lValuesAreEqual(llvm::Value *v0, llvm::Value *v1,
std::vector<llvm::PHINode *> &seenPhi0,
std::vector<llvm::PHINode *> &seenPhi1) {
// Thanks to the fact that LLVM hashes and returns the same pointer for
// constants (of all sorts, even constant expressions), this first test
// actually catches a lot of cases. LLVM's SSA form also helps a lot
// with this..
if (v0 == v1)
return true;
Assert(seenPhi0.size() == seenPhi1.size());
for (unsigned int i = 0; i < seenPhi0.size(); ++i)
if (v0 == seenPhi0[i] && v1 == seenPhi1[i])
return true;
llvm::BinaryOperator *bo0 = llvm::dyn_cast<llvm::BinaryOperator>(v0);
llvm::BinaryOperator *bo1 = llvm::dyn_cast<llvm::BinaryOperator>(v1);
if (bo0 != NULL && bo1 != NULL) {
if (bo0->getOpcode() != bo1->getOpcode())
return false;
return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0),
seenPhi0, seenPhi1) &&
lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1),
seenPhi0, seenPhi1));
}
llvm::PHINode *phi0 = llvm::dyn_cast<llvm::PHINode>(v0);
llvm::PHINode *phi1 = llvm::dyn_cast<llvm::PHINode>(v1);
if (phi0 != NULL && phi1 != NULL) {
if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues())
return false;
seenPhi0.push_back(phi0);
seenPhi1.push_back(phi1);
unsigned int numIncoming = phi0->getNumIncomingValues();
// Check all of the incoming values: if all of them are all equal,
// then we're good.
bool anyFailure = false;
for (unsigned int i = 0; i < numIncoming; ++i) {
Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i));
if (!lValuesAreEqual(phi0->getIncomingValue(i),
phi1->getIncomingValue(i), seenPhi0, seenPhi1)) {
anyFailure = true;
break;
}
}
seenPhi0.pop_back();
seenPhi1.pop_back();
return !anyFailure;
}
return false;
}
/** Tests to see if all of the elements of the vector in the 'v' parameter
are equal. Like lValuesAreEqual(), this is a conservative test and may
return false for arrays where the values are actually all equal. */
static bool
lVectorValuesAllEqual(llvm::Value *v, int vectorLength,
std::vector<llvm::PHINode *> &seenPhis) {
if (llvm::isa<llvm::ConstantAggregateZero>(v))
return true;
llvm::ConstantVector *cv = llvm::dyn_cast<llvm::ConstantVector>(v);
if (cv != NULL)
return (cv->getSplatValue() != NULL);
llvm::BinaryOperator *bop = llvm::dyn_cast<llvm::BinaryOperator>(v);
if (bop != NULL)
return (lVectorValuesAllEqual(bop->getOperand(0), vectorLength,
seenPhis) &&
lVectorValuesAllEqual(bop->getOperand(1), vectorLength,
seenPhis));
llvm::CastInst *cast = llvm::dyn_cast<llvm::CastInst>(v);
if (cast != NULL)
return lVectorValuesAllEqual(cast->getOperand(0), vectorLength,
seenPhis);
llvm::InsertElementInst *ie = llvm::dyn_cast<llvm::InsertElementInst>(v);
if (ie != NULL) {
llvm::Value *elements[ISPC_MAX_NVEC];
lFlattenInsertChain(ie, vectorLength, elements);
// We will ignore any values of elements[] that are NULL; as they
// correspond to undefined values--we just want to see if all of
// the defined values have the same value.
int lastNonNull = 0;
while (lastNonNull < vectorLength && elements[lastNonNull] == NULL)
++lastNonNull;
if (lastNonNull == vectorLength)
// all of them are undef!
return true;
for (int i = lastNonNull; i < vectorLength; ++i) {
if (elements[i] == NULL)
continue;
std::vector<llvm::PHINode *> seenPhi0;
std::vector<llvm::PHINode *> seenPhi1;
if (lValuesAreEqual(elements[lastNonNull], elements[i], seenPhi0,
seenPhi1) == false)
return false;
lastNonNull = i;
}
return true;
}
llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(v);
if (phi) {
for (unsigned int i = 0; i < seenPhis.size(); ++i)
if (seenPhis[i] == phi)
return true;
seenPhis.push_back(phi);
unsigned int numIncoming = phi->getNumIncomingValues();
// Check all of the incoming values: if all of them are all equal,
// then we're good.
for (unsigned int i = 0; i < numIncoming; ++i) {
if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength,
seenPhis)) {
seenPhis.pop_back();
return false;
}
}
seenPhis.pop_back();
return true;
}
Assert(!llvm::isa<llvm::Constant>(v));
if (llvm::isa<llvm::CallInst>(v) || llvm::isa<llvm::LoadInst>(v) ||
!llvm::isa<llvm::Instruction>(v))
return false;
llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast<llvm::ShuffleVectorInst>(v);
if (shuffle != NULL) {
llvm::Value *indices = shuffle->getOperand(2);
if (lVectorValuesAllEqual(indices, vectorLength, seenPhis))
// The easy case--just a smear of the same element across the
// whole vector.
return true;
// TODO: handle more general cases?
return false;
}
#if 0
fprintf(stderr, "all equal: ");
v->dump();
fprintf(stderr, "\n");
llvm::Instruction *inst = llvm::dyn_cast<llvm::Instruction>(v);
if (inst) {
inst->getParent()->dump();
fprintf(stderr, "\n");
fprintf(stderr, "\n");
}
#endif
return false;
}
/** Given a vector of compile-time constant integer values, test to see if /** Given a vector of compile-time constant integer values, test to see if
they are a linear sequence of constant integers starting from an they are a linear sequence of constant integers starting from an
arbirary value but then having a step of value "stride" between arbirary value but then having a step of value "stride" between
@@ -2102,9 +1869,9 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
// programIndex + unif -> ascending linear seqeuence // programIndex + unif -> ascending linear seqeuence
// unif + programIndex -> ascending linear sequence // unif + programIndex -> ascending linear sequence
return ((lVectorIsLinear(op0, vectorLength, stride, seenPhis) && return ((lVectorIsLinear(op0, vectorLength, stride, seenPhis) &&
lVectorValuesAllEqual(op1, vectorLength, seenPhis)) || LLVMVectorValuesAllEqual(op1, vectorLength, seenPhis)) ||
(lVectorIsLinear(op1, vectorLength, stride, seenPhis) && (lVectorIsLinear(op1, vectorLength, stride, seenPhis) &&
lVectorValuesAllEqual(op0, vectorLength, seenPhis))); LLVMVectorValuesAllEqual(op0, vectorLength, seenPhis)));
else if (bop->getOpcode() == llvm::Instruction::Sub) else if (bop->getOpcode() == llvm::Instruction::Sub)
// For subtraction, we only match: // For subtraction, we only match:
// //
@@ -2115,7 +1882,7 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride,
// And generate code for that as a vector load + shuffle. // And generate code for that as a vector load + shuffle.
return (lVectorIsLinear(bop->getOperand(0), vectorLength, return (lVectorIsLinear(bop->getOperand(0), vectorLength,
stride, seenPhis) && stride, seenPhis) &&
lVectorValuesAllEqual(bop->getOperand(1), vectorLength, LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength,
seenPhis)); seenPhis));
else if (bop->getOpcode() == llvm::Instruction::Mul) else if (bop->getOpcode() == llvm::Instruction::Mul)
// Multiplies are a bit trickier, so are handled in a separate // Multiplies are a bit trickier, so are handled in a separate
@@ -2313,7 +2080,7 @@ GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) {
{ {
std::vector<llvm::PHINode *> seenPhis; std::vector<llvm::PHINode *> seenPhis;
if (lVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) { if (LLVMVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) {
// If all the offsets are equal, then compute the single // If all the offsets are equal, then compute the single
// pointer they all represent based on the first one of them // pointer they all represent based on the first one of them
// (arbitrarily). // (arbitrarily).