Initial plumbing to add CollectionType base-class as common ancestor

to StructTypes, ArrayTypes, and VectorTypes.  Issue #37.
This commit is contained in:
Matt Pharr
2011-06-29 07:41:35 +01:00
parent b4068efcfb
commit 214fb3197a
6 changed files with 83 additions and 52 deletions

12
ctx.cpp
View File

@@ -1359,10 +1359,10 @@ FunctionEmitContext::gather(llvm::Value *lvalue, const Type *type,
// If we're gathering structures, do an element-wise gather // If we're gathering structures, do an element-wise gather
// recursively. // recursively.
llvm::Value *retValue = llvm::UndefValue::get(retType); llvm::Value *retValue = llvm::UndefValue::get(retType);
for (int i = 0; i < st->NumElements(); ++i) { for (int i = 0; i < st->GetElementCount(); ++i) {
llvm::Value *eltPtrs = GetElementPtrInst(lvalue, 0, i); llvm::Value *eltPtrs = GetElementPtrInst(lvalue, 0, i);
// This in turn will be another gather // This in turn will be another gather
llvm::Value *eltValues = LoadInst(eltPtrs, st->GetMemberType(i), llvm::Value *eltValues = LoadInst(eltPtrs, st->GetElementType(i),
name); name);
retValue = InsertInst(retValue, eltValues, i, "set_value"); retValue = InsertInst(retValue, eltValues, i, "set_value");
} }
@@ -1519,12 +1519,12 @@ FunctionEmitContext::maskedStore(llvm::Value *rvalue, llvm::Value *lvalue,
const StructType *structType = dynamic_cast<const StructType *>(rvalueType); const StructType *structType = dynamic_cast<const StructType *>(rvalueType);
if (structType != NULL) { if (structType != NULL) {
// Assigning a structure // Assigning a structure
for (int i = 0; i < structType->NumElements(); ++i) { for (int i = 0; i < structType->GetElementCount(); ++i) {
llvm::Value *eltValue = ExtractInst(rvalue, i, "rvalue_member"); llvm::Value *eltValue = ExtractInst(rvalue, i, "rvalue_member");
llvm::Value *eltLValue = GetElementPtrInst(lvalue, 0, i, llvm::Value *eltLValue = GetElementPtrInst(lvalue, 0, i,
"struct_lvalue_ptr"); "struct_lvalue_ptr");
StoreInst(eltValue, eltLValue, storeMask, StoreInst(eltValue, eltLValue, storeMask,
structType->GetMemberType(i)); structType->GetElementType(i));
} }
return; return;
} }
@@ -1598,10 +1598,10 @@ FunctionEmitContext::scatter(llvm::Value *rvalue, llvm::Value *lvalue,
const StructType *structType = dynamic_cast<const StructType *>(rvalueType); const StructType *structType = dynamic_cast<const StructType *>(rvalueType);
if (structType) { if (structType) {
// Scatter the struct elements individually // Scatter the struct elements individually
for (int i = 0; i < structType->NumElements(); ++i) { for (int i = 0; i < structType->GetElementCount(); ++i) {
llvm::Value *lv = GetElementPtrInst(lvalue, 0, i); llvm::Value *lv = GetElementPtrInst(lvalue, 0, i);
llvm::Value *rv = ExtractInst(rvalue, i); llvm::Value *rv = ExtractInst(rvalue, i);
scatter(rv, lv, storeMask, structType->GetMemberType(i)); scatter(rv, lv, storeMask, structType->GetElementType(i));
} }
return; return;
} }

View File

@@ -1526,7 +1526,7 @@ AssignExpr::GetValue(FunctionEmitContext *ctx) const {
if (st != NULL) { if (st != NULL) {
bool anyUniform = false; bool anyUniform = false;
for (int i = 0; i < st->NumElements(); ++i) { for (int i = 0; i < st->NumElements(); ++i) {
if (st->GetMemberType(i)->IsUniformType()) if (st->GetElementType(i)->IsUniformType())
anyUniform = true; anyUniform = true;
} }
@@ -2498,10 +2498,10 @@ ExprList::GetConstant(const Type *type) const {
// same number of elements in the ExprList as the struct has // same number of elements in the ExprList as the struct has
// members (and the various elements line up with the shape of the // members (and the various elements line up with the shape of the
// corresponding struct elements). // corresponding struct elements).
if ((int)exprs.size() != structType->NumElements()) { if ((int)exprs.size() != structType->GetElementCount()) {
Error(pos, "Initializer list for struct \"%s\" must have %d " Error(pos, "Initializer list for struct \"%s\" must have %d "
"elements (has %d).", structType->GetString().c_str(), "elements (has %d).", structType->GetString().c_str(),
(int)exprs.size(), structType->NumElements()); (int)exprs.size(), structType->GetElementCount());
return NULL; return NULL;
} }
@@ -2509,7 +2509,7 @@ ExprList::GetConstant(const Type *type) const {
for (unsigned int i = 0; i < exprs.size(); ++i) { for (unsigned int i = 0; i < exprs.size(); ++i) {
if (exprs[i] == NULL) if (exprs[i] == NULL)
return NULL; return NULL;
const Type *elementType = structType->GetMemberType(i); const Type *elementType = structType->GetElementType(i);
llvm::Constant *c = exprs[i]->GetConstant(elementType); llvm::Constant *c = exprs[i]->GetConstant(elementType);
if (c == NULL) if (c == NULL)
// If this list element couldn't convert to the right // If this list element couldn't convert to the right
@@ -2832,7 +2832,7 @@ MemberExpr::GetType() const {
// Otherwise it's a struct, and the result type is the element // Otherwise it's a struct, and the result type is the element
// type, possibly promoted to varying if the struct type / lvalue // type, possibly promoted to varying if the struct type / lvalue
// is varying. // is varying.
const Type *elementType = structType->GetMemberType(identifier); const Type *elementType = structType->GetElementType(identifier);
if (!elementType) if (!elementType)
Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s", Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s",
identifier.c_str(), structType->GetString().c_str(), identifier.c_str(), structType->GetString().c_str(),
@@ -2912,7 +2912,7 @@ MemberExpr::getElementNumber() const {
} }
} }
else { else {
elementNumber = structType->GetMemberNumber(identifier); elementNumber = structType->GetElementNumber(identifier);
if (elementNumber == -1) if (elementNumber == -1)
Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s", Error(identifierPos, "Element name \"%s\" not present in struct type \"%s\".%s",
identifier.c_str(), structType->GetString().c_str(), identifier.c_str(), structType->GetString().c_str(),
@@ -3004,7 +3004,7 @@ MemberExpr::getCandidateNearMatches() const {
return ""; return "";
std::vector<std::string> elementNames; std::vector<std::string> elementNames;
for (int i = 0; i < structType->NumElements(); ++i) for (int i = 0; i < structType->GetElementCount(); ++i)
elementNames.push_back(structType->GetElementName(i)); elementNames.push_back(structType->GetElementName(i));
std::vector<std::string> alternates = MatchStrings(identifier, elementNames); std::vector<std::string> alternates = MatchStrings(identifier, elementNames);
if (!alternates.size()) if (!alternates.size())
@@ -3904,9 +3904,9 @@ lUniformValueToVarying(FunctionEmitContext *ctx, llvm::Value *value,
// needed) and populate the return struct // needed) and populate the return struct
const StructType *structType = dynamic_cast<const StructType *>(type); const StructType *structType = dynamic_cast<const StructType *>(type);
if (structType != NULL) { if (structType != NULL) {
for (int i = 0; i < structType->NumElements(); ++i) { for (int i = 0; i < structType->GetElementCount(); ++i) {
llvm::Value *v = ctx->ExtractInst(value, i, "struct_element"); llvm::Value *v = ctx->ExtractInst(value, i, "struct_element");
v = lUniformValueToVarying(ctx, v, structType->GetMemberType(i)); v = lUniformValueToVarying(ctx, v, structType->GetElementType(i));
retValue = ctx->InsertInst(retValue, v, i, "set_struct_element"); retValue = ctx->InsertInst(retValue, v, i, "set_struct_element");
} }
return retValue; return retValue;

View File

@@ -248,8 +248,8 @@ lRecursiveCheckVarying(const Type *t) {
const StructType *st = dynamic_cast<const StructType *>(t); const StructType *st = dynamic_cast<const StructType *>(t);
if (st) { if (st) {
for (int i = 0; i < st->NumElements(); ++i) for (int i = 0; i < st->GetElementCount(); ++i)
if (lRecursiveCheckVarying(st->GetMemberType(i))) if (lRecursiveCheckVarying(st->GetElementType(i)))
return true; return true;
} }
return false; return false;
@@ -1041,8 +1041,8 @@ Module::writeObjectFileOrAssembly(OutputType outputType, const char *outFileName
static void static void
lRecursiveAddStructs(const StructType *structType, lRecursiveAddStructs(const StructType *structType,
std::vector<const StructType *> &structParamTypes) { std::vector<const StructType *> &structParamTypes) {
for (int i = 0; i < structType->NumElements(); ++i) { for (int i = 0; i < structType->GetElementCount(); ++i) {
const Type *elementBaseType = structType->GetMemberType(i)->GetBaseType(); const Type *elementBaseType = structType->GetElementType(i)->GetBaseType();
const StructType *elementStructType = const StructType *elementStructType =
dynamic_cast<const StructType *>(elementBaseType); dynamic_cast<const StructType *>(elementBaseType);
if (elementStructType != NULL) { if (elementStructType != NULL) {
@@ -1112,9 +1112,9 @@ lEmitStructDecls(std::vector<const StructType *> &structTypes, FILE *file) {
StructDAGNode *node = new StructDAGNode; StructDAGNode *node = new StructDAGNode;
structToNode[st] = node; structToNode[st] = node;
for (int j = 0; j < st->NumElements(); ++j) { for (int j = 0; j < st->GetElementCount(); ++j) {
const StructType *elementStructType = const StructType *elementStructType =
dynamic_cast<const StructType *>(st->GetMemberType(j)); dynamic_cast<const StructType *>(st->GetElementType(j));
// If this element is a struct type and we haven't already // If this element is a struct type and we haven't already
// processed it for the current struct type, then upate th // processed it for the current struct type, then upate th
// dependencies and record that this element type has other // dependencies and record that this element type has other
@@ -1144,8 +1144,8 @@ lEmitStructDecls(std::vector<const StructType *> &structTypes, FILE *file) {
for (unsigned int i = 0; i < sortedTypes.size(); ++i) { for (unsigned int i = 0; i < sortedTypes.size(); ++i) {
const StructType *st = sortedTypes[i]; const StructType *st = sortedTypes[i];
fprintf(file, "struct %s {\n", st->GetStructName().c_str()); fprintf(file, "struct %s {\n", st->GetStructName().c_str());
for (int j = 0; j < st->NumElements(); ++j) { for (int j = 0; j < st->GetElementCount(); ++j) {
const Type *type = st->GetMemberType(j)->GetAsNonConstType(); const Type *type = st->GetElementType(j)->GetAsNonConstType();
std::string d = type->GetCDeclaration(st->GetElementName(j)); std::string d = type->GetCDeclaration(st->GetElementName(j));
fprintf(file, " %s;\n", d.c_str()); fprintf(file, " %s;\n", d.c_str());
} }
@@ -1210,8 +1210,8 @@ lGetVectorsFromStructs(const std::vector<const StructType *> &structParamTypes,
std::vector<const VectorType *> *vectorParamTypes) { std::vector<const VectorType *> *vectorParamTypes) {
for (unsigned int i = 0; i < structParamTypes.size(); ++i) { for (unsigned int i = 0; i < structParamTypes.size(); ++i) {
const StructType *structType = structParamTypes[i]; const StructType *structType = structParamTypes[i];
for (int j = 0; j < structType->NumElements(); ++j) { for (int j = 0; j < structType->GetElementCount(); ++j) {
const Type *elementType = structType->GetMemberType(j); const Type *elementType = structType->GetElementType(j);
const ArrayType *at = dynamic_cast<const ArrayType *>(elementType); const ArrayType *at = dynamic_cast<const ArrayType *>(elementType);
if (at) if (at)

View File

@@ -233,16 +233,16 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type,
// The { ... } case; make sure we have the same number of // The { ... } case; make sure we have the same number of
// expressions in the ExprList as we have struct members // expressions in the ExprList as we have struct members
int nInits = exprList->exprs.size(); int nInits = exprList->exprs.size();
if (nInits != st->NumElements()) if (nInits != st->GetElementCount())
Error(initExpr->pos, Error(initExpr->pos,
"Initializer for struct \"%s\" requires %d values; %d provided.", "Initializer for struct \"%s\" requires %d values; %d provided.",
symName, st->NumElements(), nInits); symName, st->GetElementCount(), nInits);
else { else {
// Initialize each struct member with the corresponding // Initialize each struct member with the corresponding
// value from the ExprList // value from the ExprList
for (int i = 0; i < nInits; ++i) { for (int i = 0; i < nInits; ++i) {
llvm::Value *ep = ctx->GetElementPtrInst(lvalue, 0, i, "structelement"); llvm::Value *ep = ctx->GetElementPtrInst(lvalue, 0, i, "structelement");
lInitSymbol(ep, symName, st->GetMemberType(i), exprList->exprs[i], lInitSymbol(ep, symName, st->GetElementType(i), exprList->exprs[i],
ctx, pos); ctx, pos);
} }
} }
@@ -251,9 +251,9 @@ lInitSymbol(llvm::Value *lvalue, const char *symName, const Type *type,
initExpr->GetType()->IsBoolType()) { initExpr->GetType()->IsBoolType()) {
// Otherwise initialize all of the struct elements in turn with // Otherwise initialize all of the struct elements in turn with
// the initExpr. // the initExpr.
for (int i = 0; i < st->NumElements(); ++i) { for (int i = 0; i < st->GetElementCount(); ++i) {
llvm::Value *ep = ctx->GetElementPtrInst(lvalue, 0, i, "structelement"); llvm::Value *ep = ctx->GetElementPtrInst(lvalue, 0, i, "structelement");
lInitSymbol(ep, symName, st->GetMemberType(i), initExpr, ctx, pos); lInitSymbol(ep, symName, st->GetElementType(i), initExpr, ctx, pos);
} }
} }
else { else {

View File

@@ -410,6 +410,14 @@ AtomicType::GetDIType(llvm::DIDescriptor scope) const {
} }
///////////////////////////////////////////////////////////////////////////
// SequentialType
const Type *SequentialType::GetElementType(int index) const {
return GetElementType();
}
/////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////
// ArrayType // ArrayType
@@ -1035,8 +1043,8 @@ StructType::GetSOAType(int width) const {
std::vector<const Type *> et; std::vector<const Type *> et;
// The SOA version of a structure is just a structure that holds SOAed // The SOA version of a structure is just a structure that holds SOAed
// versions of its elements // versions of its elements
for (int i = 0; i < NumElements(); ++i) { for (int i = 0; i < GetElementCount(); ++i) {
const Type *t = GetMemberType(i); const Type *t = GetElementType(i);
et.push_back(t->GetSOAType(width)); et.push_back(t->GetSOAType(width));
} }
return new StructType(name, et, elementNames, elementPositions, return new StructType(name, et, elementNames, elementPositions,
@@ -1125,8 +1133,8 @@ StructType::GetCDeclaration(const std::string &n) const {
const llvm::Type * const llvm::Type *
StructType::LLVMType(llvm::LLVMContext *ctx) const { StructType::LLVMType(llvm::LLVMContext *ctx) const {
std::vector<const llvm::Type *> llvmTypes; std::vector<const llvm::Type *> llvmTypes;
for (int i = 0; i < NumElements(); ++i) { for (int i = 0; i < GetElementCount(); ++i) {
const Type *type = GetMemberType(i); const Type *type = GetElementType(i);
llvmTypes.push_back(type->LLVMType(ctx)); llvmTypes.push_back(type->LLVMType(ctx));
} }
return llvm::StructType::get(*ctx, llvmTypes); return llvm::StructType::get(*ctx, llvmTypes);
@@ -1146,7 +1154,7 @@ StructType::GetDIType(llvm::DIDescriptor scope) const {
// alignment and size, using that to figure out its offset w.r.t. the // alignment and size, using that to figure out its offset w.r.t. the
// start of the structure. // start of the structure.
for (unsigned int i = 0; i < elementTypes.size(); ++i) { for (unsigned int i = 0; i < elementTypes.size(); ++i) {
llvm::DIType eltType = GetMemberType(i)->GetDIType(scope); llvm::DIType eltType = GetElementType(i)->GetDIType(scope);
uint64_t eltAlign = eltType.getAlignInBits(); uint64_t eltAlign = eltType.getAlignInBits();
uint64_t eltSize = eltType.getSizeInBits(); uint64_t eltSize = eltType.getSizeInBits();
@@ -1197,7 +1205,7 @@ StructType::GetDIType(llvm::DIDescriptor scope) const {
const Type * const Type *
StructType::GetMemberType(int i) const { StructType::GetElementType(int i) const {
assert(i < (int)elementTypes.size()); assert(i < (int)elementTypes.size());
// If the struct is uniform qualified, then each member comes out with // If the struct is uniform qualified, then each member comes out with
// the same type as in the original source file. If it's varying, then // the same type as in the original source file. If it's varying, then
@@ -1209,7 +1217,7 @@ StructType::GetMemberType(int i) const {
const Type * const Type *
StructType::GetMemberType(const std::string &n) const { StructType::GetElementType(const std::string &n) const {
for (unsigned int i = 0; i < elementNames.size(); ++i) for (unsigned int i = 0; i < elementNames.size(); ++i)
if (elementNames[i] == n) { if (elementNames[i] == n) {
const Type *ret = isUniform ? elementTypes[i] : const Type *ret = isUniform ? elementTypes[i] :
@@ -1221,7 +1229,7 @@ StructType::GetMemberType(const std::string &n) const {
int int
StructType::GetMemberNumber(const std::string &n) const { StructType::GetElementNumber(const std::string &n) const {
for (unsigned int i = 0; i < elementNames.size(); ++i) for (unsigned int i = 0; i < elementNames.size(); ++i)
if (elementNames[i] == n) if (elementNames[i] == n)
return i; return i;
@@ -1775,10 +1783,10 @@ Type::Equal(const Type *a, const Type *b) {
const StructType *sta = dynamic_cast<const StructType *>(a); const StructType *sta = dynamic_cast<const StructType *>(a);
const StructType *stb = dynamic_cast<const StructType *>(b); const StructType *stb = dynamic_cast<const StructType *>(b);
if (sta && stb) { if (sta && stb) {
if (sta->NumElements() != stb->NumElements()) if (sta->GetElementCount() != stb->GetElementCount())
return false; return false;
for (int i = 0; i < sta->NumElements(); ++i) for (int i = 0; i < sta->GetElementCount(); ++i)
if (!Equal(sta->GetMemberType(i), stb->GetMemberType(i))) if (!Equal(sta->GetElementType(i), stb->GetElementType(i)))
return false; return false;
return true; return true;
} }

45
type.h
View File

@@ -243,19 +243,42 @@ private:
}; };
/** @brief Abstract base class for tpyes that represent sequences /** @brief Abstract base class for types that represent collections of
other types.
This is a common base class that StructTypes, ArrayTypes, and
VectorTypes all inherit from.
*/
class CollectionType : public Type {
public:
/** Returns the total number of elements in the collection. */
virtual int GetElementCount() const = 0;
/** Returns the type of the element given by index. (The value of
index must be between 0 and GetElementCount()-1.
*/
virtual const Type *GetElementType(int index) const = 0;
};
/** @brief Abstract base class for types that represent sequences
SequentialType is an abstract base class that adds interface routines SequentialType is an abstract base class that adds interface routines
for types that represent linear sequences of other types (i.e., arrays for types that represent linear sequences of other types (i.e., arrays
and vectors). and vectors).
*/ */
class SequentialType : public Type { class SequentialType : public CollectionType {
public: public:
/** Returns the total number of elements in the sequence. */ /** Returns the Type of the elements that the sequence stores; for
virtual int GetElementCount() const = 0; SequentialTypes, all elements have the same type . */
/** Returns the Type of the elements that the sequence stores. */
virtual const Type *GetElementType() const = 0; virtual const Type *GetElementType() const = 0;
/** SequentialType provides an implementation of this CollectionType
method, just passing the query on to the GetElementType(void)
implementation, since all of the elements of a SequentialType have
the same type.
*/
const Type *GetElementType(int index) const;
}; };
@@ -439,7 +462,7 @@ private:
/** @brief Representation of a structure holding a number of members. /** @brief Representation of a structure holding a number of members.
*/ */
class StructType : public Type { class StructType : public CollectionType {
public: public:
StructType(const std::string &name, const std::vector<const Type *> &elts, StructType(const std::string &name, const std::vector<const Type *> &elts,
const std::vector<std::string> &eltNames, const std::vector<std::string> &eltNames,
@@ -469,21 +492,21 @@ public:
/** Returns the type of the structure element with the given name (if any). /** Returns the type of the structure element with the given name (if any).
Returns NULL if there is no such named element. */ Returns NULL if there is no such named element. */
const Type *GetMemberType(const std::string &name) const; const Type *GetElementType(const std::string &name) const;
/** Returns the type of the i'th structure element. The value of \c i must /** Returns the type of the i'th structure element. The value of \c i must
be between 0 and NumElements()-1. */ be between 0 and NumElements()-1. */
const Type *GetMemberType(int i) const; const Type *GetElementType(int i) const;
/** Returns which structure element number (starting from zero) that /** Returns which structure element number (starting from zero) that
has the given name. If there is no such element, return -1. */ has the given name. If there is no such element, return -1. */
int GetMemberNumber(const std::string &name) const; int GetElementNumber(const std::string &name) const;
/** Returns the name of the i'th element of the structure. */ /** Returns the name of the i'th element of the structure. */
const std::string GetElementName(int i) const { return elementNames[i]; } const std::string GetElementName(int i) const { return elementNames[i]; }
/** Returns the total number of elements in the structure. */ /** Returns the total number of elements in the structure. */
int NumElements() const { return int(elementTypes.size()); } int GetElementCount() const { return int(elementTypes.size()); }
/** Returns the name of the structure type. (e.g. struct Foo -> "Foo".) */ /** Returns the name of the structure type. (e.g. struct Foo -> "Foo".) */
const std::string &GetStructName() const { return name; } const std::string &GetStructName() const { return name; }