From 0664f5a724bd1611110d67af942c3f28ca332fb3 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Mon, 19 Mar 2012 11:46:32 -0700 Subject: [PATCH] Add LLVMExtractVectorInts() function, use it in the opt code. --- llvmutil.cpp | 70 ++++++++++++++++++++++++++++++++++++++-------- llvmutil.h | 29 +++++++++++++++++-- opt.cpp | 79 ++++++++-------------------------------------------- 3 files changed, 96 insertions(+), 82 deletions(-) diff --git a/llvmutil.cpp b/llvmutil.cpp index cfcdf113..55bc45a0 100644 --- a/llvmutil.cpp +++ b/llvmutil.cpp @@ -589,18 +589,6 @@ lGetIntValue(llvm::Value *offset) { } -/** This function takes chains of InsertElement instructions along the - lines of: - - %v0 = insertelement undef, value_0, i32 index_0 - %v1 = insertelement %v1, value_1, i32 index_1 - ... - %vn = insertelement %vn-1, value_n-1, i32 index_n-1 - - and initializes the provided elements array such that the i'th - llvm::Value * in the array is the element that was inserted into the - i'th element of the vector. -*/ void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, llvm::Value **elements) { @@ -612,17 +600,32 @@ LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, Assert(iOffset >= 0 && iOffset < vectorWidth); Assert(elements[iOffset] == NULL); + // Get the scalar value from this insert elements[iOffset] = ie->getOperand(1); + // Do we have another insert? llvm::Value *insertBase = ie->getOperand(0); ie = llvm::dyn_cast(insertBase); if (ie == NULL) { if (llvm::isa(insertBase)) return; + // Get the value out of a constant vector if that's what we + // have llvm::ConstantVector *cv = llvm::dyn_cast(insertBase); + + // FIXME: this assert is a little questionable; we probably + // shouldn't fail in this case but should just return an + // incomplete result. But there aren't currently any known + // cases where we have anything other than an undef value or a + // constant vector at the base, so if that ever does happen, + // it'd be nice to know what happend so that perhaps we can + // handle it. + // FIXME: Also, should we handle ConstantDataVectors with + // LLVM3.1? What about ConstantAggregateZero values?? Assert(cv != NULL); + Assert(iOffset < (int)cv->getNumOperands()); elements[iOffset] = cv->getOperand((int32_t)iOffset); } @@ -630,6 +633,49 @@ LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, } +bool +LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) { + // Make sure we do in fact have a vector of integer values here + LLVM_TYPE_CONST llvm::VectorType *vt = + llvm::dyn_cast(v->getType()); + Assert(vt != NULL); + Assert(llvm::isa(vt->getElementType())); + + *nElts = (int)vt->getNumElements(); + + if (llvm::isa(v)) { + for (int i = 0; i < (int)vt->getNumElements(); ++i) + ret[i] = 0; + return true; + } + + // Deal with the fact that LLVM3.1 and previous versions have different + // representations for vectors of constant ints... +#ifdef LLVM_3_1svn + llvm::ConstantDataVector *cv = llvm::dyn_cast(v); + if (cv == NULL) + return false; + + for (int i = 0; i < (int)cv->getNumElements(); ++i) + ret[i] = cv->getElementAsInteger(i); + return true; +#else + llvm::ConstantVector *cv = llvm::dyn_cast(v); + if (cv == NULL) + return false; + + llvm::SmallVector elements; + cv->getVectorElements(elements); + for (int i = 0; i < (int)vt->getNumElements(); ++i) { + llvm::ConstantInt *ci = llvm::dyn_cast(elements[i]); + Assert(ci != NULL); + ret[i] = ci->getSExtValue(); + } + return true; +#endif // LLVM_3_1svn +} + + /** Tests to see if all of the elements of the vector in the 'v' parameter are equal. Like lValuesAreEqual(), this is a conservative test and may return false for arrays where the values are actually all equal. */ diff --git a/llvmutil.h b/llvmutil.h index 41f98d96..8be95f99 100644 --- a/llvmutil.h +++ b/llvmutil.h @@ -231,8 +231,33 @@ extern llvm::Constant *LLVMMaskAllOff; extern bool LLVMVectorValuesAllEqual(llvm::Value *v, int vectorLength, std::vector &seenPhis); -void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, - llvm::Value **elements); +/** Given a vector-typed value v, if the vector is a vector with constant + element values, this function extracts those element values into the + ret[] array and returns the number of elements (i.e. the vector type's + width) in *nElts. It returns true if successful and false if the given + vector is not in fact a vector of constants. */ +extern bool LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts); + +/** This function takes chains of InsertElement instructions along the + lines of: + + %v0 = insertelement undef, value_0, i32 index_0 + %v1 = insertelement %v1, value_1, i32 index_1 + ... + %vn = insertelement %vn-1, value_n-1, i32 index_n-1 + + and initializes the provided elements array such that the i'th + llvm::Value * in the array is the element that was inserted into the + i'th element of the vector. + + When the chain of insertelement instruction comes to an end, the only + base case that this function handles is the initial value being a + constant vector. For anything more complex (e.g. some other arbitrary + value, it doesn't try to extract element values into the returned + array. + */ +extern void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, + llvm::Value **elements); /** This is a utility routine for debugging that dumps out the given LLVM value as well as (recursively) all of the other values that it depends diff --git a/opt.cpp b/opt.cpp index e2efe9a5..cb708636 100644 --- a/opt.cpp +++ b/opt.cpp @@ -1647,51 +1647,11 @@ lExtractUniformsFromOffset(llvm::Value **basePtr, llvm::Value **offsetVector, #endif -static bool -lExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts) { - LLVM_TYPE_CONST llvm::VectorType *vt = - llvm::dyn_cast(v->getType()); - Assert(vt != NULL); - Assert(llvm::isa(vt->getElementType())); - - *nElts = (int)vt->getNumElements(); - - if (llvm::isa(v)) { - for (int i = 0; i < (int)vt->getNumElements(); ++i) - ret[i] = 0; - return true; - } - -#ifdef LLVM_3_1svn - llvm::ConstantDataVector *cv = llvm::dyn_cast(v); - if (cv == NULL) - return false; - - for (int i = 0; i < (int)cv->getNumElements(); ++i) - ret[i] = cv->getElementAsInteger(i); - return true; -#else - llvm::ConstantVector *cv = llvm::dyn_cast(v); - if (cv == NULL) - return false; - - llvm::SmallVector elements; - cv->getVectorElements(elements); - for (int i = 0; i < (int)vt->getNumElements(); ++i) { - llvm::ConstantInt *ci = llvm::dyn_cast(elements[i]); - Assert(ci != NULL); - ret[i] = ci->getSExtValue(); - } - return true; -#endif // LLVM_3_1svn -} - - static bool lVectorIs32BitInts(llvm::Value *v) { int nElts; int64_t elts[ISPC_MAX_NVEC]; - if (!lExtractVectorInts(v, elts, &nElts)) + if (!LLVMExtractVectorInts(v, elts, &nElts)) return false; for (int i = 0; i < nElts; ++i) @@ -3546,36 +3506,19 @@ lComputeBasePtr(llvm::CallInst *gatherInst, llvm::Instruction *insertBefore) { static void lExtractConstOffsets(const std::vector &coalesceGroup, int elementSize, std::vector *constOffsets) { - constOffsets->reserve(coalesceGroup.size() * g->target.vectorWidth); + int width = g->target.vectorWidth; + *constOffsets = std::vector(coalesceGroup.size() * width, 0); - for (int i = 0; i < (int)coalesceGroup.size(); ++i) { + int64_t *endPtr = &((*constOffsets)[0]); + for (int i = 0; i < (int)coalesceGroup.size(); ++i, endPtr += width) { llvm::Value *offsets = coalesceGroup[i]->getArgOperand(3); - -#ifdef LLVM_3_1svn - llvm::ConstantDataVector *cv = - llvm::dyn_cast(offsets); - Assert(cv != NULL); - - for (int j = 0; j < g->target.vectorWidth; ++j) { - Assert((cv->getElementAsInteger(j) % elementSize) == 0); - constOffsets->push_back((int64_t)cv->getElementAsInteger(j) / - elementSize); - } -#else - llvm::ConstantVector *cv = - llvm::dyn_cast(offsets); - Assert(cv != NULL); - - for (int j = 0; j < g->target.vectorWidth; ++j) { - llvm::ConstantInt *ci = - llvm::dyn_cast(cv->getOperand(j)); - Assert(ci != NULL); - int64_t value = ci->getValue().getSExtValue(); - Assert((value % elementSize) == 0); - constOffsets->push_back(value / elementSize); - } -#endif + int nElts; + bool ok = LLVMExtractVectorInts(offsets, endPtr, &nElts); + Assert(ok && nElts == width); } + + for (int i = 0; i < (int)constOffsets->size(); ++i) + (*constOffsets)[i] /= elementSize; }