diff --git a/cbackend.cpp b/cbackend.cpp index 3c9ce0af..3ae9f5e0 100644 --- a/cbackend.cpp +++ b/cbackend.cpp @@ -24,6 +24,8 @@ #define PRIx64 "llx" #endif +#include "llvmutil.h" + #include "llvm/CallingConv.h" #include "llvm/Constants.h" #include "llvm/DerivedTypes.h" @@ -232,6 +234,7 @@ namespace { unsigned NextAnonValueNumber; std::string includeName; + int vectorWidth; /// UnnamedStructIDs - This contains a unique ID for each struct that is /// either anonymous or has no name. @@ -240,11 +243,13 @@ namespace { public: static char ID; - explicit CWriter(formatted_raw_ostream &o, const char *incname) + explicit CWriter(formatted_raw_ostream &o, const char *incname, + int vecwidth) : FunctionPass(ID), Out(o), IL(0), Mang(0), LI(0), TheModule(0), TAsm(0), MRI(0), MOFI(0), TCtx(0), TD(0), OpaqueCounter(0), NextAnonValueNumber(0), - includeName(incname ? incname : "generic_defs.h") { + includeName(incname ? incname : "generic_defs.h"), + vectorWidth(vecwidth) { initializeLoopInfoPass(*PassRegistry::getPassRegistry()); FPCounter = 0; } @@ -2894,7 +2899,21 @@ void CWriter::visitBinaryOperator(Instruction &I) { Out << "("; writeOperand(I.getOperand(0)); Out << ", "; - writeOperand(I.getOperand(1)); + if ((I.getOpcode() == Instruction::Shl || + I.getOpcode() == Instruction::LShr || + I.getOpcode() == Instruction::AShr)) { + std::vector phis; + if (LLVMVectorValuesAllEqual(I.getOperand(1), + vectorWidth, phis)) { + Out << "__extract_element("; + writeOperand(I.getOperand(1)); + Out << ", 0) "; + } + else + writeOperand(I.getOperand(1)); + } + else + writeOperand(I.getOperand(1)); Out << ")"; return; } @@ -3635,7 +3654,7 @@ std::string CWriter::InterpretASMConstraint(InlineAsm::ConstraintInfo& c) { #endif std::string E; - if (const Target *Match = TargetRegistry::lookupTarget(Triple, E)) + if (const llvm::Target *Match = TargetRegistry::lookupTarget(Triple, E)) TargetAsm = Match->createMCAsmInfo(Triple); else return c.Codes[0]; @@ -4337,7 +4356,7 @@ WriteCXXFile(llvm::Module *module, const char *fn, int vectorWidth, pm.add(new BitcastCleanupPass); pm.add(createDeadCodeEliminationPass()); // clean up after smear pass //CO pm.add(createPrintModulePass(&fos)); - pm.add(new CWriter(fos, includeName)); + pm.add(new CWriter(fos, includeName, vectorWidth)); pm.add(createGCInfoDeleter()); //CO pm.add(createVerifierPass()); diff --git a/examples/intrinsics/generic-16.h b/examples/intrinsics/generic-16.h index c7600918..ffeb4680 100644 --- a/examples/intrinsics/generic-16.h +++ b/examples/intrinsics/generic-16.h @@ -251,6 +251,14 @@ static FORCEINLINE TYPE __select(bool cond, TYPE a, TYPE b) { \ return cond ? a : b; \ } +#define SHIFT_UNIFORM(TYPE, CAST, NAME, OP) \ +static FORCEINLINE TYPE NAME(TYPE a, int32_t b) { \ + TYPE ret; \ + for (int i = 0; i < 16; ++i) \ + ret.v[i] = (CAST)(a.v[i]) OP b; \ + return ret; \ +} + #define SMEAR(VTYPE, NAME, STYPE) \ static FORCEINLINE VTYPE __smear_##NAME(STYPE v) { \ VTYPE ret; \ @@ -386,6 +394,10 @@ BINARY_OP_CAST(__vec16_i8, int8_t, __srem, %) BINARY_OP_CAST(__vec16_i8, uint8_t, __lshr, >>) BINARY_OP_CAST(__vec16_i8, int8_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i8, uint8_t, __lshr, >>) +SHIFT_UNIFORM(__vec16_i8, int8_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i8, int8_t, __shl, <<) + CMP_OP(__vec16_i8, int8_t, __equal, ==) CMP_OP(__vec16_i8, int8_t, __not_equal, !=) CMP_OP(__vec16_i8, uint8_t, __unsigned_less_equal, <=) @@ -425,6 +437,10 @@ BINARY_OP_CAST(__vec16_i16, int16_t, __srem, %) BINARY_OP_CAST(__vec16_i16, uint16_t, __lshr, >>) BINARY_OP_CAST(__vec16_i16, int16_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i16, uint16_t, __lshr, >>) +SHIFT_UNIFORM(__vec16_i16, int16_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i16, int16_t, __shl, <<) + CMP_OP(__vec16_i16, int16_t, __equal, ==) CMP_OP(__vec16_i16, int16_t, __not_equal, !=) CMP_OP(__vec16_i16, uint16_t, __unsigned_less_equal, <=) @@ -464,6 +480,10 @@ BINARY_OP_CAST(__vec16_i32, int32_t, __srem, %) BINARY_OP_CAST(__vec16_i32, uint32_t, __lshr, >>) BINARY_OP_CAST(__vec16_i32, int32_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i32, uint32_t, __lshr, >>) +SHIFT_UNIFORM(__vec16_i32, int32_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i32, int32_t, __shl, <<) + CMP_OP(__vec16_i32, int32_t, __equal, ==) CMP_OP(__vec16_i32, int32_t, __not_equal, !=) CMP_OP(__vec16_i32, uint32_t, __unsigned_less_equal, <=) @@ -503,6 +523,10 @@ BINARY_OP_CAST(__vec16_i64, int64_t, __srem, %) BINARY_OP_CAST(__vec16_i64, uint64_t, __lshr, >>) BINARY_OP_CAST(__vec16_i64, int64_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i64, uint64_t, __lshr, >>) +SHIFT_UNIFORM(__vec16_i64, int64_t, __ashr, >>) +SHIFT_UNIFORM(__vec16_i64, int64_t, __shl, <<) + CMP_OP(__vec16_i64, int64_t, __equal, ==) CMP_OP(__vec16_i64, int64_t, __not_equal, !=) CMP_OP(__vec16_i64, uint64_t, __unsigned_less_equal, <=) diff --git a/examples/intrinsics/sse4.h b/examples/intrinsics/sse4.h index 81444ecb..2dc48b06 100644 --- a/examples/intrinsics/sse4.h +++ b/examples/intrinsics/sse4.h @@ -303,6 +303,13 @@ static FORCEINLINE __vec4_i8 __shl(__vec4_i8 a, __vec4_i8 b) { _mm_extract_epi8(a.v, 3) << _mm_extract_epi8(b.v, 3)); } +static FORCEINLINE __vec4_i8 __shl(__vec4_i8 a, int32_t b) { + return __vec4_i8(_mm_extract_epi8(a.v, 0) << b, + _mm_extract_epi8(a.v, 1) << b, + _mm_extract_epi8(a.v, 2) << b, + _mm_extract_epi8(a.v, 3) << b); +} + static FORCEINLINE __vec4_i8 __udiv(__vec4_i8 a, __vec4_i8 b) { return __vec4_i8((uint8_t)_mm_extract_epi8(a.v, 0) / (uint8_t)_mm_extract_epi8(b.v, 0), @@ -358,6 +365,13 @@ static FORCEINLINE __vec4_i8 __lshr(__vec4_i8 a, __vec4_i8 b) { (uint8_t)_mm_extract_epi8(b.v, 3)); } +static FORCEINLINE __vec4_i8 __lshr(__vec4_i8 a, int32_t b) { + return __vec4_i8((uint8_t)_mm_extract_epi8(a.v, 0) >> b, + (uint8_t)_mm_extract_epi8(a.v, 1) >> b, + (uint8_t)_mm_extract_epi8(a.v, 2) >> b, + (uint8_t)_mm_extract_epi8(a.v, 3) >> b); +} + static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, __vec4_i8 b) { return __vec4_i8((int8_t)_mm_extract_epi8(a.v, 0) >> (int8_t)_mm_extract_epi8(b.v, 0), @@ -369,6 +383,13 @@ static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, __vec4_i8 b) { (int8_t)_mm_extract_epi8(b.v, 3)); } +static FORCEINLINE __vec4_i8 __ashr(__vec4_i8 a, int32_t b) { + return __vec4_i8((int8_t)_mm_extract_epi8(a.v, 0) >> b, + (int8_t)_mm_extract_epi8(a.v, 1) >> b, + (int8_t)_mm_extract_epi8(a.v, 2) >> b, + (int8_t)_mm_extract_epi8(a.v, 3) >> b); +} + static FORCEINLINE __vec4_i1 __equal(__vec4_i8 a, __vec4_i8 b) { __m128i cmp = _mm_cmpeq_epi8(a.v, b.v); return __vec4_i1(_mm_extract_epi8(cmp, 0), @@ -547,6 +568,10 @@ static FORCEINLINE __vec4_i16 __shl(__vec4_i16 a, __vec4_i16 b) { _mm_extract_epi16(a.v, 3) << _mm_extract_epi16(b.v, 3)); } +static FORCEINLINE __vec4_i16 __shl(__vec4_i16 a, int32_t b) { + return _mm_sll_epi16(a.v, _mm_set_epi32(0, 0, 0, b)); +} + static FORCEINLINE __vec4_i16 __udiv(__vec4_i16 a, __vec4_i16 b) { return __vec4_i16((uint16_t)_mm_extract_epi16(a.v, 0) / (uint16_t)_mm_extract_epi16(b.v, 0), @@ -602,6 +627,10 @@ static FORCEINLINE __vec4_i16 __lshr(__vec4_i16 a, __vec4_i16 b) { (uint16_t)_mm_extract_epi16(b.v, 3)); } +static FORCEINLINE __vec4_i16 __lshr(__vec4_i16 a, int32_t b) { + return _mm_srl_epi16(a.v, _mm_set_epi32(0, 0, 0, b)); +} + static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, __vec4_i16 b) { return __vec4_i16((int16_t)_mm_extract_epi16(a.v, 0) >> (int16_t)_mm_extract_epi16(b.v, 0), @@ -613,6 +642,10 @@ static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, __vec4_i16 b) { (int16_t)_mm_extract_epi16(b.v, 3)); } +static FORCEINLINE __vec4_i16 __ashr(__vec4_i16 a, int32_t b) { + return _mm_sra_epi16(a.v, _mm_set_epi32(0, 0, 0, b)); +} + static FORCEINLINE __vec4_i1 __equal(__vec4_i16 a, __vec4_i16 b) { __m128i cmp = _mm_cmpeq_epi16(a.v, b.v); return __vec4_i1(_mm_extract_epi16(cmp, 0), @@ -789,9 +822,6 @@ static FORCEINLINE __vec4_i32 __xor(__vec4_i32 a, __vec4_i32 b) { } static FORCEINLINE __vec4_i32 __shl(__vec4_i32 a, __vec4_i32 b) { - // FIXME: if we can determine at compile time that b has the same value - // across all elements, then we can use _mm_sll_epi32. - /* fixme: llvm generates thie code for shift left, which is presumably more efficient than doing each component individually as below. @@ -813,57 +843,92 @@ _f___ii: ## @f___ii ret */ - return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) << _mm_extract_epi32(b.v, 0), - (uint32_t)_mm_extract_epi32(a.v, 1) << _mm_extract_epi32(b.v, 1), - (uint32_t)_mm_extract_epi32(a.v, 2) << _mm_extract_epi32(b.v, 2), - (uint32_t)_mm_extract_epi32(a.v, 3) << _mm_extract_epi32(b.v, 3)); + return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) << + _mm_extract_epi32(b.v, 0), + (uint32_t)_mm_extract_epi32(a.v, 1) << + _mm_extract_epi32(b.v, 1), + (uint32_t)_mm_extract_epi32(a.v, 2) << + _mm_extract_epi32(b.v, 2), + (uint32_t)_mm_extract_epi32(a.v, 3) << + _mm_extract_epi32(b.v, 3)); +} + +static FORCEINLINE __vec4_i32 __shl(__vec4_i32 a, int32_t b) { + return _mm_sll_epi32(a.v, _mm_set_epi32(0, 0, 0, b)); } static FORCEINLINE __vec4_i32 __udiv(__vec4_i32 a, __vec4_i32 b) { - return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) / (uint32_t)_mm_extract_epi32(b.v, 0), - (uint32_t)_mm_extract_epi32(a.v, 1) / (uint32_t)_mm_extract_epi32(b.v, 1), - (uint32_t)_mm_extract_epi32(a.v, 2) / (uint32_t)_mm_extract_epi32(b.v, 2), - (uint32_t)_mm_extract_epi32(a.v, 3) / (uint32_t)_mm_extract_epi32(b.v, 3)); + return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) / + (uint32_t)_mm_extract_epi32(b.v, 0), + (uint32_t)_mm_extract_epi32(a.v, 1) / + (uint32_t)_mm_extract_epi32(b.v, 1), + (uint32_t)_mm_extract_epi32(a.v, 2) / + (uint32_t)_mm_extract_epi32(b.v, 2), + (uint32_t)_mm_extract_epi32(a.v, 3) / + (uint32_t)_mm_extract_epi32(b.v, 3)); } static FORCEINLINE __vec4_i32 __sdiv(__vec4_i32 a, __vec4_i32 b) { - return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) / (int32_t)_mm_extract_epi32(b.v, 0), - (int32_t)_mm_extract_epi32(a.v, 1) / (int32_t)_mm_extract_epi32(b.v, 1), - (int32_t)_mm_extract_epi32(a.v, 2) / (int32_t)_mm_extract_epi32(b.v, 2), - (int32_t)_mm_extract_epi32(a.v, 3) / (int32_t)_mm_extract_epi32(b.v, 3)); + return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) / + (int32_t)_mm_extract_epi32(b.v, 0), + (int32_t)_mm_extract_epi32(a.v, 1) / + (int32_t)_mm_extract_epi32(b.v, 1), + (int32_t)_mm_extract_epi32(a.v, 2) / + (int32_t)_mm_extract_epi32(b.v, 2), + (int32_t)_mm_extract_epi32(a.v, 3) / + (int32_t)_mm_extract_epi32(b.v, 3)); } static FORCEINLINE __vec4_i32 __urem(__vec4_i32 a, __vec4_i32 b) { - return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) % (uint32_t)_mm_extract_epi32(b.v, 0), - (uint32_t)_mm_extract_epi32(a.v, 1) % (uint32_t)_mm_extract_epi32(b.v, 1), - (uint32_t)_mm_extract_epi32(a.v, 2) % (uint32_t)_mm_extract_epi32(b.v, 2), - (uint32_t)_mm_extract_epi32(a.v, 3) % (uint32_t)_mm_extract_epi32(b.v, 3)); + return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) % + (uint32_t)_mm_extract_epi32(b.v, 0), + (uint32_t)_mm_extract_epi32(a.v, 1) % + (uint32_t)_mm_extract_epi32(b.v, 1), + (uint32_t)_mm_extract_epi32(a.v, 2) % + (uint32_t)_mm_extract_epi32(b.v, 2), + (uint32_t)_mm_extract_epi32(a.v, 3) % + (uint32_t)_mm_extract_epi32(b.v, 3)); } static FORCEINLINE __vec4_i32 __srem(__vec4_i32 a, __vec4_i32 b) { - return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) % (int32_t)_mm_extract_epi32(b.v, 0), - (int32_t)_mm_extract_epi32(a.v, 1) % (int32_t)_mm_extract_epi32(b.v, 1), - (int32_t)_mm_extract_epi32(a.v, 2) % (int32_t)_mm_extract_epi32(b.v, 2), - (int32_t)_mm_extract_epi32(a.v, 3) % (int32_t)_mm_extract_epi32(b.v, 3)); + return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) % + (int32_t)_mm_extract_epi32(b.v, 0), + (int32_t)_mm_extract_epi32(a.v, 1) % + (int32_t)_mm_extract_epi32(b.v, 1), + (int32_t)_mm_extract_epi32(a.v, 2) % + (int32_t)_mm_extract_epi32(b.v, 2), + (int32_t)_mm_extract_epi32(a.v, 3) % + (int32_t)_mm_extract_epi32(b.v, 3)); } static FORCEINLINE __vec4_i32 __lshr(__vec4_i32 a, __vec4_i32 b) { - // FIXME: if we can determine at compile time that b has the same value - // across all elements, e.g. using gcc's __builtin_constant_p, then we - // can use _mm_srl_epi32. - return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) >> _mm_extract_epi32(b.v, 0), - (uint32_t)_mm_extract_epi32(a.v, 1) >> _mm_extract_epi32(b.v, 1), - (uint32_t)_mm_extract_epi32(a.v, 2) >> _mm_extract_epi32(b.v, 2), - (uint32_t)_mm_extract_epi32(a.v, 3) >> _mm_extract_epi32(b.v, 3)); + return __vec4_i32((uint32_t)_mm_extract_epi32(a.v, 0) >> + _mm_extract_epi32(b.v, 0), + (uint32_t)_mm_extract_epi32(a.v, 1) >> + _mm_extract_epi32(b.v, 1), + (uint32_t)_mm_extract_epi32(a.v, 2) >> + _mm_extract_epi32(b.v, 2), + (uint32_t)_mm_extract_epi32(a.v, 3) >> + _mm_extract_epi32(b.v, 3)); +} + +static FORCEINLINE __vec4_i32 __lshr(__vec4_i32 a, int32_t b) { + return _mm_srl_epi32(a.v, _mm_set_epi32(0, 0, 0, b)); } static FORCEINLINE __vec4_i32 __ashr(__vec4_i32 a, __vec4_i32 b) { - // FIXME: if we can determine at compile time that b has the same value - // across all elements, then we can use _mm_sra_epi32. - return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) >> _mm_extract_epi32(b.v, 0), - (int32_t)_mm_extract_epi32(a.v, 1) >> _mm_extract_epi32(b.v, 1), - (int32_t)_mm_extract_epi32(a.v, 2) >> _mm_extract_epi32(b.v, 2), - (int32_t)_mm_extract_epi32(a.v, 3) >> _mm_extract_epi32(b.v, 3)); + return __vec4_i32((int32_t)_mm_extract_epi32(a.v, 0) >> + _mm_extract_epi32(b.v, 0), + (int32_t)_mm_extract_epi32(a.v, 1) >> + _mm_extract_epi32(b.v, 1), + (int32_t)_mm_extract_epi32(a.v, 2) >> + _mm_extract_epi32(b.v, 2), + (int32_t)_mm_extract_epi32(a.v, 3) >> + _mm_extract_epi32(b.v, 3)); +} + +static FORCEINLINE __vec4_i32 __ashr(__vec4_i32 a, int32_t b) { + return _mm_sra_epi32(a.v, _mm_set_epi32(0, 0, 0, b)); } static FORCEINLINE __vec4_i1 __equal(__vec4_i32 a, __vec4_i32 b) { @@ -1016,6 +1081,12 @@ static FORCEINLINE __vec4_i64 __shl(__vec4_i64 a, __vec4_i64 b) { _mm_extract_epi64(a.v[1], 1) << _mm_extract_epi64(b.v[1], 1)); } +static FORCEINLINE __vec4_i64 __shl(__vec4_i64 a, int32_t b) { + __m128i amt = _mm_set_epi32(0, 0, 0, b); + return __vec4_i64(_mm_sll_epi64(a.v[0], amt), + _mm_sll_epi64(a.v[1], amt)); +} + static FORCEINLINE __vec4_i64 __udiv(__vec4_i64 a, __vec4_i64 b) { return __vec4_i64((uint64_t)_mm_extract_epi64(a.v[0], 0) / (uint64_t)_mm_extract_epi64(b.v[0], 0), @@ -1071,6 +1142,12 @@ static FORCEINLINE __vec4_i64 __lshr(__vec4_i64 a, __vec4_i64 b) { (uint64_t)_mm_extract_epi64(b.v[1], 1)); } +static FORCEINLINE __vec4_i64 __lshr(__vec4_i64 a, int32_t b) { + __m128i amt = _mm_set_epi32(0, 0, 0, b); + return __vec4_i64(_mm_srl_epi64(a.v[0], amt), + _mm_srl_epi64(a.v[1], amt)); +} + static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, __vec4_i64 b) { return __vec4_i64((int64_t)_mm_extract_epi64(a.v[0], 0) >> (int64_t)_mm_extract_epi64(b.v[0], 0), @@ -1082,6 +1159,13 @@ static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, __vec4_i64 b) { (int64_t)_mm_extract_epi64(b.v[1], 1)); } +static FORCEINLINE __vec4_i64 __ashr(__vec4_i64 a, int32_t b) { + return __vec4_i64((int64_t)_mm_extract_epi64(a.v[0], 0) >> b, + (int64_t)_mm_extract_epi64(a.v[0], 1) >> b, + (int64_t)_mm_extract_epi64(a.v[1], 0) >> b, + (int64_t)_mm_extract_epi64(a.v[1], 1) >> b); +} + static FORCEINLINE __vec4_i1 __equal(__vec4_i64 a, __vec4_i64 b) { __m128i cmp0 = _mm_cmpeq_epi64(a.v[0], b.v[0]); __m128i cmp1 = _mm_cmpeq_epi64(a.v[1], b.v[1]); diff --git a/llvmutil.cpp b/llvmutil.cpp index 804b13e0..808babbc 100644 --- a/llvmutil.cpp +++ b/llvmutil.cpp @@ -36,7 +36,9 @@ */ #include "llvmutil.h" +#include "ispc.h" #include "type.h" +#include LLVM_TYPE_CONST llvm::Type *LLVMTypes::VoidType = NULL; LLVM_TYPE_CONST llvm::PointerType *LLVMTypes::VoidPointerType = NULL; @@ -465,3 +467,239 @@ LLVMBoolVector(const bool *bvec) { } return llvm::ConstantVector::get(vals); } + + +/** Conservative test to see if two llvm::Values are equal. There are + (potentially many) cases where the two values actually are equal but + this will return false. However, if it does return true, the two + vectors definitely are equal. + + @todo This seems to catch all of the cases we currently need it for in + practice, but it's be nice to make it a little more robust/general. In + general, though, a little something called the halting problem means we + won't get all of them. +*/ +static bool +lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, + std::vector &seenPhi0, + std::vector &seenPhi1) { + // Thanks to the fact that LLVM hashes and returns the same pointer for + // constants (of all sorts, even constant expressions), this first test + // actually catches a lot of cases. LLVM's SSA form also helps a lot + // with this.. + if (v0 == v1) + return true; + + Assert(seenPhi0.size() == seenPhi1.size()); + for (unsigned int i = 0; i < seenPhi0.size(); ++i) + if (v0 == seenPhi0[i] && v1 == seenPhi1[i]) + return true; + + llvm::BinaryOperator *bo0 = llvm::dyn_cast(v0); + llvm::BinaryOperator *bo1 = llvm::dyn_cast(v1); + if (bo0 != NULL && bo1 != NULL) { + if (bo0->getOpcode() != bo1->getOpcode()) + return false; + return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0), + seenPhi0, seenPhi1) && + lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1), + seenPhi0, seenPhi1)); + } + + llvm::PHINode *phi0 = llvm::dyn_cast(v0); + llvm::PHINode *phi1 = llvm::dyn_cast(v1); + if (phi0 != NULL && phi1 != NULL) { + if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues()) + return false; + + seenPhi0.push_back(phi0); + seenPhi1.push_back(phi1); + + unsigned int numIncoming = phi0->getNumIncomingValues(); + // Check all of the incoming values: if all of them are all equal, + // then we're good. + bool anyFailure = false; + for (unsigned int i = 0; i < numIncoming; ++i) { + Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i)); + if (!lValuesAreEqual(phi0->getIncomingValue(i), + phi1->getIncomingValue(i), seenPhi0, seenPhi1)) { + anyFailure = true; + break; + } + } + + seenPhi0.pop_back(); + seenPhi1.pop_back(); + + return !anyFailure; + } + + return false; +} + + +/** Given an llvm::Value known to be an integer, return its value as + an int64_t. +*/ +static int64_t +lGetIntValue(llvm::Value *offset) { + llvm::ConstantInt *intOffset = llvm::dyn_cast(offset); + Assert(intOffset && (intOffset->getBitWidth() == 32 || + intOffset->getBitWidth() == 64)); + return intOffset->getSExtValue(); +} + + +/** This function takes chains of InsertElement instructions along the + lines of: + + %v0 = insertelement undef, value_0, i32 index_0 + %v1 = insertelement %v1, value_1, i32 index_1 + ... + %vn = insertelement %vn-1, value_n-1, i32 index_n-1 + + and initializes the provided elements array such that the i'th + llvm::Value * in the array is the element that was inserted into the + i'th element of the vector. +*/ +void +LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, + llvm::Value **elements) { + for (int i = 0; i < vectorWidth; ++i) + elements[i] = NULL; + + while (ie != NULL) { + int64_t iOffset = lGetIntValue(ie->getOperand(2)); + Assert(iOffset >= 0 && iOffset < vectorWidth); + Assert(elements[iOffset] == NULL); + + elements[iOffset] = ie->getOperand(1); + + llvm::Value *insertBase = ie->getOperand(0); + ie = llvm::dyn_cast(insertBase); + if (ie == NULL) { + if (llvm::isa(insertBase)) + return; + + llvm::ConstantVector *cv = + llvm::dyn_cast(insertBase); + Assert(cv != NULL); + Assert(iOffset < (int)cv->getNumOperands()); + elements[iOffset] = cv->getOperand(iOffset); + } + } +} + + +/** Tests to see if all of the elements of the vector in the 'v' parameter + are equal. Like lValuesAreEqual(), this is a conservative test and may + return false for arrays where the values are actually all equal. */ +bool +LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, + std::vector &seenPhis) { + if (llvm::isa(v)) + return true; + + llvm::ConstantVector *cv = llvm::dyn_cast(v); + if (cv != NULL) + return (cv->getSplatValue() != NULL); + + llvm::BinaryOperator *bop = llvm::dyn_cast(v); + if (bop != NULL) + return (LLVMVectorValuesAllEqual(bop->getOperand(0), vectorLength, + seenPhis) && + LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength, + seenPhis)); + + llvm::CastInst *cast = llvm::dyn_cast(v); + if (cast != NULL) + return LLVMVectorValuesAllEqual(cast->getOperand(0), vectorLength, + seenPhis); + + llvm::InsertElementInst *ie = llvm::dyn_cast(v); + if (ie != NULL) { + llvm::Value *elements[ISPC_MAX_NVEC]; + LLVMFlattenInsertChain(ie, vectorLength, elements); + + // We will ignore any values of elements[] that are NULL; as they + // correspond to undefined values--we just want to see if all of + // the defined values have the same value. + int lastNonNull = 0; + while (lastNonNull < vectorLength && elements[lastNonNull] == NULL) + ++lastNonNull; + + if (lastNonNull == vectorLength) + // all of them are undef! + return true; + + for (int i = lastNonNull; i < vectorLength; ++i) { + if (elements[i] == NULL) + continue; + + std::vector seenPhi0; + std::vector seenPhi1; + if (lValuesAreEqual(elements[lastNonNull], elements[i], seenPhi0, + seenPhi1) == false) + return false; + lastNonNull = i; + } + return true; + } + + llvm::PHINode *phi = llvm::dyn_cast(v); + if (phi) { + for (unsigned int i = 0; i < seenPhis.size(); ++i) + if (seenPhis[i] == phi) + return true; + + seenPhis.push_back(phi); + + unsigned int numIncoming = phi->getNumIncomingValues(); + // Check all of the incoming values: if all of them are all equal, + // then we're good. + for (unsigned int i = 0; i < numIncoming; ++i) { + if (!LLVMVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, + seenPhis)) { + seenPhis.pop_back(); + return false; + } + } + + seenPhis.pop_back(); + return true; + } + + Assert(!llvm::isa(v)); + + if (llvm::isa(v) || llvm::isa(v) || + !llvm::isa(v)) + return false; + + llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast(v); + if (shuffle != NULL) { + llvm::Value *indices = shuffle->getOperand(2); + if (LLVMVectorValuesAllEqual(indices, vectorLength, seenPhis)) + // The easy case--just a smear of the same element across the + // whole vector. + return true; + + // TODO: handle more general cases? + return false; + } + +#if 0 + fprintf(stderr, "all equal: "); + v->dump(); + fprintf(stderr, "\n"); + llvm::Instruction *inst = llvm::dyn_cast(v); + if (inst) { + inst->getParent()->dump(); + fprintf(stderr, "\n"); + fprintf(stderr, "\n"); + } +#endif + + return false; +} + + diff --git a/llvmutil.h b/llvmutil.h index 0322b49e..a1257084 100644 --- a/llvmutil.h +++ b/llvmutil.h @@ -38,12 +38,23 @@ #ifndef ISPC_LLVMUTIL_H #define ISPC_LLVMUTIL_H 1 -#include "ispc.h" #include #include #include #include +namespace llvm { + class PHINode; + class InsertElementInst; +} + +// llvm::Type *s are no longer const in llvm 3.0 +#if defined(LLVM_3_0) || defined(LLVM_3_0svn) || defined(LLVM_3_1svn) +#define LLVM_TYPE_CONST +#else +#define LLVM_TYPE_CONST const +#endif + /** This structure holds pointers to a variety of LLVM types; code elsewhere can use them from here, ratherthan needing to make more @@ -99,6 +110,7 @@ extern llvm::Constant *LLVMTrue, *LLVMFalse; of LLVMTypes and the LLVMTrue/LLVMFalse constants. However, it can't be called until the compilation target is known. */ +struct Target; extern void InitLLVMUtil(llvm::LLVMContext *ctx, Target target); /** Returns an LLVM i8 constant of the given value */ @@ -205,4 +217,13 @@ extern llvm::Constant *LLVMMaskAllOn; /** LLVM constant value representing an 'all off' SIMD lane mask */ extern llvm::Constant *LLVMMaskAllOff; +/** Tests to see if all of the elements of the vector in the 'v' parameter + are equal. Like lValuesAreEqual(), this is a conservative test and may + return false for arrays where the values are actually all equal. */ +extern bool LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, + std::vector &seenPhis); + +void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, + llvm::Value **elements); + #endif // ISPC_LLVMUTIL_H diff --git a/opt.cpp b/opt.cpp index b1f5c7a0..cd62d342 100644 --- a/opt.cpp +++ b/opt.cpp @@ -921,58 +921,6 @@ char GatherScatterFlattenOpt::ID = 0; llvm::RegisterPass gsf("gs-flatten", "Gather/Scatter Flatten Pass"); -/** Given an llvm::Value known to be an integer, return its value as - an int64_t. -*/ -static int64_t -lGetIntValue(llvm::Value *offset) { - llvm::ConstantInt *intOffset = llvm::dyn_cast(offset); - Assert(intOffset && (intOffset->getBitWidth() == 32 || - intOffset->getBitWidth() == 64)); - return intOffset->getSExtValue(); -} - -/** This function takes chains of InsertElement instructions along the - lines of: - - %v0 = insertelement undef, value_0, i32 index_0 - %v1 = insertelement %v1, value_1, i32 index_1 - ... - %vn = insertelement %vn-1, value_n-1, i32 index_n-1 - - and initializes the provided elements array such that the i'th - llvm::Value * in the array is the element that was inserted into the - i'th element of the vector. -*/ -static void -lFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, - llvm::Value **elements) { - for (int i = 0; i < vectorWidth; ++i) - elements[i] = NULL; - - while (ie != NULL) { - int64_t iOffset = lGetIntValue(ie->getOperand(2)); - Assert(iOffset >= 0 && iOffset < vectorWidth); - Assert(elements[iOffset] == NULL); - - elements[iOffset] = ie->getOperand(1); - - llvm::Value *insertBase = ie->getOperand(0); - ie = llvm::dyn_cast(insertBase); - if (ie == NULL) { - if (llvm::isa(insertBase)) - return; - - llvm::ConstantVector *cv = - llvm::dyn_cast(insertBase); - Assert(cv != NULL); - Assert(iOffset < (int)cv->getNumOperands()); - elements[iOffset] = cv->getOperand(iOffset); - } - } -} - - /** Check to make sure that this value is actually a pointer in the end. We need to make sure that given an expression like vec(offset) + ptr2int(ptr), lGetBasePointer() doesn't return vec(offset) for the base @@ -1011,7 +959,7 @@ lGetBasePointer(llvm::Value *v) { llvm::InsertElementInst *ie = llvm::dyn_cast(v); if (ie != NULL) { llvm::Value *elements[ISPC_MAX_NVEC]; - lFlattenInsertChain(ie, g->target.vectorWidth, elements); + LLVMFlattenInsertChain(ie, g->target.vectorWidth, elements); // Make sure none of the elements is undefined. // TODO: it's probably ok to allow undefined elements and return @@ -1825,187 +1773,6 @@ llvm::RegisterPass gsi("gs-improvements", "Gather/Scatter Improvements Pass"); -/** Conservative test to see if two llvm::Values are equal. There are - (potentially many) cases where the two values actually are equal but - this will return false. However, if it does return true, the two - vectors definitely are equal. - - @todo This seems to catch all of the cases we currently need it for in - practice, but it's be nice to make it a little more robust/general. In - general, though, a little something called the halting problem means we - won't get all of them. -*/ -static bool -lValuesAreEqual(llvm::Value *v0, llvm::Value *v1, - std::vector &seenPhi0, - std::vector &seenPhi1) { - // Thanks to the fact that LLVM hashes and returns the same pointer for - // constants (of all sorts, even constant expressions), this first test - // actually catches a lot of cases. LLVM's SSA form also helps a lot - // with this.. - if (v0 == v1) - return true; - - Assert(seenPhi0.size() == seenPhi1.size()); - for (unsigned int i = 0; i < seenPhi0.size(); ++i) - if (v0 == seenPhi0[i] && v1 == seenPhi1[i]) - return true; - - llvm::BinaryOperator *bo0 = llvm::dyn_cast(v0); - llvm::BinaryOperator *bo1 = llvm::dyn_cast(v1); - if (bo0 != NULL && bo1 != NULL) { - if (bo0->getOpcode() != bo1->getOpcode()) - return false; - return (lValuesAreEqual(bo0->getOperand(0), bo1->getOperand(0), - seenPhi0, seenPhi1) && - lValuesAreEqual(bo0->getOperand(1), bo1->getOperand(1), - seenPhi0, seenPhi1)); - } - - llvm::PHINode *phi0 = llvm::dyn_cast(v0); - llvm::PHINode *phi1 = llvm::dyn_cast(v1); - if (phi0 != NULL && phi1 != NULL) { - if (phi0->getNumIncomingValues() != phi1->getNumIncomingValues()) - return false; - - seenPhi0.push_back(phi0); - seenPhi1.push_back(phi1); - - unsigned int numIncoming = phi0->getNumIncomingValues(); - // Check all of the incoming values: if all of them are all equal, - // then we're good. - bool anyFailure = false; - for (unsigned int i = 0; i < numIncoming; ++i) { - Assert(phi0->getIncomingBlock(i) == phi1->getIncomingBlock(i)); - if (!lValuesAreEqual(phi0->getIncomingValue(i), - phi1->getIncomingValue(i), seenPhi0, seenPhi1)) { - anyFailure = true; - break; - } - } - - seenPhi0.pop_back(); - seenPhi1.pop_back(); - - return !anyFailure; - } - - return false; -} - - -/** Tests to see if all of the elements of the vector in the 'v' parameter - are equal. Like lValuesAreEqual(), this is a conservative test and may - return false for arrays where the values are actually all equal. */ -static bool -lVectorValuesAllEqual(llvm::Value *v, int vectorLength, - std::vector &seenPhis) { - if (llvm::isa(v)) - return true; - - llvm::ConstantVector *cv = llvm::dyn_cast(v); - if (cv != NULL) - return (cv->getSplatValue() != NULL); - - llvm::BinaryOperator *bop = llvm::dyn_cast(v); - if (bop != NULL) - return (lVectorValuesAllEqual(bop->getOperand(0), vectorLength, - seenPhis) && - lVectorValuesAllEqual(bop->getOperand(1), vectorLength, - seenPhis)); - - llvm::CastInst *cast = llvm::dyn_cast(v); - if (cast != NULL) - return lVectorValuesAllEqual(cast->getOperand(0), vectorLength, - seenPhis); - - llvm::InsertElementInst *ie = llvm::dyn_cast(v); - if (ie != NULL) { - llvm::Value *elements[ISPC_MAX_NVEC]; - lFlattenInsertChain(ie, vectorLength, elements); - - // We will ignore any values of elements[] that are NULL; as they - // correspond to undefined values--we just want to see if all of - // the defined values have the same value. - int lastNonNull = 0; - while (lastNonNull < vectorLength && elements[lastNonNull] == NULL) - ++lastNonNull; - - if (lastNonNull == vectorLength) - // all of them are undef! - return true; - - for (int i = lastNonNull; i < vectorLength; ++i) { - if (elements[i] == NULL) - continue; - - std::vector seenPhi0; - std::vector seenPhi1; - if (lValuesAreEqual(elements[lastNonNull], elements[i], seenPhi0, - seenPhi1) == false) - return false; - lastNonNull = i; - } - return true; - } - - llvm::PHINode *phi = llvm::dyn_cast(v); - if (phi) { - for (unsigned int i = 0; i < seenPhis.size(); ++i) - if (seenPhis[i] == phi) - return true; - - seenPhis.push_back(phi); - - unsigned int numIncoming = phi->getNumIncomingValues(); - // Check all of the incoming values: if all of them are all equal, - // then we're good. - for (unsigned int i = 0; i < numIncoming; ++i) { - if (!lVectorValuesAllEqual(phi->getIncomingValue(i), vectorLength, - seenPhis)) { - seenPhis.pop_back(); - return false; - } - } - - seenPhis.pop_back(); - return true; - } - - Assert(!llvm::isa(v)); - - if (llvm::isa(v) || llvm::isa(v) || - !llvm::isa(v)) - return false; - - llvm::ShuffleVectorInst *shuffle = llvm::dyn_cast(v); - if (shuffle != NULL) { - llvm::Value *indices = shuffle->getOperand(2); - if (lVectorValuesAllEqual(indices, vectorLength, seenPhis)) - // The easy case--just a smear of the same element across the - // whole vector. - return true; - - // TODO: handle more general cases? - return false; - } - -#if 0 - fprintf(stderr, "all equal: "); - v->dump(); - fprintf(stderr, "\n"); - llvm::Instruction *inst = llvm::dyn_cast(v); - if (inst) { - inst->getParent()->dump(); - fprintf(stderr, "\n"); - fprintf(stderr, "\n"); - } -#endif - - return false; -} - - /** Given a vector of compile-time constant integer values, test to see if they are a linear sequence of constant integers starting from an arbirary value but then having a step of value "stride" between @@ -2102,9 +1869,9 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, // programIndex + unif -> ascending linear seqeuence // unif + programIndex -> ascending linear sequence return ((lVectorIsLinear(op0, vectorLength, stride, seenPhis) && - lVectorValuesAllEqual(op1, vectorLength, seenPhis)) || + LLVMVectorValuesAllEqual(op1, vectorLength, seenPhis)) || (lVectorIsLinear(op1, vectorLength, stride, seenPhis) && - lVectorValuesAllEqual(op0, vectorLength, seenPhis))); + LLVMVectorValuesAllEqual(op0, vectorLength, seenPhis))); else if (bop->getOpcode() == llvm::Instruction::Sub) // For subtraction, we only match: // @@ -2115,7 +1882,7 @@ lVectorIsLinear(llvm::Value *v, int vectorLength, int stride, // And generate code for that as a vector load + shuffle. return (lVectorIsLinear(bop->getOperand(0), vectorLength, stride, seenPhis) && - lVectorValuesAllEqual(bop->getOperand(1), vectorLength, + LLVMVectorValuesAllEqual(bop->getOperand(1), vectorLength, seenPhis)); else if (bop->getOpcode() == llvm::Instruction::Mul) // Multiplies are a bit trickier, so are handled in a separate @@ -2313,7 +2080,7 @@ GSImprovementsPass::runOnBasicBlock(llvm::BasicBlock &bb) { { std::vector seenPhis; - if (lVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) { + if (LLVMVectorValuesAllEqual(offsets, g->target.vectorWidth, seenPhis)) { // If all the offsets are equal, then compute the single // pointer they all represent based on the first one of them // (arbitrarily).