Incorporate per-lane offsets for varying data in the front-end.
Previously, it was only in the GatherScatterFlattenOpt optimization pass that we added the per-lane offsets when we were indexing into varying data. (Specifically, the case of float foo[]; int index; foo[index], where foo is an array of varying elements rather than uniform elements.) Now, this is done in the front-end as we're first emitting code. In addition to the basic ugliness of doing this in an optimization pass, it was also error-prone to do it there, since we no longer have access to all of the type information that's around in the front-end. No functionality or performance change.
This commit is contained in:
240
expr.cpp
240
expr.cpp
@@ -2522,6 +2522,78 @@ lCastUniformVectorBasePtr(llvm::Value *ptr, FunctionEmitContext *ctx) {
|
||||
}
|
||||
|
||||
|
||||
/** When computing pointer values, we need to apply a per-lane offset when
|
||||
we're indexing into varying data. Consdier the following ispc code:
|
||||
|
||||
uniform float u[] = ...;
|
||||
float v[] = ...;
|
||||
int index = ...;
|
||||
float a = u[index];
|
||||
float b = v[index];
|
||||
|
||||
To compute the varying pointer that holds the addresses to load from
|
||||
for u[index], we basically just need to multiply index element-wise by
|
||||
sizeof(float) before doing the memory load. For v[index], we need to
|
||||
do the same scaling but also need to add per-lane offsets <0,
|
||||
sizeof(float), 2*sizeof(float), ...> so that the i'th lane loads the
|
||||
i'th of the varying values at its index value.
|
||||
|
||||
This function handles figuring out when this additional offset is
|
||||
needed and then incorporates it in the varying pointer value.
|
||||
*/
|
||||
static llvm::Value *
|
||||
lAddVaryingOffsetsIfNeeded(FunctionEmitContext *ctx, llvm::Value *ptr,
|
||||
const Type *returnType, const Type *indexedType) {
|
||||
// If the result of the indexing isn't a varying atomic type, then
|
||||
// nothing to do here.
|
||||
if (returnType->IsVaryingType() == false ||
|
||||
dynamic_cast<const AtomicType *>(returnType) == NULL)
|
||||
return ptr;
|
||||
|
||||
// We should now have an array of pointer values, represing in a
|
||||
// varying pointer.
|
||||
LLVM_TYPE_CONST llvm::ArrayType *at =
|
||||
llvm::dyn_cast<LLVM_TYPE_CONST llvm::ArrayType>(ptr->getType());
|
||||
if (at == NULL)
|
||||
return ptr;
|
||||
LLVM_TYPE_CONST llvm::PointerType *pt =
|
||||
llvm::dyn_cast<LLVM_TYPE_CONST llvm::PointerType>(at->getElementType());
|
||||
assert(pt != NULL);
|
||||
|
||||
// If the pointers are to uniform types (e.g. ptr->getType() ==
|
||||
// [8 x float *]), then we have the u[index] situation from the comment
|
||||
// above, and no additional offset is needed. Otherwise we have
|
||||
// pointers to varying atomic types--e.g. ptr->getType() ==
|
||||
// [8 x <8 x float> *]
|
||||
if (llvm::isa<LLVM_TYPE_CONST llvm::VectorType>(pt->getElementType()) == false)
|
||||
return ptr;
|
||||
|
||||
// But not so fast: if the reason we have a vector of pointers is that
|
||||
// we're indexing into an array of uniform short-vector types, then we
|
||||
// don't need the offsets.
|
||||
if (dynamic_cast<const VectorType *>(indexedType) != NULL)
|
||||
return ptr;
|
||||
|
||||
// Onward: compute the per lane offsets.
|
||||
llvm::Value *varyingOffsets =
|
||||
llvm::UndefValue::get(LLVMTypes::Int32VectorType);
|
||||
for (int i = 0; i < g->target.vectorWidth; ++i)
|
||||
varyingOffsets = ctx->InsertInst(varyingOffsets, LLVMInt32(i), i,
|
||||
"varying_delta");
|
||||
|
||||
// Cast the pointer to the corresponding uniform pointer
|
||||
// type--e.g. from [8 x <8 x float> *] to [8 x float *].
|
||||
LLVM_TYPE_CONST llvm::Type *unifType =
|
||||
returnType->GetAsUniformType()->LLVMType(g->ctx);
|
||||
LLVM_TYPE_CONST llvm::PointerType *ptrCastType =
|
||||
llvm::PointerType::get(llvm::ArrayType::get(unifType, 0), 0);
|
||||
ptr = ctx->BitCastInst(ptr, ptrCastType, "ptr2unif");
|
||||
|
||||
// And finally add the per-lane offsets.
|
||||
return ctx->GetElementPtrInst(ptr, LLVMInt32(0), varyingOffsets);
|
||||
}
|
||||
|
||||
|
||||
llvm::Value *
|
||||
IndexExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||
const Type *arrayOrVectorType;
|
||||
@@ -2547,6 +2619,8 @@ IndexExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||
ctx->StoreInst(val, ptr);
|
||||
ptr = lCastUniformVectorBasePtr(ptr, ctx);
|
||||
lvalue = ctx->GetElementPtrInst(ptr, LLVMInt32(0), index->GetValue(ctx));
|
||||
lvalue = lAddVaryingOffsetsIfNeeded(ctx, lvalue, GetType(),
|
||||
arrayOrVectorType);
|
||||
mask = LLVMMaskAllOn;
|
||||
}
|
||||
else {
|
||||
@@ -2593,19 +2667,20 @@ IndexExpr::GetBaseSymbol() const {
|
||||
|
||||
llvm::Value *
|
||||
IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
|
||||
const Type *type;
|
||||
if (!arrayOrVector || !index || ((type = arrayOrVector->GetType()) == NULL))
|
||||
const Type *arrayOrVectorType;
|
||||
if (arrayOrVector == NULL || index == NULL ||
|
||||
((arrayOrVectorType = arrayOrVector->GetType()) == NULL))
|
||||
return NULL;
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
llvm::Value *basePtr = NULL;
|
||||
if (dynamic_cast<const ArrayType *>(type) ||
|
||||
dynamic_cast<const VectorType *>(type))
|
||||
if (dynamic_cast<const ArrayType *>(arrayOrVectorType) ||
|
||||
dynamic_cast<const VectorType *>(arrayOrVectorType))
|
||||
basePtr = arrayOrVector->GetLValue(ctx);
|
||||
else {
|
||||
type = type->GetReferenceTarget();
|
||||
assert(dynamic_cast<const ArrayType *>(type) ||
|
||||
dynamic_cast<const VectorType *>(type));
|
||||
arrayOrVectorType = arrayOrVectorType->GetReferenceTarget();
|
||||
assert(dynamic_cast<const ArrayType *>(arrayOrVectorType) ||
|
||||
dynamic_cast<const VectorType *>(arrayOrVectorType));
|
||||
basePtr = arrayOrVector->GetValue(ctx);
|
||||
}
|
||||
if (!basePtr)
|
||||
@@ -2614,7 +2689,8 @@ IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
|
||||
// If the array index is a compile time constant, check to see if it
|
||||
// may lead to an out-of-bounds access.
|
||||
ConstExpr *ce = dynamic_cast<ConstExpr *>(index);
|
||||
const SequentialType *seqType = dynamic_cast<const SequentialType *>(type);
|
||||
const SequentialType *seqType =
|
||||
dynamic_cast<const SequentialType *>(arrayOrVectorType);
|
||||
assert(seqType != NULL);
|
||||
int nElements = seqType->GetElementCount();
|
||||
if (ce != NULL && nElements > 0) {
|
||||
@@ -2630,7 +2706,11 @@ IndexExpr::GetLValue(FunctionEmitContext *ctx) const {
|
||||
basePtr = lCastUniformVectorBasePtr(basePtr, ctx);
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
return ctx->GetElementPtrInst(basePtr, LLVMInt32(0), index->GetValue(ctx));
|
||||
llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, LLVMInt32(0),
|
||||
index->GetValue(ctx));
|
||||
ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetType(), arrayOrVectorType);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -2731,27 +2811,32 @@ lIdentifierToVectorElement(char id) {
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// StructMemberExpr
|
||||
|
||||
class StructMemberExpr : public MemberExpr
|
||||
{
|
||||
public:
|
||||
StructMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
SourcePos idpos, const StructType* structType);
|
||||
|
||||
const Type* GetType() const;
|
||||
SourcePos idpos, const StructType *structType);
|
||||
|
||||
const Type *GetType() const;
|
||||
int getElementNumber() const;
|
||||
const Type *getElementType() const;
|
||||
|
||||
private:
|
||||
const StructType* exprStructType;
|
||||
const StructType *exprStructType;
|
||||
};
|
||||
|
||||
|
||||
StructMemberExpr::StructMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
SourcePos idpos,
|
||||
const StructType* structType)
|
||||
const StructType *structType)
|
||||
: MemberExpr(e, id, p, idpos), exprStructType(structType) {
|
||||
}
|
||||
|
||||
const Type*
|
||||
|
||||
const Type *
|
||||
StructMemberExpr::GetType() const {
|
||||
// It's a struct, and the result type is the element
|
||||
// type, possibly promoted to varying if the struct type / lvalue
|
||||
@@ -2780,26 +2865,35 @@ StructMemberExpr::getElementNumber() const {
|
||||
return elementNumber;
|
||||
}
|
||||
|
||||
|
||||
const Type *
|
||||
StructMemberExpr::getElementType() const {
|
||||
return exprStructType->GetAsUniformType()->GetElementType(identifier);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////
|
||||
// VectorMemberExpr
|
||||
|
||||
class VectorMemberExpr : public MemberExpr
|
||||
{
|
||||
public:
|
||||
VectorMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
SourcePos idpos, const VectorType* vectorType);
|
||||
|
||||
~VectorMemberExpr();
|
||||
|
||||
const Type* GetType() const;
|
||||
|
||||
llvm::Value* GetLValue(FunctionEmitContext* ctx) const;
|
||||
|
||||
llvm::Value* GetValue(FunctionEmitContext* ctx) const;
|
||||
const Type *GetType() const;
|
||||
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
|
||||
llvm::Value *GetValue(FunctionEmitContext* ctx) const;
|
||||
|
||||
int getElementNumber() const;
|
||||
const Type *getElementType() const;
|
||||
|
||||
private:
|
||||
const VectorType* exprVectorType;
|
||||
const VectorType* memberType;
|
||||
const VectorType *exprVectorType;
|
||||
const VectorType *memberType;
|
||||
};
|
||||
|
||||
|
||||
VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
SourcePos idpos,
|
||||
const VectorType* vectorType)
|
||||
@@ -2808,11 +2902,8 @@ VectorMemberExpr::VectorMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
identifier.length());
|
||||
}
|
||||
|
||||
VectorMemberExpr::~VectorMemberExpr() {
|
||||
delete memberType;
|
||||
}
|
||||
|
||||
const Type*
|
||||
const Type *
|
||||
VectorMemberExpr::GetType() const {
|
||||
// For 1-element expressions, we have the base vector element
|
||||
// type. For n-element expressions, we have a shortvec type
|
||||
@@ -2826,7 +2917,7 @@ VectorMemberExpr::GetType() const {
|
||||
}
|
||||
|
||||
|
||||
llvm::Value*
|
||||
llvm::Value *
|
||||
VectorMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
|
||||
if (identifier.length() == 1) {
|
||||
return MemberExpr::GetLValue(ctx);
|
||||
@@ -2836,11 +2927,12 @@ VectorMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
|
||||
}
|
||||
|
||||
|
||||
llvm::Value*
|
||||
llvm::Value *
|
||||
VectorMemberExpr::GetValue(FunctionEmitContext* ctx) const {
|
||||
if (identifier.length() == 1) {
|
||||
return MemberExpr::GetValue(ctx);
|
||||
} else {
|
||||
}
|
||||
else {
|
||||
std::vector<int> indices;
|
||||
|
||||
for (size_t i = 0; i < identifier.size(); ++i) {
|
||||
@@ -2866,8 +2958,7 @@ VectorMemberExpr::GetValue(FunctionEmitContext* ctx) const {
|
||||
llvm::Value *ptmp =
|
||||
ctx->GetElementPtrInst(ltmp, 0, i, "new_offset");
|
||||
llvm::Value *initLValue =
|
||||
ctx->GetElementPtrInst(basePtr , 0,
|
||||
indices[i], "orig_offset");
|
||||
ctx->GetElementPtrInst(basePtr, 0, indices[i], "orig_offset");
|
||||
llvm::Value *initValue =
|
||||
ctx->LoadInst(initLValue, NULL, memberType->GetElementType(),
|
||||
"vec_element");
|
||||
@@ -2878,6 +2969,7 @@ VectorMemberExpr::GetValue(FunctionEmitContext* ctx) const {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
VectorMemberExpr::getElementNumber() const {
|
||||
int elementNumber = lIdentifierToVectorElement(identifier[0]);
|
||||
@@ -2887,43 +2979,51 @@ VectorMemberExpr::getElementNumber() const {
|
||||
return elementNumber;
|
||||
}
|
||||
|
||||
|
||||
const Type *
|
||||
VectorMemberExpr::getElementType() const {
|
||||
return memberType;
|
||||
}
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// ReferenceMemberExpr
|
||||
|
||||
class ReferenceMemberExpr : public MemberExpr
|
||||
{
|
||||
public:
|
||||
ReferenceMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
SourcePos idpos, const ReferenceType* referenceType);
|
||||
|
||||
const Type* GetType() const;
|
||||
const Type *GetType() const;
|
||||
llvm::Value *GetLValue(FunctionEmitContext* ctx) const;
|
||||
|
||||
int getElementNumber() const;
|
||||
|
||||
llvm::Value* GetLValue(FunctionEmitContext* ctx) const;
|
||||
const Type *getElementType() const;
|
||||
|
||||
private:
|
||||
const ReferenceType* exprReferenceType;
|
||||
MemberExpr* dereferencedExpr;
|
||||
const ReferenceType *exprReferenceType;
|
||||
MemberExpr *dereferencedExpr;
|
||||
};
|
||||
|
||||
ReferenceMemberExpr::ReferenceMemberExpr(Expr *e, const char *id, SourcePos p,
|
||||
SourcePos idpos,
|
||||
const ReferenceType* referenceType)
|
||||
const ReferenceType *referenceType)
|
||||
: MemberExpr(e, id, p, idpos), exprReferenceType(referenceType) {
|
||||
const Type* refTarget = exprReferenceType->GetReferenceTarget();
|
||||
const StructType* structType
|
||||
= dynamic_cast<const StructType *>(refTarget);
|
||||
const VectorType* vectorType
|
||||
= dynamic_cast<const VectorType *>(refTarget);
|
||||
const Type *refTarget = exprReferenceType->GetReferenceTarget();
|
||||
const StructType *structType = dynamic_cast<const StructType *>(refTarget);
|
||||
const VectorType *vectorType = dynamic_cast<const VectorType *>(refTarget);
|
||||
|
||||
if (structType != NULL) {
|
||||
if (structType != NULL)
|
||||
dereferencedExpr = new StructMemberExpr(e, id, p, idpos, structType);
|
||||
} else if (vectorType != NULL) {
|
||||
else if (vectorType != NULL)
|
||||
dereferencedExpr = new VectorMemberExpr(e, id, p, idpos, vectorType);
|
||||
} else {
|
||||
else
|
||||
dereferencedExpr = NULL;
|
||||
}
|
||||
}
|
||||
|
||||
const Type*
|
||||
|
||||
const Type *
|
||||
ReferenceMemberExpr::GetType() const {
|
||||
if (dereferencedExpr == NULL) {
|
||||
Error(pos, "Can't access member of non-struct/vector type \"%s\".",
|
||||
@@ -2934,6 +3034,7 @@ ReferenceMemberExpr::GetType() const {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int
|
||||
ReferenceMemberExpr::getElementNumber() const {
|
||||
if (dereferencedExpr == NULL) {
|
||||
@@ -2945,7 +3046,15 @@ ReferenceMemberExpr::getElementNumber() const {
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value*
|
||||
|
||||
const Type *
|
||||
ReferenceMemberExpr::getElementType() const {
|
||||
assert(dereferencedExpr != NULL);
|
||||
return dereferencedExpr->getElementType();
|
||||
}
|
||||
|
||||
|
||||
llvm::Value *
|
||||
ReferenceMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
|
||||
if (dereferencedExpr == NULL) {
|
||||
// FIXME: again I think typechecking should have caught this
|
||||
@@ -2965,29 +3074,35 @@ ReferenceMemberExpr::GetLValue(FunctionEmitContext* ctx) const {
|
||||
return NULL;
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
return ctx->GetElementPtrInst(basePtr, 0, elementNumber);
|
||||
llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, 0, elementNumber);
|
||||
|
||||
const Type *elementType = getElementType();
|
||||
ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetType(), elementType);
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
||||
MemberExpr*
|
||||
MemberExpr *
|
||||
MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos) {
|
||||
const Type* exprType;
|
||||
const Type *exprType;
|
||||
if (e == NULL || (exprType = e->GetType()) == NULL)
|
||||
return new MemberExpr(e, id, p, idpos);
|
||||
return NULL;
|
||||
|
||||
const StructType* structType = dynamic_cast<const StructType*>(exprType);
|
||||
const StructType *structType = dynamic_cast<const StructType*>(exprType);
|
||||
if (structType != NULL)
|
||||
return new StructMemberExpr(e, id, p, idpos, structType);
|
||||
|
||||
const VectorType* vectorType = dynamic_cast<const VectorType*>(exprType);
|
||||
const VectorType *vectorType = dynamic_cast<const VectorType*>(exprType);
|
||||
if (vectorType != NULL)
|
||||
return new VectorMemberExpr(e, id, p, idpos, vectorType);
|
||||
|
||||
const ReferenceType* referenceType = dynamic_cast<const ReferenceType*>(exprType);
|
||||
const ReferenceType *referenceType = dynamic_cast<const ReferenceType*>(exprType);
|
||||
if (referenceType != NULL)
|
||||
return new ReferenceMemberExpr(e, id, p, idpos, referenceType);
|
||||
|
||||
return new MemberExpr(e, id, p, idpos);
|
||||
|
||||
FATAL("Unexpected case in MemberExpr::create()");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
@@ -3024,6 +3139,8 @@ MemberExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||
if (elementNumber == -1)
|
||||
return NULL;
|
||||
lvalue = ctx->GetElementPtrInst(ptr, 0, elementNumber);
|
||||
lvalue = lAddVaryingOffsetsIfNeeded(ctx, lvalue, GetType(), getElementType());
|
||||
|
||||
mask = LLVMMaskAllOn;
|
||||
}
|
||||
else {
|
||||
@@ -3074,7 +3191,10 @@ MemberExpr::GetLValue(FunctionEmitContext *ctx) const {
|
||||
return NULL;
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
return ctx->GetElementPtrInst(basePtr, 0, elementNumber);
|
||||
llvm::Value *ptr = ctx->GetElementPtrInst(basePtr, 0, elementNumber);
|
||||
ptr = lAddVaryingOffsetsIfNeeded(ctx, ptr, GetType(), getElementType());
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user