Incorporate per-lane offsets for varying data in the front-end.

Previously, it was only in the GatherScatterFlattenOpt optimization pass that
we added the per-lane offsets when we were indexing into varying data.
(Specifically, the case of float foo[]; int index; foo[index], where foo
is an array of varying elements rather than uniform elements.)  Now, this
is done in the front-end as we're first emitting code.

In addition to the basic ugliness of doing this in an optimization pass, 
it was also error-prone to do it there, since we no longer have access
to all of the type information that's around in the front-end.

No functionality or performance change.
This commit is contained in:
Matt Pharr
2011-11-03 13:15:07 -07:00
parent 6084d6aeaf
commit 43a2d510bf
7 changed files with 247 additions and 129 deletions

240
expr.cpp
View File

@@ -2522,6 +2522,78 @@ lCastUniformVectorBasePtr(llvm::Value *ptr, FunctionEmitContext *ctx) {
}
/** When computing pointer values, we need to apply a per-lane offset when
we're indexing into varying data. Consdier the following ispc code:
uniform float u[] = ...;
float v[] = ...;
int index = ...;
float a = u[index];
float b = v[index];
To compute the varying pointer that holds the addresses to load from
for u[index], we basically just need to multiply index element-wise by
sizeof(float) before doing the memory load. For v[index], we need to
do the same scaling but also need to add per-lane offsets <0,
sizeof(float), 2*sizeof(float), ...> so that the i'th lane loads the
i'th of the varying values at its index value.
This function handles figuring out when this additional offset is
needed and then incorporates it in the varying pointer value.
*/
static llvm::Value *
lAddVaryingOffsetsIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr,
const Type *returnType, const Type *indexedType) {
// If the result of the indexing isn't a varying atomic type, then
// nothing to do here.
if (returnType->IsVaryingType() == false ||
dynamic_cast<const AtomicType *>(returnType) == NULL)
return ptr;
// We should now have an array of pointer values, represing in a
// varying pointer.
LLVM_TYPE_CONST llvm::ArrayType *at =
llvm::dyn_cast<LLVM_TYPE_CONST llvm::ArrayType>(ptr->getType());
if (at == NULL)
return ptr;
LLVM_TYPE_CONST llvm::PointerType *pt =
llvm::dyn_cast<LLVM_TYPE_CONST llvm::PointerType>(at->getElementType());
assert(pt != NULL);
// If the pointers are to uniform types (e.g. ptr->getType() ==
// [8 x float *]), then we have the u[index] situation from the comment
// above, and no additional offset is needed. Otherwise we have
// pointers to varying atomic types--e.g. ptr->getType() ==
// [8 x <8 x float> *]
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(pt->getElementType()) == false)
return ptr;
// But not so fast: if the reason we have a vector of pointers is that
// we're indexing into an array of uniform short-vector types, then we
// don't need the offsets.
if (dynamic_cast<const VectorType *>(indexedType) != NULL)
return ptr;
// Onward: compute the per lane offsets.
llvm::Value *varyingOffsets =
llvm::UndefValue::get(LLVMTypes::Int32VectorType);
for (int i = 0; i < g->target.vectorWidth; ++i)
varyingOffsets = ctx->InsertInst(varyingOffsets, LLVMInt32(i), i,
"varying_delta");
// Cast the pointer to the corresponding uniform pointer
// type--e.g. from [8 x <8 x float> *] to [8 x float *].
LLVM_TYPE_CONST llvm::Type *unifType =
returnType->GetAsUniformType()->LLVMType(g->ctx);
LLVM_TYPE_CONST llvm::PointerType *ptrCastType =
llvm::PointerType::get(llvm::ArrayType::get(unifType, 0), 0);
ptr = ctx->BitCastInst(ptr, ptrCastType, "ptr2unif");
// And finally add the per-lane offsets.
return ctx->GetElementPtrInst(ptr, LLVMInt32(0), varyingOffsets);
}
llvm::Value *
IndexExpr::GetValue(FunctionEmitContext *ctx) const {
const Type *arrayOrVectorType;
@@ -2547,6 +2619,8 @@ IndexExpr::GetValue(FunctionEmitContext *ctx) const {
ctx->StoreInst(val, ptr);
ptr = lCastUniformVectorBasePtr(ptr, ctx);
lvalue = ctx->GetElementPtrInst(ptr, LLVMInt32(0), index->GetValue(ctx));
lvalue = lAddVaryingOffsetsIfNeeded(ctx, lvalue, GetType(),
arrayOrVectorType);
mask = LLVMMaskAllOn;
}
else {
@@ -2593,19 +2667,20 @@ IndexExpr::GetBaseSymbol() const {
llvm::Value *
IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
const Type *type;
if (!arrayOrVector || !index || ((type = arrayOrVector->GetType()) == NULL))
const Type *arrayOrVectorType;
if (arrayOrVector == NULL || index == NULL ||
((arrayOrVectorType = arrayOrVector->GetType()) == NULL))
return NULL;
ctx->SetDebugPos(pos);
llvm::Value *basePtr = NULL;
if (dynamic_cast<const ArrayType *>(type) ||
dynamic_cast<const VectorType *>(type))
if (dynamic_cast<const ArrayType *>(arrayOrVectorType) ||
dynamic_cast<const VectorType *>(arrayOrVectorType))
basePtr = arrayOrVector->GetLValue(ctx);
else {
type = type->GetReferenceTarget();
assert(dynamic_cast<const ArrayType *>(type) ||
dynamic_cast<const VectorType *>(type));
arrayOrVectorType = arrayOrVectorType->GetReferenceTarget();
assert(dynamic_cast<const ArrayType *>(arrayOrVectorType) ||
dynamic_cast<const VectorType *>(arrayOrVectorType));
basePtr = arrayOrVector->GetValue(ctx);
}
if (!basePtr)
@@ -2614,7 +2689,8 @@ IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
// If the array index is a compile time constant, check to see if it
// may lead to an out-of-bounds access.
ConstExpr *ce = dynamic_cast<ConstExpr *>(index);
const SequentialType *seqType = dynamic_cast<const SequentialType *>(type);
const SequentialType *seqType =
dynamic_cast<const SequentialType *>(arrayOrVectorType);
assert(seqType != NULL);
int nElements = seqType->GetElementCount();
if (ce != NULL && nElements > 0) {
@@ -2630,7 +2706,11 @@ IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
basePtr = lCastUniformVectorBasePtr(basePtr, ctx);
ctx->SetDebugPos(pos);
return ctx->GetElementPtrInst(basePtr, LLVMInt32(0), index->GetValue(ctx));
llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, LLVMInt32(0),
index->GetValue(ctx));
ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetType(), arrayOrVectorType);
return ptr;
}
@@ -2731,27 +2811,32 @@ lIdentifierToVectorElement(char id) {
}
}
//////////////////////////////////////////////////
// StructMemberExpr
class StructMemberExpr : public MemberExpr
{
public:
StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, const StructType* structType);
const Type* GetType() const;
SourcePos idpos, const StructType *structType);
const Type *GetType() const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const StructType* exprStructType;
const StructType *exprStructType;
};
StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos,
const StructType* structType)
const StructType *structType)
: MemberExpr(e, id, p, idpos), exprStructType(structType) {
}
const Type*
const Type *
StructMemberExpr::GetType() const {
// It's a struct, and the result type is the element
// type, possibly promoted to varying if the struct type / lvalue
@@ -2780,26 +2865,35 @@ StructMemberExpr::getElementNumber() const {
return elementNumber;
}
const Type *
StructMemberExpr::getElementType() const {
return exprStructType->GetAsUniformType()->GetElementType(identifier);
}
//////////////////////////////////////////////////
// VectorMemberExpr
class VectorMemberExpr : public MemberExpr
{
public:
VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, const VectorType* vectorType);
~VectorMemberExpr();
const Type* GetType() const;
llvm::Value* GetLValue(FunctionEmitContext* ctx) const;
llvm::Value* GetValue(FunctionEmitContext* ctx) const;
const Type *GetType() const;
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
llvm::Value *GetValue(FunctionEmitContext* ctx) const;
int getElementNumber() const;
const Type *getElementType() const;
private:
const VectorType* exprVectorType;
const VectorType* memberType;
const VectorType *exprVectorType;
const VectorType *memberType;
};
VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos,
const VectorType* vectorType)
@@ -2808,11 +2902,8 @@ VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p,
identifier.length());
}
VectorMemberExpr::~VectorMemberExpr() {
delete memberType;
}
const Type*
const Type *
VectorMemberExpr::GetType() const {
// For 1-element expressions, we have the base vector element
// type. For n-element expressions, we have a shortvec type
@@ -2826,7 +2917,7 @@ VectorMemberExpr::GetType() const {
}
llvm::Value*
llvm::Value *
VectorMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
if (identifier.length() == 1) {
return MemberExpr::GetLValue(ctx);
@@ -2836,11 +2927,12 @@ VectorMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
}
llvm::Value*
llvm::Value *
VectorMemberExpr::GetValue(FunctionEmitContext* ctx) const {
if (identifier.length() == 1) {
return MemberExpr::GetValue(ctx);
} else {
}
else {
std::vector<int> indices;
for (size_t i = 0; i < identifier.size(); ++i) {
@@ -2866,8 +2958,7 @@ VectorMemberExpr::GetValue(FunctionEmitContext* ctx) const {
llvm::Value *ptmp =
ctx->GetElementPtrInst(ltmp, 0, i, "new_offset");
llvm::Value *initLValue =
ctx->GetElementPtrInst(basePtr , 0,
indices[i], "orig_offset");
ctx->GetElementPtrInst(basePtr, 0, indices[i], "orig_offset");
llvm::Value *initValue =
ctx->LoadInst(initLValue, NULL, memberType->GetElementType(),
"vec_element");
@@ -2878,6 +2969,7 @@ VectorMemberExpr::GetValue(FunctionEmitContext* ctx) const {
}
}
int
VectorMemberExpr::getElementNumber() const {
int elementNumber = lIdentifierToVectorElement(identifier[0]);
@@ -2887,43 +2979,51 @@ VectorMemberExpr::getElementNumber() const {
return elementNumber;
}
const Type *
VectorMemberExpr::getElementType() const {
return memberType;
}
///////////////////////////////////////////////////////////////////////////
// ReferenceMemberExpr
class ReferenceMemberExpr : public MemberExpr
{
public:
ReferenceMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos, const ReferenceType* referenceType);
const Type* GetType() const;
const Type *GetType() const;
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
int getElementNumber() const;
llvm::Value* GetLValue(FunctionEmitContext* ctx) const;
const Type *getElementType() const;
private:
const ReferenceType* exprReferenceType;
MemberExpr* dereferencedExpr;
const ReferenceType *exprReferenceType;
MemberExpr *dereferencedExpr;
};
ReferenceMemberExpr::ReferenceMemberExpr(Expr *e, const char *id, SourcePos p,
SourcePos idpos,
const ReferenceType* referenceType)
const ReferenceType *referenceType)
: MemberExpr(e, id, p, idpos), exprReferenceType(referenceType) {
const Type* refTarget = exprReferenceType->GetReferenceTarget();
const StructType* structType
= dynamic_cast<const StructType *>(refTarget);
const VectorType* vectorType
= dynamic_cast<const VectorType *>(refTarget);
const Type *refTarget = exprReferenceType->GetReferenceTarget();
const StructType *structType = dynamic_cast<const StructType *>(refTarget);
const VectorType *vectorType = dynamic_cast<const VectorType *>(refTarget);
if (structType != NULL) {
if (structType != NULL)
dereferencedExpr = new StructMemberExpr(e, id, p, idpos, structType);
} else if (vectorType != NULL) {
else if (vectorType != NULL)
dereferencedExpr = new VectorMemberExpr(e, id, p, idpos, vectorType);
} else {
else
dereferencedExpr = NULL;
}
}
const Type*
const Type *
ReferenceMemberExpr::GetType() const {
if (dereferencedExpr == NULL) {
Error(pos, "Can't access member of non-struct/vector type \"%s\".",
@@ -2934,6 +3034,7 @@ ReferenceMemberExpr::GetType() const {
}
}
int
ReferenceMemberExpr::getElementNumber() const {
if (dereferencedExpr == NULL) {
@@ -2945,7 +3046,15 @@ ReferenceMemberExpr::getElementNumber() const {
}
}
llvm::Value*
const Type *
ReferenceMemberExpr::getElementType() const {
assert(dereferencedExpr != NULL);
return dereferencedExpr->getElementType();
}
llvm::Value *
ReferenceMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
if (dereferencedExpr == NULL) {
// FIXME: again I think typechecking should have caught this
@@ -2965,29 +3074,35 @@ ReferenceMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
return NULL;
ctx->SetDebugPos(pos);
return ctx->GetElementPtrInst(basePtr, 0, elementNumber);
llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, 0, elementNumber);
const Type *elementType = getElementType();
ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetType(), elementType);
return ptr;
}
MemberExpr*
MemberExpr *
MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos) {
const Type* exprType;
const Type *exprType;
if (e == NULL || (exprType = e->GetType()) == NULL)
return new MemberExpr(e, id, p, idpos);
return NULL;
const StructType* structType = dynamic_cast<const StructType*>(exprType);
const StructType *structType = dynamic_cast<const StructType*>(exprType);
if (structType != NULL)
return new StructMemberExpr(e, id, p, idpos, structType);
const VectorType* vectorType = dynamic_cast<const VectorType*>(exprType);
const VectorType *vectorType = dynamic_cast<const VectorType*>(exprType);
if (vectorType != NULL)
return new VectorMemberExpr(e, id, p, idpos, vectorType);
const ReferenceType* referenceType = dynamic_cast<const ReferenceType*>(exprType);
const ReferenceType *referenceType = dynamic_cast<const ReferenceType*>(exprType);
if (referenceType != NULL)
return new ReferenceMemberExpr(e, id, p, idpos, referenceType);
return new MemberExpr(e, id, p, idpos);
FATAL("Unexpected case in MemberExpr::create()");
return NULL;
}
@@ -3024,6 +3139,8 @@ MemberExpr::GetValue(FunctionEmitContext *ctx) const {
if (elementNumber == -1)
return NULL;
lvalue = ctx->GetElementPtrInst(ptr, 0, elementNumber);
lvalue = lAddVaryingOffsetsIfNeeded(ctx, lvalue, GetType(), getElementType());
mask = LLVMMaskAllOn;
}
else {
@@ -3074,7 +3191,10 @@ MemberExpr::GetLValue(FunctionEmitContext *ctx) const {
return NULL;
ctx->SetDebugPos(pos);
return ctx->GetElementPtrInst(basePtr, 0, elementNumber);
llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, 0, elementNumber);
ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetType(), getElementType());
return ptr;
}