diff --git a/ctx.cpp b/ctx.cpp index 6f4b6bcf..f8fb2962 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -1379,6 +1379,19 @@ FunctionEmitContext::MasksAllEqual(llvm::Value *v1, llvm::Value *v2) { #endif } +llvm::Value * +FunctionEmitContext::ProgramIndex() { + llvm::SmallVector array; + for (int i = 0; i < g->target->getVectorWidth() ; ++i) { + llvm::Constant *C = LLVMInt32(i); + array.push_back(C); + } + + llvm::Constant* index = llvm::ConstantVector::get(array); + + return index; +} + llvm::Value * FunctionEmitContext::GetStringPtr(const std::string &str) { @@ -1729,26 +1742,48 @@ FunctionEmitContext::SmearUniform(llvm::Value *value, const char *name) { llvm::Value *ret = NULL; llvm::Type *eltType = value->getType(); + llvm::Type *vecType = NULL; llvm::PointerType *pt = llvm::dyn_cast(eltType); if (pt != NULL) { // Varying pointers are represented as vectors of i32/i64s - ret = llvm::UndefValue::get(LLVMTypes::VoidPointerVectorType); + vecType = LLVMTypes::VoidPointerVectorType; value = PtrToIntInst(value); } - else + else { // All other varying types are represented as vectors of the // underlying type. - ret = llvm::UndefValue::get(llvm::VectorType::get(eltType, - g->target->getVectorWidth())); - - for (int i = 0; i < g->target->getVectorWidth(); ++i) { - llvm::Twine n = llvm::Twine("smear.") + llvm::Twine(name ? name : "") + - llvm::Twine(i); - ret = InsertInst(ret, value, i, n.str().c_str()); + vecType = llvm::VectorType::get(eltType, g->target->getVectorWidth()); } + // Check for a constant case. + if (llvm::Constant *const_val = llvm::dyn_cast(value)) { + ret = llvm::ConstantVector::getSplat( + g->target->getVectorWidth(), + const_val); + return ret; + } + + // Generate the follwoing sequence: + // %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0 + // %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef, + // <4 x i32> zeroinitializer + // + llvm::Value *undef1 = llvm::UndefValue::get(vecType); + llvm::Value *undef2 = llvm::UndefValue::get(vecType); + + // InsertElement + llvm::Twine tw1 = llvm::Twine("broadcast_init.") + llvm::Twine(name ? name : ""); + llvm::Value *insert = InsertInst(undef1, value, 0, tw1.str().c_str()); + + // ShuffleVector + llvm::Constant *zeroVec = llvm::ConstantVector::getSplat( + g->target->getVectorWidth(), + llvm::Constant::getNullValue(llvm::Type::getInt32Ty(*g->ctx))); + llvm::Twine tw2 = llvm::Twine("broadcast.") + llvm::Twine(name ? name : ""); + ret = ShuffleInst(insert, undef2, zeroVec, tw2.str().c_str()); + return ret; } @@ -3131,6 +3166,27 @@ FunctionEmitContext::InsertInst(llvm::Value *v, llvm::Value *eltVal, int elt, } +llvm::Value * +FunctionEmitContext::ShuffleInst(llvm::Value *v1, llvm::Value *v2, llvm::Value *mask, + const char *name) { + if (v1 == NULL || v2 == NULL || mask == NULL) { + AssertPos(currentPos, m->errorCount > 0); + return NULL; + } + + if (name == NULL) { + char buf[32]; + sprintf(buf, "_shuffle"); + name = LLVMGetName(v1, buf); + } + + llvm::Instruction *ii = new llvm::ShuffleVectorInst(v1, v2, mask, name, bblock); + + AddDebugPos(ii); + return ii; +} + + llvm::PHINode * FunctionEmitContext::PhiNode(llvm::Type *type, int count, const char *name) { diff --git a/ctx.h b/ctx.h index 7e262310..67e7efa6 100644 --- a/ctx.h +++ b/ctx.h @@ -295,6 +295,10 @@ public: that indicates whether the two masks are equal. */ llvm::Value *MasksAllEqual(llvm::Value *mask1, llvm::Value *mask2); + /** Generate ConstantVector, which contains ProgramIndex, i.e. + < i32 0, i32 1, i32 2, i32 3> */ + llvm::Value *ProgramIndex(); + /** Given a string, create an anonymous global variable to hold its value and return the pointer to the string. */ llvm::Value *GetStringPtr(const std::string &str); @@ -500,6 +504,9 @@ public: llvm::Value *InsertInst(llvm::Value *v, llvm::Value *eltVal, int elt, const char *name = NULL); + llvm::Value *ShuffleInst(llvm::Value *v1, llvm::Value *v2, llvm::Value *mask, + const char *name = NULL); + llvm::PHINode *PhiNode(llvm::Type *type, int count, const char *name = NULL); llvm::Instruction *SelectInst(llvm::Value *test, llvm::Value *val0, diff --git a/llvmutil.cpp b/llvmutil.cpp index e8dd4f9c..26ab72a5 100644 --- a/llvmutil.cpp +++ b/llvmutil.cpp @@ -601,44 +601,74 @@ lGetIntValue(llvm::Value *offset) { void -LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, +LLVMFlattenInsertChain(llvm::Value *inst, int vectorWidth, llvm::Value **elements) { - for (int i = 0; i < vectorWidth; ++i) + for (int i = 0; i < vectorWidth; ++i) { elements[i] = NULL; + } - while (ie != NULL) { - int64_t iOffset = lGetIntValue(ie->getOperand(2)); - Assert(iOffset >= 0 && iOffset < vectorWidth); - Assert(elements[iOffset] == NULL); + // Catch a pattern of InsertElement chain. + if (llvm::InsertElementInst *ie = + llvm::dyn_cast(inst)) { + while (ie != NULL) { + int64_t iOffset = lGetIntValue(ie->getOperand(2)); + Assert(iOffset >= 0 && iOffset < vectorWidth); + Assert(elements[iOffset] == NULL); - // Get the scalar value from this insert - elements[iOffset] = ie->getOperand(1); + // Get the scalar value from this insert + elements[iOffset] = ie->getOperand(1); - // Do we have another insert? - llvm::Value *insertBase = ie->getOperand(0); - ie = llvm::dyn_cast(insertBase); - if (ie == NULL) { - if (llvm::isa(insertBase)) - return; + // Do we have another insert? + llvm::Value *insertBase = ie->getOperand(0); + ie = llvm::dyn_cast(insertBase); + if (ie == NULL) { + if (llvm::isa(insertBase)) { + return; + } - // Get the value out of a constant vector if that's what we - // have - llvm::ConstantVector *cv = - llvm::dyn_cast(insertBase); + // Get the value out of a constant vector if that's what we + // have + llvm::ConstantVector *cv = + llvm::dyn_cast(insertBase); - // FIXME: this assert is a little questionable; we probably - // shouldn't fail in this case but should just return an - // incomplete result. But there aren't currently any known - // cases where we have anything other than an undef value or a - // constant vector at the base, so if that ever does happen, - // it'd be nice to know what happend so that perhaps we can - // handle it. - // FIXME: Also, should we handle ConstantDataVectors with - // LLVM3.1? What about ConstantAggregateZero values?? - Assert(cv != NULL); + // FIXME: this assert is a little questionable; we probably + // shouldn't fail in this case but should just return an + // incomplete result. But there aren't currently any known + // cases where we have anything other than an undef value or a + // constant vector at the base, so if that ever does happen, + // it'd be nice to know what happend so that perhaps we can + // handle it. + // FIXME: Also, should we handle ConstantDataVectors with + // LLVM3.1? What about ConstantAggregateZero values?? + Assert(cv != NULL); - Assert(iOffset < (int)cv->getNumOperands()); - elements[iOffset] = cv->getOperand((int32_t)iOffset); + Assert(iOffset < (int)cv->getNumOperands()); + elements[iOffset] = cv->getOperand((int32_t)iOffset); + } + } + } + // Catch a pattern of broadcast implemented as InsertElement + Shuffle: + // %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0 + // %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef, + // <4 x i32> zeroinitializer + else if (llvm::ShuffleVectorInst *shuf = + llvm::dyn_cast(inst)) { + llvm::Value *indices = shuf->getOperand(2); + if (llvm::isa(indices)) { + llvm::Value *op = shuf->getOperand(0); + llvm::InsertElementInst *ie = llvm::dyn_cast(op); + if (ie != NULL && + llvm::isa(ie->getOperand(0))) { + llvm::ConstantInt *ci = + llvm::dyn_cast(ie->getOperand(2)); + + if (ci->isZero()) { + for (int i = 0; i < vectorWidth; ++i) { + elements[i] = ie->getOperand(1); + } + return; + } + } } } } @@ -694,10 +724,10 @@ lIsExactMultiple(llvm::Value *val, int baseValue, int vectorLength, else Assert(LLVMVectorValuesAllEqual(val)); - llvm::InsertElementInst *ie = llvm::dyn_cast(val); - if (ie != NULL) { + if (llvm::isa(val) || + llvm::isa(val)) { llvm::Value *elts[ISPC_MAX_NVEC]; - LLVMFlattenInsertChain(ie, g->target->getVectorWidth(), elts); + LLVMFlattenInsertChain(val, g->target->getVectorWidth(), elts); // We just need to check the scalar first value, since we know that // all elements are equal return lIsExactMultiple(elts[0], baseValue, vectorLength, @@ -1440,10 +1470,10 @@ lExtractFirstVectorElement(llvm::Value *v, // If we have a chain of insertelement instructions, then we can just // flatten them out and grab the value for the first one. - llvm::InsertElementInst *ie = llvm::dyn_cast(v); - if (ie != NULL) { + if (llvm::isa(v) || + llvm::isa(v)) { llvm::Value *elements[ISPC_MAX_NVEC]; - LLVMFlattenInsertChain(ie, vt->getNumElements(), elements); + LLVMFlattenInsertChain(v, vt->getNumElements(), elements); return elements[0]; } diff --git a/llvmutil.h b/llvmutil.h index d14a5000..c8d6f32b 100644 --- a/llvmutil.h +++ b/llvmutil.h @@ -264,8 +264,13 @@ extern bool LLVMExtractVectorInts(llvm::Value *v, int64_t ret[], int *nElts); constant vector. For anything more complex (e.g. some other arbitrary value, it doesn't try to extract element values into the returned array. + + This also handles common broadcast pattern: + %broadcast_init.0 = insertelement <4 x i32> undef, i32 %val, i32 0 + %broadcast.1 = shufflevector <4 x i32> %smear.0, <4 x i32> undef, + <4 x i32> zeroinitializer */ -extern void LLVMFlattenInsertChain(llvm::InsertElementInst *ie, int vectorWidth, +extern void LLVMFlattenInsertChain(llvm::Value *inst, int vectorWidth, llvm::Value **elements); /** This is a utility routine for debugging that dumps out the given LLVM diff --git a/opt.cpp b/opt.cpp index b310c155..687aa507 100644 --- a/opt.cpp +++ b/opt.cpp @@ -1058,10 +1058,10 @@ lCheckForActualPointer(llvm::Value *v) { */ static llvm::Value * lGetBasePointer(llvm::Value *v) { - llvm::InsertElementInst *ie = llvm::dyn_cast(v); - if (ie != NULL) { + if (llvm::isa(v) || + llvm::isa(v)) { llvm::Value *elements[ISPC_MAX_NVEC]; - LLVMFlattenInsertChain(ie, g->target->getVectorWidth(), elements); + LLVMFlattenInsertChain(v, g->target->getVectorWidth(), elements); // Make sure none of the elements is undefined. // TODO: it's probably ok to allow undefined elements and return @@ -1080,9 +1080,12 @@ lGetBasePointer(llvm::Value *v) { } // This case comes up with global/static arrays - llvm::ConstantVector *cv = llvm::dyn_cast(v); - if (cv != NULL) + if (llvm::ConstantVector *cv = llvm::dyn_cast(v)) { return lCheckForActualPointer(cv->getSplatValue()); + } + else if (llvm::ConstantDataVector *cdv = llvm::dyn_cast(v)) { + return lCheckForActualPointer(cdv->getSplatValue()); + } return NULL; } diff --git a/stmt.cpp b/stmt.cpp index 0b789626..32fe672a 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -1993,11 +1993,7 @@ ForeachActiveStmt::EmitCode(FunctionEmitContext *ctx) const { // math...) // Get the "program index" vector value - llvm::Value *programIndex = - llvm::UndefValue::get(LLVMTypes::Int32VectorType); - for (int i = 0; i < g->target->getVectorWidth(); ++i) - programIndex = ctx->InsertInst(programIndex, LLVMInt32(i), i, - "prog_index"); + llvm::Value *programIndex = ctx->ProgramIndex(); // And smear the current lane out to a vector llvm::Value *firstSet32 =