Stop using dynamic_cast for Types.

We now have a set of template functions CastType<AtomicType>, etc., that in
turn use a new typeId field in each Type instance, allowing them to be inlined
and to be quite efficient.

This improves front-end performance for a particular large program by 28%.
This commit is contained in:
Matt Pharr
2012-05-04 11:12:33 -07:00
parent c756c855ea
commit 944c53bff1
11 changed files with 539 additions and 425 deletions

12
ast.cpp
View File

@@ -356,10 +356,10 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
return false; return false;
const Type *type = fce->func->GetType(); const Type *type = fce->func->GetType();
const PointerType *pt = dynamic_cast<const PointerType *>(type); const PointerType *pt = CastType<PointerType>(type);
if (pt != NULL) if (pt != NULL)
type = pt->GetBaseType(); type = pt->GetBaseType();
const FunctionType *ftype = dynamic_cast<const FunctionType *>(type); const FunctionType *ftype = CastType<FunctionType>(type);
Assert(ftype != NULL); Assert(ftype != NULL);
if (ftype->isSafe == false) { if (ftype->isSafe == false) {
@@ -405,7 +405,7 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
const Type *type = ie->baseExpr->GetType(); const Type *type = ie->baseExpr->GetType();
if (type == NULL) if (type == NULL)
return true; return true;
if (dynamic_cast<const ReferenceType *>(type) != NULL) if (CastType<ReferenceType>(type) != NULL)
type = type->GetReferenceTarget(); type = type->GetReferenceTarget();
ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index); ConstExpr *ce = dynamic_cast<ConstExpr *>(ie->index);
@@ -415,16 +415,14 @@ lCheckAllOffSafety(ASTNode *node, void *data) {
return false; return false;
} }
const PointerType *pointerType = const PointerType *pointerType = CastType<PointerType>(type);
dynamic_cast<const PointerType *>(type);
if (pointerType != NULL) { if (pointerType != NULL) {
// pointer[index] -> can't be sure -> not safe // pointer[index] -> can't be sure -> not safe
*okPtr = false; *okPtr = false;
return false; return false;
} }
const SequentialType *seqType = const SequentialType *seqType = CastType<SequentialType>(type);
dynamic_cast<const SequentialType *>(type);
Assert(seqType != NULL); Assert(seqType != NULL);
int nElements = seqType->GetElementCount(); int nElements = seqType->GetElementCount();
if (nElements == 0) { if (nElements == 0) {

72
ctx.cpp
View File

@@ -1194,7 +1194,7 @@ FunctionEmitContext::CurrentLanesReturned(Expr *expr, bool doCoherenceCheck) {
llvm::Value *retVal = expr->GetValue(this); llvm::Value *retVal = expr->GetValue(this);
if (retVal != NULL) { if (retVal != NULL) {
if (returnType->IsUniformType() || if (returnType->IsUniformType() ||
dynamic_cast<const ReferenceType *>(returnType) != NULL) CastType<ReferenceType>(returnType) != NULL)
StoreInst(retVal, returnValuePtr); StoreInst(retVal, returnValuePtr);
else { else {
// Use a masked store to store the value of the expression // Use a masked store to store the value of the expression
@@ -2063,10 +2063,10 @@ FunctionEmitContext::GetElementPtrInst(llvm::Value *basePtr, llvm::Value *index,
// Regularize to a standard pointer type for basePtr's type // Regularize to a standard pointer type for basePtr's type
const PointerType *ptrType; const PointerType *ptrType;
if (dynamic_cast<const ReferenceType *>(ptrRefType) != NULL) if (CastType<ReferenceType>(ptrRefType) != NULL)
ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget()); ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget());
else { else {
ptrType = dynamic_cast<const PointerType *>(ptrRefType); ptrType = CastType<PointerType>(ptrRefType);
Assert(ptrType != NULL); Assert(ptrType != NULL);
} }
@@ -2133,10 +2133,10 @@ FunctionEmitContext::GetElementPtrInst(llvm::Value *basePtr, llvm::Value *index0
// Regaularize the pointer type for basePtr // Regaularize the pointer type for basePtr
const PointerType *ptrType = NULL; const PointerType *ptrType = NULL;
if (dynamic_cast<const ReferenceType *>(ptrRefType) != NULL) if (CastType<ReferenceType>(ptrRefType) != NULL)
ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget()); ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget());
else { else {
ptrType = dynamic_cast<const PointerType *>(ptrRefType); ptrType = CastType<PointerType>(ptrRefType);
Assert(ptrType != NULL); Assert(ptrType != NULL);
} }
@@ -2184,7 +2184,7 @@ FunctionEmitContext::GetElementPtrInst(llvm::Value *basePtr, llvm::Value *index0
// Now index into the second dimension with index1. First figure // Now index into the second dimension with index1. First figure
// out the type of ptr0. // out the type of ptr0.
const Type *baseType = ptrType->GetBaseType(); const Type *baseType = ptrType->GetBaseType();
const SequentialType *st = dynamic_cast<const SequentialType *>(baseType); const SequentialType *st = CastType<SequentialType>(baseType);
Assert(st != NULL); Assert(st != NULL);
bool ptr0IsUniform = bool ptr0IsUniform =
@@ -2211,10 +2211,10 @@ FunctionEmitContext::AddElementOffset(llvm::Value *fullBasePtr, int elementNum,
const PointerType *ptrType = NULL; const PointerType *ptrType = NULL;
if (ptrRefType != NULL) { if (ptrRefType != NULL) {
// Normalize references to uniform pointers // Normalize references to uniform pointers
if (dynamic_cast<const ReferenceType *>(ptrRefType) != NULL) if (CastType<ReferenceType>(ptrRefType) != NULL)
ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget()); ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget());
else else
ptrType = dynamic_cast<const PointerType *>(ptrRefType); ptrType = CastType<PointerType>(ptrRefType);
Assert(ptrType != NULL); Assert(ptrType != NULL);
} }
@@ -2240,8 +2240,8 @@ FunctionEmitContext::AddElementOffset(llvm::Value *fullBasePtr, int elementNum,
// want it. // want it.
if (resultPtrType != NULL) { if (resultPtrType != NULL) {
Assert(ptrType != NULL); Assert(ptrType != NULL);
const CollectionType *ct = const CollectionType *ct =
dynamic_cast<const CollectionType *>(ptrType->GetBaseType()); CastType<CollectionType>(ptrType->GetBaseType());
Assert(ct != NULL); Assert(ct != NULL);
*resultPtrType = new PointerType(ct->GetElementType(elementNum), *resultPtrType = new PointerType(ct->GetElementType(elementNum),
ptrType->GetVariability(), ptrType->GetVariability(),
@@ -2261,8 +2261,7 @@ FunctionEmitContext::AddElementOffset(llvm::Value *fullBasePtr, int elementNum,
else { else {
// Otherwise do the math to find the offset and add it to the given // Otherwise do the math to find the offset and add it to the given
// varying pointers // varying pointers
const StructType *st = const StructType *st = CastType<StructType>(ptrType->GetBaseType());
dynamic_cast<const StructType *>(ptrType->GetBaseType());
llvm::Value *offset = NULL; llvm::Value *offset = NULL;
if (st != NULL) if (st != NULL)
// If the pointer is to a structure, Target::StructOffset() gives // If the pointer is to a structure, Target::StructOffset() gives
@@ -2273,8 +2272,8 @@ FunctionEmitContext::AddElementOffset(llvm::Value *fullBasePtr, int elementNum,
// Otherwise we should have a vector or array here and the offset // Otherwise we should have a vector or array here and the offset
// is given by the element number times the size of the element // is given by the element number times the size of the element
// type of the vector. // type of the vector.
const SequentialType *st = const SequentialType *st =
dynamic_cast<const SequentialType *>(ptrType->GetBaseType()); CastType<SequentialType>(ptrType->GetBaseType());
Assert(st != NULL); Assert(st != NULL);
llvm::Value *size = llvm::Value *size =
g->target.SizeOf(st->GetElementType()->LLVMType(g->ctx), bblock); g->target.SizeOf(st->GetElementType()->LLVMType(g->ctx), bblock);
@@ -2340,7 +2339,7 @@ FunctionEmitContext::LoadInst(llvm::Value *ptr, const char *name) {
static llvm::Value * static llvm::Value *
lFinalSliceOffset(FunctionEmitContext *ctx, llvm::Value *ptr, lFinalSliceOffset(FunctionEmitContext *ctx, llvm::Value *ptr,
const PointerType **ptrType) { const PointerType **ptrType) {
Assert(dynamic_cast<const PointerType *>(*ptrType) != NULL); Assert(CastType<PointerType>(*ptrType) != NULL);
llvm::Value *slicePtr = ctx->ExtractInst(ptr, 0, LLVMGetName(ptr, "_ptr")); llvm::Value *slicePtr = ctx->ExtractInst(ptr, 0, LLVMGetName(ptr, "_ptr"));
llvm::Value *sliceOffset = ctx->ExtractInst(ptr, 1, LLVMGetName(ptr, "_offset")); llvm::Value *sliceOffset = ctx->ExtractInst(ptr, 1, LLVMGetName(ptr, "_offset"));
@@ -2377,8 +2376,7 @@ FunctionEmitContext::loadUniformFromSOA(llvm::Value *ptr, llvm::Value *mask,
const char *name) { const char *name) {
const Type *unifType = ptrType->GetBaseType()->GetAsUniformType(); const Type *unifType = ptrType->GetBaseType()->GetAsUniformType();
const CollectionType *ct = const CollectionType *ct = CastType<CollectionType>(ptrType->GetBaseType());
dynamic_cast<const CollectionType *>(ptrType->GetBaseType());
if (ct != NULL) { if (ct != NULL) {
// If we have a struct/array, we need to decompose it into // If we have a struct/array, we need to decompose it into
// individual element loads to fill in the result structure since // individual element loads to fill in the result structure since
@@ -2420,10 +2418,10 @@ FunctionEmitContext::LoadInst(llvm::Value *ptr, llvm::Value *mask,
name = LLVMGetName(ptr, "_load"); name = LLVMGetName(ptr, "_load");
const PointerType *ptrType; const PointerType *ptrType;
if (dynamic_cast<const ReferenceType *>(ptrRefType) != NULL) if (CastType<ReferenceType>(ptrRefType) != NULL)
ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget()); ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget());
else { else {
ptrType = dynamic_cast<const PointerType *>(ptrRefType); ptrType = CastType<PointerType>(ptrRefType);
Assert(ptrType != NULL); Assert(ptrType != NULL);
} }
@@ -2440,8 +2438,8 @@ FunctionEmitContext::LoadInst(llvm::Value *ptr, llvm::Value *mask,
// atomic types, we need to make sure that the compiler emits // atomic types, we need to make sure that the compiler emits
// unaligned vector loads, so we specify a reduced alignment here. // unaligned vector loads, so we specify a reduced alignment here.
int align = 0; int align = 0;
const AtomicType *atomicType = const AtomicType *atomicType =
dynamic_cast<const AtomicType *>(ptrType->GetBaseType()); CastType<AtomicType>(ptrType->GetBaseType());
if (atomicType != NULL && atomicType->IsVaryingType()) if (atomicType != NULL && atomicType->IsVaryingType())
// We actually just want to align to the vector element // We actually just want to align to the vector element
// alignment, but can't easily get that here, so just tell LLVM // alignment, but can't easily get that here, so just tell LLVM
@@ -2473,7 +2471,7 @@ FunctionEmitContext::gather(llvm::Value *ptr, const PointerType *ptrType,
llvm::Type *llvmReturnType = returnType->LLVMType(g->ctx); llvm::Type *llvmReturnType = returnType->LLVMType(g->ctx);
const CollectionType *collectionType = const CollectionType *collectionType =
dynamic_cast<const CollectionType *>(ptrType->GetBaseType()); CastType<CollectionType>(ptrType->GetBaseType());
if (collectionType != NULL) { if (collectionType != NULL) {
// For collections, recursively gather element wise to find the // For collections, recursively gather element wise to find the
// result. // result.
@@ -2508,7 +2506,7 @@ FunctionEmitContext::gather(llvm::Value *ptr, const PointerType *ptrType,
// Figure out which gather function to call based on the size of // Figure out which gather function to call based on the size of
// the elements. // the elements.
const PointerType *pt = dynamic_cast<const PointerType *>(returnType); const PointerType *pt = CastType<PointerType>(returnType);
const char *funcName = NULL; const char *funcName = NULL;
if (pt != NULL) if (pt != NULL)
funcName = g->target.is32Bit ? "__pseudo_gather32_32" : funcName = g->target.is32Bit ? "__pseudo_gather32_32" :
@@ -2631,12 +2629,11 @@ FunctionEmitContext::maskedStore(llvm::Value *value, llvm::Value *ptr,
return; return;
} }
Assert(dynamic_cast<const PointerType *>(ptrType) != NULL); Assert(CastType<PointerType>(ptrType) != NULL);
Assert(ptrType->IsUniformType()); Assert(ptrType->IsUniformType());
const Type *valueType = ptrType->GetBaseType(); const Type *valueType = ptrType->GetBaseType();
const CollectionType *collectionType = const CollectionType *collectionType = CastType<CollectionType>(valueType);
dynamic_cast<const CollectionType *>(valueType);
if (collectionType != NULL) { if (collectionType != NULL) {
// Assigning a structure / array / vector. Handle each element // Assigning a structure / array / vector. Handle each element
// individually with what turns into a recursive call to // individually with what turns into a recursive call to
@@ -2660,7 +2657,7 @@ FunctionEmitContext::maskedStore(llvm::Value *value, llvm::Value *ptr,
// Figure out if we need a 8, 16, 32 or 64-bit masked store. // Figure out if we need a 8, 16, 32 or 64-bit masked store.
llvm::Function *maskedStoreFunc = NULL; llvm::Function *maskedStoreFunc = NULL;
const PointerType *pt = dynamic_cast<const PointerType *>(valueType); const PointerType *pt = CastType<PointerType>(valueType);
if (pt != NULL) { if (pt != NULL) {
if (pt->IsSlice()) { if (pt->IsSlice()) {
// Masked store of (varying) slice pointer. // Masked store of (varying) slice pointer.
@@ -2714,7 +2711,7 @@ FunctionEmitContext::maskedStore(llvm::Value *value, llvm::Value *ptr,
Type::Equal(valueType, AtomicType::VaryingBool) || Type::Equal(valueType, AtomicType::VaryingBool) ||
Type::Equal(valueType, AtomicType::VaryingInt32) || Type::Equal(valueType, AtomicType::VaryingInt32) ||
Type::Equal(valueType, AtomicType::VaryingUInt32) || Type::Equal(valueType, AtomicType::VaryingUInt32) ||
dynamic_cast<const EnumType *>(valueType) != NULL) { CastType<EnumType>(valueType) != NULL) {
maskedStoreFunc = m->module->getFunction("__pseudo_masked_store_32"); maskedStoreFunc = m->module->getFunction("__pseudo_masked_store_32");
ptr = BitCastInst(ptr, LLVMTypes::Int32VectorPointerType, ptr = BitCastInst(ptr, LLVMTypes::Int32VectorPointerType,
LLVMGetName(ptr, "_to_int32vecptr")); LLVMGetName(ptr, "_to_int32vecptr"));
@@ -2755,12 +2752,12 @@ void
FunctionEmitContext::scatter(llvm::Value *value, llvm::Value *ptr, FunctionEmitContext::scatter(llvm::Value *value, llvm::Value *ptr,
const Type *valueType, const Type *origPt, const Type *valueType, const Type *origPt,
llvm::Value *mask) { llvm::Value *mask) {
const PointerType *ptrType = dynamic_cast<const PointerType *>(origPt); const PointerType *ptrType = CastType<PointerType>(origPt);
Assert(ptrType != NULL); Assert(ptrType != NULL);
Assert(ptrType->IsVaryingType()); Assert(ptrType->IsVaryingType());
const CollectionType *srcCollectionType = const CollectionType *srcCollectionType =
dynamic_cast<const CollectionType *>(valueType); CastType<CollectionType>(valueType);
if (srcCollectionType != NULL) { if (srcCollectionType != NULL) {
// We're scattering a collection type--we need to keep track of the // We're scattering a collection type--we need to keep track of the
// source type (the type of the data values to be stored) and the // source type (the type of the data values to be stored) and the
@@ -2771,7 +2768,7 @@ FunctionEmitContext::scatter(llvm::Value *value, llvm::Value *ptr,
// same struct type, versus scattering into an array of varying // same struct type, versus scattering into an array of varying
// instances of the struct type, etc. // instances of the struct type, etc.
const CollectionType *dstCollectionType = const CollectionType *dstCollectionType =
dynamic_cast<const CollectionType *>(ptrType->GetBaseType()); CastType<CollectionType>(ptrType->GetBaseType());
Assert(dstCollectionType != NULL); Assert(dstCollectionType != NULL);
// Scatter the collection elements individually // Scatter the collection elements individually
@@ -2816,11 +2813,10 @@ FunctionEmitContext::scatter(llvm::Value *value, llvm::Value *ptr,
ptr = lFinalSliceOffset(this, ptr, &ptrType); ptr = lFinalSliceOffset(this, ptr, &ptrType);
} }
const PointerType *pt = dynamic_cast<const PointerType *>(valueType); const PointerType *pt = CastType<PointerType>(valueType);
// And everything should be a pointer or atomic from here on out... // And everything should be a pointer or atomic from here on out...
Assert(pt != NULL || Assert(pt != NULL || CastType<AtomicType>(valueType) != NULL);
dynamic_cast<const AtomicType *>(valueType) != NULL);
llvm::Type *type = value->getType(); llvm::Type *type = value->getType();
const char *funcName = NULL; const char *funcName = NULL;
@@ -2896,10 +2892,10 @@ FunctionEmitContext::StoreInst(llvm::Value *value, llvm::Value *ptr,
} }
const PointerType *ptrType; const PointerType *ptrType;
if (dynamic_cast<const ReferenceType *>(ptrRefType) != NULL) if (CastType<ReferenceType>(ptrRefType) != NULL)
ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget()); ptrType = PointerType::GetUniform(ptrRefType->GetReferenceTarget());
else { else {
ptrType = dynamic_cast<const PointerType *>(ptrRefType); ptrType = CastType<PointerType>(ptrRefType);
Assert(ptrType != NULL); Assert(ptrType != NULL);
} }
@@ -2936,7 +2932,7 @@ FunctionEmitContext::storeUniformToSOA(llvm::Value *value, llvm::Value *ptr,
Assert(Type::EqualIgnoringConst(ptrType->GetBaseType()->GetAsUniformType(), Assert(Type::EqualIgnoringConst(ptrType->GetBaseType()->GetAsUniformType(),
valueType)); valueType));
const CollectionType *ct = dynamic_cast<const CollectionType *>(valueType); const CollectionType *ct = CastType<CollectionType>(valueType);
if (ct != NULL) { if (ct != NULL) {
// Handle collections element wise... // Handle collections element wise...
for (int i = 0; i < ct->GetElementCount(); ++i) { for (int i = 0; i < ct->GetElementCount(); ++i) {
@@ -3418,7 +3414,7 @@ llvm::Value *
FunctionEmitContext::addVaryingOffsetsIfNeeded(llvm::Value *ptr, FunctionEmitContext::addVaryingOffsetsIfNeeded(llvm::Value *ptr,
const Type *ptrType) { const Type *ptrType) {
// This should only be called for varying pointers // This should only be called for varying pointers
const PointerType *pt = dynamic_cast<const PointerType *>(ptrType); const PointerType *pt = CastType<PointerType>(ptrType);
Assert(pt && pt->IsVaryingType()); Assert(pt && pt->IsVaryingType());
const Type *baseType = ptrType->GetBaseType(); const Type *baseType = ptrType->GetBaseType();

View File

@@ -136,7 +136,7 @@ DeclSpecs::GetBaseType(SourcePos pos) const {
} }
if (vectorSize > 0) { if (vectorSize > 0) {
const AtomicType *atomicType = dynamic_cast<const AtomicType *>(retType); const AtomicType *atomicType = CastType<AtomicType>(retType);
if (atomicType == NULL) { if (atomicType == NULL) {
Error(pos, "Only atomic types (int, float, ...) are legal for vector " Error(pos, "Only atomic types (int, float, ...) are legal for vector "
"types."); "types.");
@@ -148,7 +148,7 @@ DeclSpecs::GetBaseType(SourcePos pos) const {
retType = lApplyTypeQualifiers(typeQualifiers, retType, pos); retType = lApplyTypeQualifiers(typeQualifiers, retType, pos);
if (soaWidth > 0) { if (soaWidth > 0) {
const StructType *st = dynamic_cast<const StructType *>(retType); const StructType *st = CastType<StructType>(retType);
if (st == NULL) { if (st == NULL) {
Error(pos, "Illegal to provide soa<%d> qualifier with non-struct " Error(pos, "Illegal to provide soa<%d> qualifier with non-struct "
@@ -238,7 +238,7 @@ Declarator::InitFromDeclSpecs(DeclSpecs *ds) {
storageClass = ds->storageClass; storageClass = ds->storageClass;
if (ds->declSpecList.size() > 0 && if (ds->declSpecList.size() > 0 &&
dynamic_cast<const FunctionType *>(type) == NULL) { CastType<FunctionType>(type) == NULL) {
Error(pos, "__declspec specifiers for non-function type \"%s\" are " Error(pos, "__declspec specifiers for non-function type \"%s\" are "
"not used.", type->GetString().c_str()); "not used.", type->GetString().c_str());
} }
@@ -351,7 +351,7 @@ Declarator::InitFromType(const Type *baseType, DeclSpecs *ds) {
return; return;
} }
// The parser should disallow this already, but double check. // The parser should disallow this already, but double check.
if (dynamic_cast<const ReferenceType *>(baseType) != NULL) { if (CastType<ReferenceType>(baseType) != NULL) {
Error(pos, "References to references are illegal."); Error(pos, "References to references are illegal.");
return; return;
} }
@@ -370,7 +370,7 @@ Declarator::InitFromType(const Type *baseType, DeclSpecs *ds) {
Error(pos, "Arrays of \"void\" type are illegal."); Error(pos, "Arrays of \"void\" type are illegal.");
return; return;
} }
if (dynamic_cast<const ReferenceType *>(baseType)) { if (CastType<ReferenceType>(baseType)) {
Error(pos, "Arrays of references (type \"%s\") are illegal.", Error(pos, "Arrays of references (type \"%s\") are illegal.",
baseType->GetString().c_str()); baseType->GetString().c_str());
return; return;
@@ -434,7 +434,7 @@ Declarator::InitFromType(const Type *baseType, DeclSpecs *ds) {
decl->type = NULL; decl->type = NULL;
} }
const ArrayType *at = dynamic_cast<const ArrayType *>(decl->type); const ArrayType *at = CastType<ArrayType>(decl->type);
if (at != NULL) { if (at != NULL) {
// As in C, arrays are passed to functions as pointers to // As in C, arrays are passed to functions as pointers to
// their element type. We'll just immediately make this // their element type. We'll just immediately make this
@@ -454,13 +454,13 @@ Declarator::InitFromType(const Type *baseType, DeclSpecs *ds) {
// Make sure there are no unsized arrays (other than the // Make sure there are no unsized arrays (other than the
// first dimension) in function parameter lists. // first dimension) in function parameter lists.
at = dynamic_cast<const ArrayType *>(targetType); at = CastType<ArrayType>(targetType);
while (at != NULL) { while (at != NULL) {
if (at->GetElementCount() == 0) if (at->GetElementCount() == 0)
Error(decl->pos, "Arrays with unsized dimensions in " Error(decl->pos, "Arrays with unsized dimensions in "
"dimensions after the first one are illegal in " "dimensions after the first one are illegal in "
"function parameter lists."); "function parameter lists.");
at = dynamic_cast<const ArrayType *>(at->GetElementType()); at = CastType<ArrayType>(at->GetElementType());
} }
} }
@@ -497,7 +497,7 @@ Declarator::InitFromType(const Type *baseType, DeclSpecs *ds) {
return; return;
} }
if (dynamic_cast<const FunctionType *>(returnType) != NULL) { if (CastType<FunctionType>(returnType) != NULL) {
Error(pos, "Illegal to return function type from function."); Error(pos, "Illegal to return function type from function.");
return; return;
} }
@@ -596,7 +596,7 @@ Declaration::GetVariableDeclarations() const {
if (Type::Equal(decl->type, AtomicType::Void)) if (Type::Equal(decl->type, AtomicType::Void))
Error(decl->pos, "\"void\" type variable illegal in declaration."); Error(decl->pos, "\"void\" type variable illegal in declaration.");
else if (dynamic_cast<const FunctionType *>(decl->type) == NULL) { else if (CastType<FunctionType>(decl->type) == NULL) {
decl->type = decl->type->ResolveUnboundVariability(Variability::Varying); decl->type = decl->type->ResolveUnboundVariability(Variability::Varying);
Symbol *sym = new Symbol(decl->name, decl->pos, decl->type, Symbol *sym = new Symbol(decl->name, decl->pos, decl->type,
decl->storageClass); decl->storageClass);
@@ -621,8 +621,7 @@ Declaration::DeclareFunctions() {
continue; continue;
} }
const FunctionType *ftype = const FunctionType *ftype = CastType<FunctionType>(decl->type);
dynamic_cast<const FunctionType *>(decl->type);
if (ftype == NULL) if (ftype == NULL)
continue; continue;
@@ -690,8 +689,7 @@ GetStructTypesNamesPositions(const std::vector<StructDeclaration *> &sd,
} }
for (int i = 0; i < (int)elementTypes->size() - 1; ++i) { for (int i = 0; i < (int)elementTypes->size() - 1; ++i) {
const ArrayType *arrayType = const ArrayType *arrayType = CastType<ArrayType>((*elementTypes)[i]);
dynamic_cast<const ArrayType *>((*elementTypes)[i]);
if (arrayType != NULL && arrayType->GetElementCount() == 0) if (arrayType != NULL && arrayType->GetElementCount() == 0)
Error((*elementPositions)[i], "Unsized arrays aren't allowed except " Error((*elementPositions)[i], "Unsized arrays aren't allowed except "

482
expr.cpp

File diff suppressed because it is too large Load Diff

View File

@@ -100,7 +100,7 @@ Function::Function(Symbol *s, Stmt *c) {
printf("\n\n\n"); printf("\n\n\n");
} }
const FunctionType *type = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *type = CastType<FunctionType>(sym->type);
Assert(type != NULL); Assert(type != NULL);
for (int i = 0; i < type->GetNumParameters(); ++i) { for (int i = 0; i < type->GetNumParameters(); ++i) {
@@ -111,7 +111,7 @@ Function::Function(Symbol *s, Stmt *c) {
args.push_back(sym); args.push_back(sym);
const Type *t = type->GetParameterType(i); const Type *t = type->GetParameterType(i);
if (sym != NULL && dynamic_cast<const ReferenceType *>(t) == NULL) if (sym != NULL && CastType<ReferenceType>(t) == NULL)
sym->parentFunction = this; sym->parentFunction = this;
} }
@@ -132,7 +132,7 @@ Function::Function(Symbol *s, Stmt *c) {
const Type * const Type *
Function::GetReturnType() const { Function::GetReturnType() const {
const FunctionType *type = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *type = CastType<FunctionType>(sym->type);
Assert(type != NULL); Assert(type != NULL);
return type->GetReturnType(); return type->GetReturnType();
} }
@@ -140,7 +140,7 @@ Function::GetReturnType() const {
const FunctionType * const FunctionType *
Function::GetType() const { Function::GetType() const {
const FunctionType *type = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *type = CastType<FunctionType>(sym->type);
Assert(type != NULL); Assert(type != NULL);
return type; return type;
} }
@@ -205,7 +205,7 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function,
#if 0 #if 0
llvm::BasicBlock *entryBBlock = ctx->GetCurrentBasicBlock(); llvm::BasicBlock *entryBBlock = ctx->GetCurrentBasicBlock();
#endif #endif
const FunctionType *type = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *type = CastType<FunctionType>(sym->type);
Assert(type != NULL); Assert(type != NULL);
if (type->isTask == true) { if (type->isTask == true) {
// For tasks, we there should always be three parmeters: the // For tasks, we there should always be three parmeters: the
@@ -431,7 +431,7 @@ Function::GenerateIR() {
// If the function is 'export'-qualified, emit a second version of // If the function is 'export'-qualified, emit a second version of
// it without a mask parameter and without name mangling so that // it without a mask parameter and without name mangling so that
// the application can call it // the application can call it
const FunctionType *type = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *type = CastType<FunctionType>(sym->type);
Assert(type != NULL); Assert(type != NULL);
if (type->isExported) { if (type->isExported) {
if (!type->isTask) { if (!type->isTask) {

View File

@@ -368,7 +368,7 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE
if (type == NULL) if (type == NULL)
return; return;
const ArrayType *at = dynamic_cast<const ArrayType *>(type); const ArrayType *at = CastType<ArrayType>(type);
if (at != NULL && at->TotalElementCount() == 0) { if (at != NULL && at->TotalElementCount() == 0) {
Error(pos, "Illegal to declare a global variable with unsized " Error(pos, "Illegal to declare a global variable with unsized "
"array dimensions that aren't set with an initializer " "array dimensions that aren't set with an initializer "
@@ -517,7 +517,7 @@ Module::AddGlobalVariable(const std::string &name, const Type *type, Expr *initE
*/ */
static bool static bool
lRecursiveCheckValidParamType(const Type *t) { lRecursiveCheckValidParamType(const Type *t) {
const StructType *st = dynamic_cast<const StructType *>(t); const StructType *st = CastType<StructType>(t);
if (st != NULL) { if (st != NULL) {
for (int i = 0; i < st->GetElementCount(); ++i) for (int i = 0; i < st->GetElementCount(); ++i)
if (lRecursiveCheckValidParamType(st->GetElementType(i))) if (lRecursiveCheckValidParamType(st->GetElementType(i)))
@@ -525,11 +525,11 @@ lRecursiveCheckValidParamType(const Type *t) {
return false; return false;
} }
const SequentialType *seqt = dynamic_cast<const SequentialType *>(t); const SequentialType *seqt = CastType<SequentialType>(t);
if (seqt != NULL) if (seqt != NULL)
return lRecursiveCheckValidParamType(seqt->GetElementType()); return lRecursiveCheckValidParamType(seqt->GetElementType());
const PointerType *pt = dynamic_cast<const PointerType *>(t); const PointerType *pt = CastType<PointerType>(t);
if (pt != NULL) { if (pt != NULL) {
if (pt->IsSlice() || pt->IsVaryingType()) if (pt->IsSlice() || pt->IsVaryingType())
return true; return true;
@@ -550,7 +550,7 @@ lCheckForVaryingParameter(const Type *type, const std::string &name,
SourcePos pos) { SourcePos pos) {
if (lRecursiveCheckValidParamType(type)) { if (lRecursiveCheckValidParamType(type)) {
const Type *t = type->GetBaseType(); const Type *t = type->GetBaseType();
if (dynamic_cast<const StructType *>(t)) if (CastType<StructType>(t))
Error(pos, "Struct parameter \"%s\" with varying member(s) is illegal " Error(pos, "Struct parameter \"%s\" with varying member(s) is illegal "
"in an exported function.", name.c_str()); "in an exported function.", name.c_str());
else else
@@ -568,7 +568,7 @@ static void
lCheckForStructParameters(const FunctionType *ftype, SourcePos pos) { lCheckForStructParameters(const FunctionType *ftype, SourcePos pos) {
for (int i = 0; i < ftype->GetNumParameters(); ++i) { for (int i = 0; i < ftype->GetNumParameters(); ++i) {
const Type *type = ftype->GetParameterType(i); const Type *type = ftype->GetParameterType(i);
if (dynamic_cast<const StructType *>(type) != NULL) { if (CastType<StructType>(type) != NULL) {
Error(pos, "Passing structs to/from application functions is " Error(pos, "Passing structs to/from application functions is "
"currently broken. Use a pointer or const pointer to the " "currently broken. Use a pointer or const pointer to the "
"struct instead for now."); "struct instead for now.");
@@ -615,7 +615,7 @@ Module::AddFunctionDeclaration(const std::string &name,
// different, return an error--overloading by return type isn't // different, return an error--overloading by return type isn't
// allowed. // allowed.
const FunctionType *ofType = const FunctionType *ofType =
dynamic_cast<const FunctionType *>(overloadFunc->type); CastType<FunctionType>(overloadFunc->type);
Assert(ofType != NULL); Assert(ofType != NULL);
if (ofType->GetNumParameters() == functionType->GetNumParameters()) { if (ofType->GetNumParameters() == functionType->GetNumParameters()) {
int i; int i;
@@ -737,9 +737,9 @@ Module::AddFunctionDeclaration(const std::string &name,
// default.) Set parameter attributes accordingly. (Only for // default.) Set parameter attributes accordingly. (Only for
// uniform pointers, since varying pointers are int vectors...) // uniform pointers, since varying pointers are int vectors...)
if (!functionType->isTask && if (!functionType->isTask &&
((dynamic_cast<const PointerType *>(argType) != NULL && ((CastType<PointerType>(argType) != NULL &&
argType->IsUniformType()) || argType->IsUniformType()) ||
dynamic_cast<const ReferenceType *>(argType) != NULL)) { CastType<ReferenceType>(argType) != NULL)) {
// NOTE: LLVM indexes function parameters starting from 1. // NOTE: LLVM indexes function parameters starting from 1.
// This is unintuitive. // This is unintuitive.
@@ -962,7 +962,7 @@ lEmitStructDecl(const StructType *st, std::vector<const StructType *> *emittedSt
// Otherwise first make sure any contained structs have been declared. // Otherwise first make sure any contained structs have been declared.
for (int i = 0; i < st->GetElementCount(); ++i) { for (int i = 0; i < st->GetElementCount(); ++i) {
const StructType *elementStructType = const StructType *elementStructType =
dynamic_cast<const StructType *>(st->GetElementType(i)); CastType<StructType>(st->GetElementType(i));
if (elementStructType != NULL) if (elementStructType != NULL)
lEmitStructDecl(elementStructType, emittedStructs, file); lEmitStructDecl(elementStructType, emittedStructs, file);
} }
@@ -1084,7 +1084,7 @@ lAddTypeIfNew(const Type *type, std::vector<const T *> *exportedTypes) {
if (Type::Equal((*exportedTypes)[i], type)) if (Type::Equal((*exportedTypes)[i], type))
return; return;
const T *castType = dynamic_cast<const T *>(type); const T *castType = CastType<T>(type);
Assert(castType != NULL); Assert(castType != NULL);
exportedTypes->push_back(castType); exportedTypes->push_back(castType);
} }
@@ -1099,13 +1099,13 @@ lGetExportedTypes(const Type *type,
std::vector<const StructType *> *exportedStructTypes, std::vector<const StructType *> *exportedStructTypes,
std::vector<const EnumType *> *exportedEnumTypes, std::vector<const EnumType *> *exportedEnumTypes,
std::vector<const VectorType *> *exportedVectorTypes) { std::vector<const VectorType *> *exportedVectorTypes) {
const ArrayType *arrayType = dynamic_cast<const ArrayType *>(type); const ArrayType *arrayType = CastType<ArrayType>(type);
const StructType *structType = dynamic_cast<const StructType *>(type); const StructType *structType = CastType<StructType>(type);
if (dynamic_cast<const ReferenceType *>(type) != NULL) if (CastType<ReferenceType>(type) != NULL)
lGetExportedTypes(type->GetReferenceTarget(), exportedStructTypes, lGetExportedTypes(type->GetReferenceTarget(), exportedStructTypes,
exportedEnumTypes, exportedVectorTypes); exportedEnumTypes, exportedVectorTypes);
else if (dynamic_cast<const PointerType *>(type) != NULL) else if (CastType<PointerType>(type) != NULL)
lGetExportedTypes(type->GetBaseType(), exportedStructTypes, lGetExportedTypes(type->GetBaseType(), exportedStructTypes,
exportedEnumTypes, exportedVectorTypes); exportedEnumTypes, exportedVectorTypes);
else if (arrayType != NULL) else if (arrayType != NULL)
@@ -1117,12 +1117,12 @@ lGetExportedTypes(const Type *type,
lGetExportedTypes(structType->GetElementType(i), exportedStructTypes, lGetExportedTypes(structType->GetElementType(i), exportedStructTypes,
exportedEnumTypes, exportedVectorTypes); exportedEnumTypes, exportedVectorTypes);
} }
else if (dynamic_cast<const EnumType *>(type) != NULL) else if (CastType<EnumType>(type) != NULL)
lAddTypeIfNew(type, exportedEnumTypes); lAddTypeIfNew(type, exportedEnumTypes);
else if (dynamic_cast<const VectorType *>(type) != NULL) else if (CastType<VectorType>(type) != NULL)
lAddTypeIfNew(type, exportedVectorTypes); lAddTypeIfNew(type, exportedVectorTypes);
else else
Assert(dynamic_cast<const AtomicType *>(type) != NULL); Assert(CastType<AtomicType>(type) != NULL);
} }
@@ -1135,7 +1135,7 @@ lGetExportedParamTypes(const std::vector<Symbol *> &funcs,
std::vector<const EnumType *> *exportedEnumTypes, std::vector<const EnumType *> *exportedEnumTypes,
std::vector<const VectorType *> *exportedVectorTypes) { std::vector<const VectorType *> *exportedVectorTypes) {
for (unsigned int i = 0; i < funcs.size(); ++i) { for (unsigned int i = 0; i < funcs.size(); ++i) {
const FunctionType *ftype = dynamic_cast<const FunctionType *>(funcs[i]->type); const FunctionType *ftype = CastType<FunctionType>(funcs[i]->type);
// Handle the return type // Handle the return type
lGetExportedTypes(ftype->GetReturnType(), exportedStructTypes, lGetExportedTypes(ftype->GetReturnType(), exportedStructTypes,
exportedEnumTypes, exportedVectorTypes); exportedEnumTypes, exportedVectorTypes);
@@ -1152,7 +1152,7 @@ static void
lPrintFunctionDeclarations(FILE *file, const std::vector<Symbol *> &funcs) { lPrintFunctionDeclarations(FILE *file, const std::vector<Symbol *> &funcs) {
fprintf(file, "#ifdef __cplusplus\nextern \"C\" {\n#endif // __cplusplus\n"); fprintf(file, "#ifdef __cplusplus\nextern \"C\" {\n#endif // __cplusplus\n");
for (unsigned int i = 0; i < funcs.size(); ++i) { for (unsigned int i = 0; i < funcs.size(); ++i) {
const FunctionType *ftype = dynamic_cast<const FunctionType *>(funcs[i]->type); const FunctionType *ftype = CastType<FunctionType>(funcs[i]->type);
Assert(ftype); Assert(ftype);
std::string decl = ftype->GetCDeclaration(funcs[i]->name); std::string decl = ftype->GetCDeclaration(funcs[i]->name);
fprintf(file, " extern %s;\n", decl.c_str()); fprintf(file, " extern %s;\n", decl.c_str());
@@ -1163,7 +1163,7 @@ lPrintFunctionDeclarations(FILE *file, const std::vector<Symbol *> &funcs) {
static bool static bool
lIsExported(const Symbol *sym) { lIsExported(const Symbol *sym) {
const FunctionType *ft = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *ft = CastType<FunctionType>(sym->type);
Assert(ft); Assert(ft);
return ft->isExported; return ft->isExported;
} }
@@ -1171,7 +1171,7 @@ lIsExported(const Symbol *sym) {
static bool static bool
lIsExternC(const Symbol *sym) { lIsExternC(const Symbol *sym) {
const FunctionType *ft = dynamic_cast<const FunctionType *>(sym->type); const FunctionType *ft = CastType<FunctionType>(sym->type);
Assert(ft); Assert(ft);
return ft->isExternC; return ft->isExternC;
} }

View File

@@ -550,7 +550,7 @@ rate_qualified_type_specifier
$$ = NULL; $$ = NULL;
else { else {
int soaWidth = (int)$1; int soaWidth = (int)$1;
const StructType *st = dynamic_cast<const StructType *>($2); const StructType *st = CastType<StructType>($2);
if (st == NULL) { if (st == NULL) {
Error(@1, "\"soa\" qualifier is illegal with non-struct type \"%s\".", Error(@1, "\"soa\" qualifier is illegal with non-struct type \"%s\".",
$2->GetString().c_str()); $2->GetString().c_str());
@@ -895,7 +895,7 @@ struct_or_union_specifier
st = new UndefinedStructType($2, Variability::Unbound, false, @2); st = new UndefinedStructType($2, Variability::Unbound, false, @2);
m->symbolTable->AddType($2, st, @2); m->symbolTable->AddType($2, st, @2);
} }
else if (dynamic_cast<const StructType *>(st) == NULL) else if (CastType<StructType>(st) == NULL)
Error(@2, "Type \"%s\" is not a struct type! (%s)", $2, Error(@2, "Type \"%s\" is not a struct type! (%s)", $2,
st->GetString().c_str()); st->GetString().c_str());
$$ = st; $$ = st;
@@ -1060,7 +1060,7 @@ enum_specifier
$$ = NULL; $$ = NULL;
} }
else { else {
const EnumType *enumType = dynamic_cast<const EnumType *>(type); const EnumType *enumType = CastType<EnumType>(type);
if (enumType == NULL) { if (enumType == NULL) {
Error(@2, "Type \"%s\" is not an enum type (%s).", $2, Error(@2, "Type \"%s\" is not an enum type (%s).", $2,
type->GetString().c_str()); type->GetString().c_str());
@@ -1858,8 +1858,7 @@ function_definition
{ {
if ($2 != NULL) { if ($2 != NULL) {
$2->InitFromDeclSpecs($1); $2->InitFromDeclSpecs($1);
const FunctionType *funcType = const FunctionType *funcType = CastType<FunctionType>($2->type);
dynamic_cast<const FunctionType *>($2->type);
if (funcType == NULL) if (funcType == NULL)
Assert(m->errorCount > 0); Assert(m->errorCount > 0);
else { else {
@@ -1987,7 +1986,7 @@ lAddDeclaration(DeclSpecs *ds, Declarator *decl) {
decl->type = decl->type->ResolveUnboundVariability(Variability::Varying); decl->type = decl->type->ResolveUnboundVariability(Variability::Varying);
const FunctionType *ft = dynamic_cast<const FunctionType *>(decl->type); const FunctionType *ft = CastType<FunctionType>(decl->type);
if (ft != NULL) { if (ft != NULL) {
bool isInline = (ds->typeQualifiers & TYPEQUAL_INLINE); bool isInline = (ds->typeQualifiers & TYPEQUAL_INLINE);
m->AddFunctionDeclaration(decl->name, ft, ds->storageClass, m->AddFunctionDeclaration(decl->name, ft, ds->storageClass,

View File

@@ -122,7 +122,7 @@ DeclStmt::DeclStmt(const std::vector<VariableDeclaration> &v, SourcePos p)
static bool static bool
lHasUnsizedArrays(const Type *type) { lHasUnsizedArrays(const Type *type) {
const ArrayType *at = dynamic_cast<const ArrayType *>(type); const ArrayType *at = CastType<ArrayType>(type);
if (at == NULL) if (at == NULL)
return false; return false;
@@ -297,8 +297,8 @@ DeclStmt::TypeCheck() {
// the int->float type conversion is in there and we don't return // the int->float type conversion is in there and we don't return
// an int as the constValue later... // an int as the constValue later...
const Type *type = vars[i].sym->type; const Type *type = vars[i].sym->type;
if (dynamic_cast<const AtomicType *>(type) != NULL || if (CastType<AtomicType>(type) != NULL ||
dynamic_cast<const EnumType *>(type) != NULL) { CastType<EnumType>(type) != NULL) {
// If it's an expr list with an atomic type, we'll later issue // If it's an expr list with an atomic type, we'll later issue
// an error. Need to leave vars[i].init as is in that case so // an error. Need to leave vars[i].init as is in that case so
// it is in fact caught later, though. // it is in fact caught later, though.
@@ -2461,7 +2461,7 @@ lEncodeType(const Type *t) {
if (Type::Equal(t, AtomicType::VaryingUInt64)) return 'V'; if (Type::Equal(t, AtomicType::VaryingUInt64)) return 'V';
if (Type::Equal(t, AtomicType::UniformDouble)) return 'd'; if (Type::Equal(t, AtomicType::UniformDouble)) return 'd';
if (Type::Equal(t, AtomicType::VaryingDouble)) return 'D'; if (Type::Equal(t, AtomicType::VaryingDouble)) return 'D';
if (dynamic_cast<const PointerType *>(t) != NULL) { if (CastType<PointerType>(t) != NULL) {
if (t->IsUniformType()) if (t->IsUniformType())
return 'p'; return 'p';
else else
@@ -2481,7 +2481,7 @@ lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, std::string &argTypes) {
if (type == NULL) if (type == NULL)
return NULL; return NULL;
if (dynamic_cast<const ReferenceType *>(type) != NULL) { if (CastType<ReferenceType>(type) != NULL) {
expr = new RefDerefExpr(expr, expr->pos); expr = new RefDerefExpr(expr, expr->pos);
type = expr->GetType(); type = expr->GetType();
if (type == NULL) if (type == NULL)
@@ -2732,7 +2732,7 @@ DeleteStmt::EmitCode(FunctionEmitContext *ctx) const {
} }
// Typechecking should catch this // Typechecking should catch this
Assert(dynamic_cast<const PointerType *>(exprType) != NULL); Assert(CastType<PointerType>(exprType) != NULL);
if (exprType->IsUniformType()) { if (exprType->IsUniformType()) {
// For deletion of a uniform pointer, we just need to cast the // For deletion of a uniform pointer, we just need to cast the
@@ -2772,7 +2772,7 @@ DeleteStmt::TypeCheck() {
if (expr == NULL || ((exprType = expr->GetType()) == NULL)) if (expr == NULL || ((exprType = expr->GetType()) == NULL))
return NULL; return NULL;
if (dynamic_cast<const PointerType *>(exprType) == NULL) { if (CastType<PointerType>(exprType) == NULL) {
Error(pos, "Illegal to delete non-pointer type \"%s\".", Error(pos, "Illegal to delete non-pointer type \"%s\".",
exprType->GetString().c_str()); exprType->GetString().c_str());
return NULL; return NULL;

View File

@@ -136,7 +136,7 @@ SymbolTable::LookupVariable(const char *name) {
bool bool
SymbolTable::AddFunction(Symbol *symbol) { SymbolTable::AddFunction(Symbol *symbol) {
const FunctionType *ft = dynamic_cast<const FunctionType *>(symbol->type); const FunctionType *ft = CastType<FunctionType>(symbol->type);
Assert(ft != NULL); Assert(ft != NULL);
if (LookupFunction(symbol->name.c_str(), ft) != NULL) if (LookupFunction(symbol->name.c_str(), ft) != NULL)
// A function of the same name and type has already been added to // A function of the same name and type has already been added to
@@ -182,7 +182,7 @@ SymbolTable::LookupFunction(const char *name, const FunctionType *type) {
bool bool
SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) { SymbolTable::AddType(const char *name, const Type *type, SourcePos pos) {
const Type *t = LookupType(name); const Type *t = LookupType(name);
if (t != NULL && dynamic_cast<const UndefinedStructType *>(t) == NULL) { if (t != NULL && CastType<UndefinedStructType>(t) == NULL) {
// If we have a previous declaration of anything other than an // If we have a previous declaration of anything other than an
// UndefinedStructType with this struct name, issue an error. If // UndefinedStructType with this struct name, issue an error. If
// we have an UndefinedStructType, then we'll fall through to the // we have an UndefinedStructType, then we'll fall through to the
@@ -270,7 +270,7 @@ SymbolTable::closestTypeMatch(const char *str, bool structsVsEnums) const {
for (iter = types.begin(); iter != types.end(); ++iter) { for (iter = types.begin(); iter != types.end(); ++iter) {
// Skip over either StructTypes or EnumTypes, depending on the // Skip over either StructTypes or EnumTypes, depending on the
// value of the structsVsEnums parameter // value of the structsVsEnums parameter
bool isEnum = (dynamic_cast<const EnumType *>(iter->second) != NULL); bool isEnum = (CastType<EnumType>(iter->second) != NULL);
if (isEnum && structsVsEnums) if (isEnum && structsVsEnums)
continue; continue;
else if (!isEnum && !structsVsEnums) else if (!isEnum && !structsVsEnums)

137
type.cpp
View File

@@ -184,7 +184,7 @@ const AtomicType *AtomicType::Void =
AtomicType::AtomicType(BasicType bt, Variability v, bool ic) AtomicType::AtomicType(BasicType bt, Variability v, bool ic)
: basicType(bt), variability(v), isConst(ic) { : Type(ATOMIC_TYPE), basicType(bt), variability(v), isConst(ic) {
} }
@@ -532,7 +532,7 @@ AtomicType::GetDIType(llvm::DIDescriptor scope) const {
// EnumType // EnumType
EnumType::EnumType(SourcePos p) EnumType::EnumType(SourcePos p)
: pos(p) { : Type(ENUM_TYPE), pos(p) {
// name = "/* (anonymous) */"; // name = "/* (anonymous) */";
isConst = false; isConst = false;
variability = Variability(Variability::Unbound); variability = Variability(Variability::Unbound);
@@ -540,7 +540,7 @@ EnumType::EnumType(SourcePos p)
EnumType::EnumType(const char *n, SourcePos p) EnumType::EnumType(const char *n, SourcePos p)
: pos(p), name(n) { : Type(ENUM_TYPE), pos(p), name(n) {
isConst = false; isConst = false;
variability = Variability(Variability::Unbound); variability = Variability(Variability::Unbound);
} }
@@ -817,7 +817,7 @@ PointerType *PointerType::Void =
PointerType::PointerType(const Type *t, Variability v, bool ic, bool is, PointerType::PointerType(const Type *t, Variability v, bool ic, bool is,
bool fr) bool fr)
: variability(v), isConst(ic), isSlice(is), isFrozen(fr) { : Type(POINTER_TYPE), variability(v), isConst(ic), isSlice(is), isFrozen(fr) {
baseType = t; baseType = t;
} }
@@ -1083,7 +1083,7 @@ PointerType::LLVMType(llvm::LLVMContext *ctx) const {
switch (variability.type) { switch (variability.type) {
case Variability::Uniform: { case Variability::Uniform: {
llvm::Type *ptype = NULL; llvm::Type *ptype = NULL;
const FunctionType *ftype = dynamic_cast<const FunctionType *>(baseType); const FunctionType *ftype = CastType<FunctionType>(baseType);
if (ftype != NULL) if (ftype != NULL)
// Get the type of the function variant that takes the mask as the // Get the type of the function variant that takes the mask as the
// last parameter--i.e. we don't allow taking function pointers of // last parameter--i.e. we don't allow taking function pointers of
@@ -1155,7 +1155,7 @@ const Type *SequentialType::GetElementType(int index) const {
// ArrayType // ArrayType
ArrayType::ArrayType(const Type *c, int a) ArrayType::ArrayType(const Type *c, int a)
: child(c), numElements(a) { : SequentialType(ARRAY_TYPE), child(c), numElements(a) {
// 0 -> unsized array. // 0 -> unsized array.
Assert(numElements >= 0); Assert(numElements >= 0);
Assert(Type::Equal(c, AtomicType::Void) == false); Assert(Type::Equal(c, AtomicType::Void) == false);
@@ -1217,11 +1217,11 @@ ArrayType::IsConstType() const {
const Type * const Type *
ArrayType::GetBaseType() const { ArrayType::GetBaseType() const {
const Type *type = child; const Type *type = child;
const ArrayType *at = dynamic_cast<const ArrayType *>(type); const ArrayType *at = CastType<ArrayType>(type);
// Keep walking until we reach a child that isn't itself an array // Keep walking until we reach a child that isn't itself an array
while (at) { while (at) {
type = at->child; type = at->child;
at = dynamic_cast<const ArrayType *>(type); at = CastType<ArrayType>(type);
} }
return type; return type;
} }
@@ -1338,7 +1338,7 @@ ArrayType::GetString() const {
else else
buf[0] = '\0'; buf[0] = '\0';
s += std::string("[") + std::string(buf) + std::string("]"); s += std::string("[") + std::string(buf) + std::string("]");
at = dynamic_cast<const ArrayType *>(at->child); at = CastType<ArrayType>(at->child);
} }
return s; return s;
} }
@@ -1381,7 +1381,7 @@ ArrayType::GetCDeclaration(const std::string &name) const {
else else
buf[0] = '\0'; buf[0] = '\0';
s += std::string("[") + std::string(buf) + std::string("]"); s += std::string("[") + std::string(buf) + std::string("]");
at = dynamic_cast<const ArrayType *>(at->child); at = CastType<ArrayType>(at->child);
} }
if (soaWidth > 0) { if (soaWidth > 0) {
@@ -1396,7 +1396,7 @@ ArrayType::GetCDeclaration(const std::string &name) const {
int int
ArrayType::TotalElementCount() const { ArrayType::TotalElementCount() const {
const ArrayType *ct = dynamic_cast<const ArrayType *>(child); const ArrayType *ct = CastType<ArrayType>(child);
if (ct != NULL) if (ct != NULL)
return numElements * ct->TotalElementCount(); return numElements * ct->TotalElementCount();
else else
@@ -1425,7 +1425,7 @@ ArrayType::GetSizedArray(int sz) const {
const Type * const Type *
ArrayType::SizeUnsizedArrays(const Type *type, Expr *initExpr) { ArrayType::SizeUnsizedArrays(const Type *type, Expr *initExpr) {
const ArrayType *at = dynamic_cast<const ArrayType *>(type); const ArrayType *at = CastType<ArrayType>(type);
if (at == NULL) if (at == NULL)
return type; return type;
@@ -1437,7 +1437,7 @@ ArrayType::SizeUnsizedArrays(const Type *type, Expr *initExpr) {
// length of the expression list // length of the expression list
if (at->GetElementCount() == 0) { if (at->GetElementCount() == 0) {
type = at->GetSizedArray(exprList->exprs.size()); type = at->GetSizedArray(exprList->exprs.size());
at = dynamic_cast<const ArrayType *>(type); at = CastType<ArrayType>(type);
} }
// Is there another nested level of expression lists? If not, bail out // Is there another nested level of expression lists? If not, bail out
@@ -1449,7 +1449,7 @@ ArrayType::SizeUnsizedArrays(const Type *type, Expr *initExpr) {
return type; return type;
const Type *nextType = at->GetElementType(); const Type *nextType = at->GetElementType();
const ArrayType *nextArrayType = dynamic_cast<const ArrayType *>(nextType); const ArrayType *nextArrayType = CastType<ArrayType>(nextType);
if (nextArrayType != NULL && nextArrayType->GetElementCount() == 0) { if (nextArrayType != NULL && nextArrayType->GetElementCount() == 0) {
// If the recursive call to SizeUnsizedArrays at the bottom of the // If the recursive call to SizeUnsizedArrays at the bottom of the
// function is going to size an unsized dimension, make sure that // function is going to size an unsized dimension, make sure that
@@ -1485,7 +1485,7 @@ ArrayType::SizeUnsizedArrays(const Type *type, Expr *initExpr) {
// VectorType // VectorType
VectorType::VectorType(const AtomicType *b, int a) VectorType::VectorType(const AtomicType *b, int a)
: base(b), numElements(a) { : SequentialType(VECTOR_TYPE), base(b), numElements(a) {
Assert(numElements > 0); Assert(numElements > 0);
Assert(base != NULL); Assert(base != NULL);
} }
@@ -2111,8 +2111,7 @@ StructType::checkIfCanBeSOA(const StructType *st) {
bool ok = true; bool ok = true;
for (int i = 0; i < (int)st->elementTypes.size(); ++i) { for (int i = 0; i < (int)st->elementTypes.size(); ++i) {
const Type *eltType = st->elementTypes[i]; const Type *eltType = st->elementTypes[i];
const StructType *childStructType = const StructType *childStructType = CastType<StructType>(eltType);
dynamic_cast<const StructType *>(eltType);
if (childStructType != NULL) if (childStructType != NULL)
ok &= checkIfCanBeSOA(childStructType); ok &= checkIfCanBeSOA(childStructType);
@@ -2124,7 +2123,7 @@ StructType::checkIfCanBeSOA(const StructType *st) {
eltType->IsUniformType() ? "uniform" : "varying"); eltType->IsUniformType() ? "uniform" : "varying");
ok = false; ok = false;
} }
else if (dynamic_cast<const ReferenceType *>(eltType)) { else if (CastType<ReferenceType>(eltType)) {
Error(st->elementPositions[i], "Unable to apply SOA conversion to " Error(st->elementPositions[i], "Unable to apply SOA conversion to "
"struct due to member \"%s\" with reference type \"%s\".", "struct due to member \"%s\" with reference type \"%s\".",
st->elementNames[i].c_str(), eltType->GetString().c_str()); st->elementNames[i].c_str(), eltType->GetString().c_str());
@@ -2141,7 +2140,7 @@ StructType::checkIfCanBeSOA(const StructType *st) {
UndefinedStructType::UndefinedStructType(const std::string &n, UndefinedStructType::UndefinedStructType(const std::string &n,
const Variability var, bool ic, const Variability var, bool ic,
SourcePos p) SourcePos p)
: name(n), variability(var), isConst(ic), pos(p) { : Type(UNDEFINED_STRUCT_TYPE), name(n), variability(var), isConst(ic), pos(p) {
Assert(name != ""); Assert(name != "");
if (variability != Variability::Unbound) { if (variability != Variability::Unbound) {
// Create a new opaque LLVM struct type for this struct name // Create a new opaque LLVM struct type for this struct name
@@ -2303,7 +2302,7 @@ UndefinedStructType::GetDIType(llvm::DIDescriptor scope) const {
// ReferenceType // ReferenceType
ReferenceType::ReferenceType(const Type *t) ReferenceType::ReferenceType(const Type *t)
: targetType(t) { : Type(REFERENCE_TYPE), targetType(t) {
} }
@@ -2493,7 +2492,7 @@ ReferenceType::GetCDeclaration(const std::string &name) const {
return ""; return "";
} }
const ArrayType *at = dynamic_cast<const ArrayType *>(targetType); const ArrayType *at = CastType<ArrayType>(targetType);
if (at != NULL) { if (at != NULL) {
if (at->GetElementCount() == 0) { if (at->GetElementCount() == 0) {
// emit unsized arrays as pointers to the base type.. // emit unsized arrays as pointers to the base type..
@@ -2553,8 +2552,8 @@ ReferenceType::GetDIType(llvm::DIDescriptor scope) const {
FunctionType::FunctionType(const Type *r, const std::vector<const Type *> &a, FunctionType::FunctionType(const Type *r, const std::vector<const Type *> &a,
SourcePos p) SourcePos p)
: isTask(false), isExported(false), isExternC(false), returnType(r), : Type(FUNCTION_TYPE), isTask(false), isExported(false), isExternC(false),
paramTypes(a), paramNames(std::vector<std::string>(a.size(), "")), returnType(r), paramTypes(a), paramNames(std::vector<std::string>(a.size(), "")),
paramDefaults(std::vector<Expr *>(a.size(), NULL)), paramDefaults(std::vector<Expr *>(a.size(), NULL)),
paramPositions(std::vector<SourcePos>(a.size(), p)) { paramPositions(std::vector<SourcePos>(a.size(), p)) {
Assert(returnType != NULL); Assert(returnType != NULL);
@@ -2568,8 +2567,8 @@ FunctionType::FunctionType(const Type *r, const std::vector<const Type *> &a,
const std::vector<Expr *> &ad, const std::vector<Expr *> &ad,
const std::vector<SourcePos> &ap, const std::vector<SourcePos> &ap,
bool it, bool is, bool ec) bool it, bool is, bool ec)
: isTask(it), isExported(is), isExternC(ec), returnType(r), paramTypes(a), : Type(FUNCTION_TYPE), isTask(it), isExported(is), isExternC(ec), returnType(r),
paramNames(an), paramDefaults(ad), paramPositions(ap) { paramTypes(a), paramNames(an), paramDefaults(ad), paramPositions(ap) {
Assert(paramTypes.size() == paramNames.size() && Assert(paramTypes.size() == paramNames.size() &&
paramNames.size() == paramDefaults.size() && paramNames.size() == paramDefaults.size() &&
paramDefaults.size() == paramPositions.size()); paramDefaults.size() == paramPositions.size());
@@ -2733,9 +2732,9 @@ FunctionType::GetCDeclaration(const std::string &fname) const {
// Convert pointers to arrays to unsized arrays, which are more clear // Convert pointers to arrays to unsized arrays, which are more clear
// to print out for multidimensional arrays (i.e. "float foo[][4] " // to print out for multidimensional arrays (i.e. "float foo[][4] "
// versus "float (foo *)[4]"). // versus "float (foo *)[4]").
const PointerType *pt = dynamic_cast<const PointerType *>(type); const PointerType *pt = CastType<PointerType>(type);
if (pt != NULL && if (pt != NULL &&
dynamic_cast<const ArrayType *>(pt->GetBaseType()) != NULL) { CastType<ArrayType>(pt->GetBaseType()) != NULL) {
type = new ArrayType(pt->GetBaseType(), 0); type = new ArrayType(pt->GetBaseType(), 0);
} }
@@ -2906,7 +2905,7 @@ Type::GetAsUnsignedType() const {
*/ */
static const Type * static const Type *
lVectorConvert(const Type *type, SourcePos pos, const char *reason, int vecSize) { lVectorConvert(const Type *type, SourcePos pos, const char *reason, int vecSize) {
const VectorType *vt = dynamic_cast<const VectorType *>(type); const VectorType *vt = CastType<VectorType>(type);
if (vt) { if (vt) {
if (vt->GetElementCount() != vecSize) { if (vt->GetElementCount() != vecSize) {
Error(pos, "Implicit conversion between from vector type " Error(pos, "Implicit conversion between from vector type "
@@ -2917,7 +2916,7 @@ lVectorConvert(const Type *type, SourcePos pos, const char *reason, int vecSize)
return vt; return vt;
} }
else { else {
const AtomicType *at = dynamic_cast<const AtomicType *>(type); const AtomicType *at = CastType<AtomicType>(type);
if (!at) { if (!at) {
Error(pos, "Non-atomic type \"%s\" can't be converted to vector type " Error(pos, "Non-atomic type \"%s\" can't be converted to vector type "
"for %s.", type->GetString().c_str(), reason); "for %s.", type->GetString().c_str(), reason);
@@ -2935,11 +2934,10 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
// First, if one or both types are function types, convert them to // First, if one or both types are function types, convert them to
// pointer to function types and then try again. // pointer to function types and then try again.
if (dynamic_cast<const FunctionType *>(t0) || if (CastType<FunctionType>(t0) || CastType<FunctionType>(t1)) {
dynamic_cast<const FunctionType *>(t1)) { if (CastType<FunctionType>(t0))
if (dynamic_cast<const FunctionType *>(t0))
t0 = PointerType::GetUniform(t0); t0 = PointerType::GetUniform(t0);
if (dynamic_cast<const FunctionType *>(t1)) if (CastType<FunctionType>(t1))
t1 = PointerType::GetUniform(t1); t1 = PointerType::GetUniform(t1);
return MoreGeneralType(t0, t1, pos, reason, forceVarying, vecSize); return MoreGeneralType(t0, t1, pos, reason, forceVarying, vecSize);
} }
@@ -2967,8 +2965,7 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
// If they're function types, it's hopeless if they didn't match in the // If they're function types, it's hopeless if they didn't match in the
// Type::Equal() call above. Fail here so that we don't get into // Type::Equal() call above. Fail here so that we don't get into
// trouble calling GetAsConstType()... // trouble calling GetAsConstType()...
if (dynamic_cast<const FunctionType *>(t0) || if (CastType<FunctionType>(t0) || CastType<FunctionType>(t1)) {
dynamic_cast<const FunctionType *>(t1)) {
Error(pos, "Incompatible function types \"%s\" and \"%s\" in %s.", Error(pos, "Incompatible function types \"%s\" and \"%s\" in %s.",
t0->GetString().c_str(), t1->GetString().c_str(), reason); t0->GetString().c_str(), t1->GetString().c_str(), reason);
return NULL; return NULL;
@@ -2979,8 +2976,8 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
if (Type::EqualIgnoringConst(t0, t1)) if (Type::EqualIgnoringConst(t0, t1))
return t0->GetAsNonConstType(); return t0->GetAsNonConstType();
const PointerType *pt0 = dynamic_cast<const PointerType *>(t0); const PointerType *pt0 = CastType<PointerType>(t0);
const PointerType *pt1 = dynamic_cast<const PointerType *>(t1); const PointerType *pt1 = CastType<PointerType>(t1);
if (pt0 != NULL && pt1 != NULL) { if (pt0 != NULL && pt1 != NULL) {
if (PointerType::IsVoidPointer(pt0)) if (PointerType::IsVoidPointer(pt0))
return pt1; return pt1;
@@ -2994,8 +2991,8 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
} }
} }
const VectorType *vt0 = dynamic_cast<const VectorType *>(t0); const VectorType *vt0 = CastType<VectorType>(t0);
const VectorType *vt1 = dynamic_cast<const VectorType *>(t1); const VectorType *vt1 = CastType<VectorType>(t1);
if (vt0 && vt1) { if (vt0 && vt1) {
// both are vectors; convert their base types and make a new vector // both are vectors; convert their base types and make a new vector
// type, as long as their lengths match // type, as long as their lengths match
@@ -3012,7 +3009,7 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
// The 'more general' version of the two vector element types must // The 'more general' version of the two vector element types must
// be an AtomicType (that's all that vectors can hold...) // be an AtomicType (that's all that vectors can hold...)
const AtomicType *at = dynamic_cast<const AtomicType *>(t); const AtomicType *at = CastType<AtomicType>(t);
Assert(at != NULL); Assert(at != NULL);
return new VectorType(at, vt0->GetElementCount()); return new VectorType(at, vt0->GetElementCount());
@@ -3027,7 +3024,7 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
if (!t) if (!t)
return NULL; return NULL;
const AtomicType *at = dynamic_cast<const AtomicType *>(t); const AtomicType *at = CastType<AtomicType>(t);
Assert(at != NULL); Assert(at != NULL);
return new VectorType(at, vt0->GetElementCount()); return new VectorType(at, vt0->GetElementCount());
} }
@@ -3039,18 +3036,18 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
if (!t) if (!t)
return NULL; return NULL;
const AtomicType *at = dynamic_cast<const AtomicType *>(t); const AtomicType *at = CastType<AtomicType>(t);
Assert(at != NULL); Assert(at != NULL);
return new VectorType(at, vt1->GetElementCount()); return new VectorType(at, vt1->GetElementCount());
} }
// TODO: what do we need to do about references here, if anything?? // TODO: what do we need to do about references here, if anything??
const AtomicType *at0 = dynamic_cast<const AtomicType *>(t0->GetReferenceTarget()); const AtomicType *at0 = CastType<AtomicType>(t0->GetReferenceTarget());
const AtomicType *at1 = dynamic_cast<const AtomicType *>(t1->GetReferenceTarget()); const AtomicType *at1 = CastType<AtomicType>(t1->GetReferenceTarget());
const EnumType *et0 = dynamic_cast<const EnumType *>(t0->GetReferenceTarget()); const EnumType *et0 = CastType<EnumType>(t0->GetReferenceTarget());
const EnumType *et1 = dynamic_cast<const EnumType *>(t1->GetReferenceTarget()); const EnumType *et1 = CastType<EnumType>(t1->GetReferenceTarget());
if (et0 != NULL && et1 != NULL) { if (et0 != NULL && et1 != NULL) {
// Two different enum types -> make them uint32s... // Two different enum types -> make them uint32s...
Assert(et0->IsVaryingType() == et1->IsVaryingType()); Assert(et0->IsVaryingType() == et1->IsVaryingType());
@@ -3098,9 +3095,9 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char
bool bool
Type::IsBasicType(const Type *type) { Type::IsBasicType(const Type *type) {
return (dynamic_cast<const AtomicType *>(type) != NULL || return (CastType<AtomicType>(type) != NULL ||
dynamic_cast<const EnumType *>(type) != NULL || CastType<EnumType>(type) != NULL ||
dynamic_cast<const PointerType *>(type) != NULL); CastType<PointerType>(type) != NULL);
} }
@@ -3110,16 +3107,16 @@ lCheckTypeEquality(const Type *a, const Type *b, bool ignoreConst) {
return false; return false;
if (ignoreConst == true) { if (ignoreConst == true) {
if (dynamic_cast<const FunctionType *>(a) == NULL) if (CastType<FunctionType>(a) == NULL)
a = a->GetAsNonConstType(); a = a->GetAsNonConstType();
if (dynamic_cast<const FunctionType *>(b) == NULL) if (CastType<FunctionType>(b) == NULL)
b = b->GetAsNonConstType(); b = b->GetAsNonConstType();
} }
else if (a->IsConstType() != b->IsConstType()) else if (a->IsConstType() != b->IsConstType())
return false; return false;
const AtomicType *ata = dynamic_cast<const AtomicType *>(a); const AtomicType *ata = CastType<AtomicType>(a);
const AtomicType *atb = dynamic_cast<const AtomicType *>(b); const AtomicType *atb = CastType<AtomicType>(b);
if (ata != NULL && atb != NULL) { if (ata != NULL && atb != NULL) {
return ((ata->basicType == atb->basicType) && return ((ata->basicType == atb->basicType) &&
(ata->GetVariability() == atb->GetVariability())); (ata->GetVariability() == atb->GetVariability()));
@@ -3128,33 +3125,31 @@ lCheckTypeEquality(const Type *a, const Type *b, bool ignoreConst) {
// For all of the other types, we need to see if we have the same two // For all of the other types, we need to see if we have the same two
// general types. If so, then we dig into the details of the type and // general types. If so, then we dig into the details of the type and
// see if all of the relevant bits are equal... // see if all of the relevant bits are equal...
const EnumType *eta = dynamic_cast<const EnumType *>(a); const EnumType *eta = CastType<EnumType>(a);
const EnumType *etb = dynamic_cast<const EnumType *>(b); const EnumType *etb = CastType<EnumType>(b);
if (eta != NULL && etb != NULL) if (eta != NULL && etb != NULL)
// Kind of goofy, but this sufficies to check // Kind of goofy, but this sufficies to check
return (eta->pos == etb->pos && return (eta->pos == etb->pos &&
eta->GetVariability() == etb->GetVariability()); eta->GetVariability() == etb->GetVariability());
const ArrayType *arta = dynamic_cast<const ArrayType *>(a); const ArrayType *arta = CastType<ArrayType>(a);
const ArrayType *artb = dynamic_cast<const ArrayType *>(b); const ArrayType *artb = CastType<ArrayType>(b);
if (arta != NULL && artb != NULL) if (arta != NULL && artb != NULL)
return (arta->GetElementCount() == artb->GetElementCount() && return (arta->GetElementCount() == artb->GetElementCount() &&
lCheckTypeEquality(arta->GetElementType(), artb->GetElementType(), lCheckTypeEquality(arta->GetElementType(), artb->GetElementType(),
ignoreConst)); ignoreConst));
const VectorType *vta = dynamic_cast<const VectorType *>(a); const VectorType *vta = CastType<VectorType>(a);
const VectorType *vtb = dynamic_cast<const VectorType *>(b); const VectorType *vtb = CastType<VectorType>(b);
if (vta != NULL && vtb != NULL) if (vta != NULL && vtb != NULL)
return (vta->GetElementCount() == vtb->GetElementCount() && return (vta->GetElementCount() == vtb->GetElementCount() &&
lCheckTypeEquality(vta->GetElementType(), vtb->GetElementType(), lCheckTypeEquality(vta->GetElementType(), vtb->GetElementType(),
ignoreConst)); ignoreConst));
const StructType *sta = dynamic_cast<const StructType *>(a); const StructType *sta = CastType<StructType>(a);
const StructType *stb = dynamic_cast<const StructType *>(b); const StructType *stb = CastType<StructType>(b);
const UndefinedStructType *usta = const UndefinedStructType *usta = CastType<UndefinedStructType>(a);
dynamic_cast<const UndefinedStructType *>(a); const UndefinedStructType *ustb = CastType<UndefinedStructType>(b);
const UndefinedStructType *ustb =
dynamic_cast<const UndefinedStructType *>(b);
if ((sta != NULL || usta != NULL) && (stb != NULL || ustb != NULL)) { if ((sta != NULL || usta != NULL) && (stb != NULL || ustb != NULL)) {
// Report both defuned and undefined structs as equal if their // Report both defuned and undefined structs as equal if their
// names are the same. // names are the same.
@@ -3166,8 +3161,8 @@ lCheckTypeEquality(const Type *a, const Type *b, bool ignoreConst) {
return (namea == nameb); return (namea == nameb);
} }
const PointerType *pta = dynamic_cast<const PointerType *>(a); const PointerType *pta = CastType<PointerType>(a);
const PointerType *ptb = dynamic_cast<const PointerType *>(b); const PointerType *ptb = CastType<PointerType>(b);
if (pta != NULL && ptb != NULL) if (pta != NULL && ptb != NULL)
return (pta->IsUniformType() == ptb->IsUniformType() && return (pta->IsUniformType() == ptb->IsUniformType() &&
pta->IsSlice() == ptb->IsSlice() && pta->IsSlice() == ptb->IsSlice() &&
@@ -3175,14 +3170,14 @@ lCheckTypeEquality(const Type *a, const Type *b, bool ignoreConst) {
lCheckTypeEquality(pta->GetBaseType(), ptb->GetBaseType(), lCheckTypeEquality(pta->GetBaseType(), ptb->GetBaseType(),
ignoreConst)); ignoreConst));
const ReferenceType *rta = dynamic_cast<const ReferenceType *>(a); const ReferenceType *rta = CastType<ReferenceType>(a);
const ReferenceType *rtb = dynamic_cast<const ReferenceType *>(b); const ReferenceType *rtb = CastType<ReferenceType>(b);
if (rta != NULL && rtb != NULL) if (rta != NULL && rtb != NULL)
return (lCheckTypeEquality(rta->GetReferenceTarget(), return (lCheckTypeEquality(rta->GetReferenceTarget(),
rtb->GetReferenceTarget(), ignoreConst)); rtb->GetReferenceTarget(), ignoreConst));
const FunctionType *fta = dynamic_cast<const FunctionType *>(a); const FunctionType *fta = CastType<FunctionType>(a);
const FunctionType *ftb = dynamic_cast<const FunctionType *>(b); const FunctionType *ftb = CastType<FunctionType>(b);
if (fta != NULL && ftb != NULL) { if (fta != NULL && ftb != NULL) {
// Both the return types and all of the argument types must match // Both the return types and all of the argument types must match
// for function types to match // for function types to match

148
type.h
View File

@@ -72,6 +72,21 @@ struct Variability {
}; };
/** Enumerant that records each of the types that inherit from the Type
baseclass. */
enum TypeId {
ATOMIC_TYPE,
ENUM_TYPE,
POINTER_TYPE,
ARRAY_TYPE,
VECTOR_TYPE,
STRUCT_TYPE,
UNDEFINED_STRUCT_TYPE,
REFERENCE_TYPE,
FUNCTION_TYPE
};
/** @brief Interface class that defines the type abstraction. /** @brief Interface class that defines the type abstraction.
Abstract base class that defines the interface that must be implemented Abstract base class that defines the interface that must be implemented
@@ -231,6 +246,14 @@ public:
(i.e. not an aggregation of multiple instances of a type or (i.e. not an aggregation of multiple instances of a type or
types.) */ types.) */
static bool IsBasicType(const Type *type); static bool IsBasicType(const Type *type);
/** Indicates which Type implementation this type is. This value can
be used to determine the actual type much more efficiently than
using dynamic_cast. */
const TypeId typeId;
protected:
Type(TypeId id) : typeId(id) { }
}; };
@@ -452,6 +475,9 @@ public:
index must be between 0 and GetElementCount()-1. index must be between 0 and GetElementCount()-1.
*/ */
virtual const Type *GetElementType(int index) const = 0; virtual const Type *GetElementType(int index) const = 0;
protected:
CollectionType(TypeId id) : Type(id) { }
}; };
@@ -473,6 +499,9 @@ public:
the same type. the same type.
*/ */
const Type *GetElementType(int index) const; const Type *GetElementType(int index) const;
protected:
SequentialType(TypeId id) : CollectionType(id) { }
}; };
@@ -686,6 +715,8 @@ private:
const Variability variability; const Variability variability;
const bool isConst; const bool isConst;
const SourcePos pos; const SourcePos pos;
mutable const StructType *oppositeConstStructType;
}; };
@@ -732,8 +763,6 @@ private:
const Variability variability; const Variability variability;
const bool isConst; const bool isConst;
const SourcePos pos; const SourcePos pos;
mutable const StructType *oppositeConstStructType;
}; };
@@ -875,8 +904,119 @@ private:
const std::vector<SourcePos> paramPositions; const std::vector<SourcePos> paramPositions;
}; };
inline bool IsReferenceType(const Type *t) {
return dynamic_cast<const ReferenceType *>(t) != NULL; /* Efficient dynamic casting of Types. First, we specify a default
template function that returns NULL, indicating a failed cast, for
arbitrary types. */
template <typename T> inline const T *
CastType(const Type *type) {
return NULL;
} }
/* Now we have template specializaitons for the Types implemented in this
file. Each one checks the Type::typeId member and then performs the
corresponding static cast if it's safe as per the typeId.
*/
template <> inline const AtomicType *
CastType(const Type *type) {
if (type != NULL && type->typeId == ATOMIC_TYPE)
return (const AtomicType *)type;
else
return NULL;
}
template <> inline const EnumType *
CastType(const Type *type) {
if (type != NULL && type->typeId == ENUM_TYPE)
return (const EnumType *)type;
else
return NULL;
}
template <> inline const PointerType *
CastType(const Type *type) {
if (type != NULL && type->typeId == POINTER_TYPE)
return (const PointerType *)type;
else
return NULL;
}
template <> inline const ArrayType *
CastType(const Type *type) {
if (type != NULL && type->typeId == ARRAY_TYPE)
return (const ArrayType *)type;
else
return NULL;
}
template <> inline const VectorType *
CastType(const Type *type) {
if (type != NULL && type->typeId == VECTOR_TYPE)
return (const VectorType *)type;
else
return NULL;
}
template <> inline const SequentialType *
CastType(const Type *type) {
// Note that this function must be updated if other sequential type
// implementations are added.
if (type != NULL &&
(type->typeId == ARRAY_TYPE || type->typeId == VECTOR_TYPE))
return (const SequentialType *)type;
else
return NULL;
}
template <> inline const CollectionType *
CastType(const Type *type) {
// Similarly a new collection type implementation requires updating
// this function.
if (type != NULL &&
(type->typeId == ARRAY_TYPE || type->typeId == VECTOR_TYPE ||
type->typeId == STRUCT_TYPE))
return (const CollectionType *)type;
else
return NULL;
}
template <> inline const StructType *
CastType(const Type *type) {
if (type != NULL && type->typeId == STRUCT_TYPE)
return (const StructType *)type;
else
return NULL;
}
template <> inline const UndefinedStructType *
CastType(const Type *type) {
if (type != NULL && type->typeId == UNDEFINED_STRUCT_TYPE)
return (const UndefinedStructType *)type;
else
return NULL;
}
template <> inline const ReferenceType *
CastType(const Type *type) {
if (type != NULL && type->typeId == REFERENCE_TYPE)
return (const ReferenceType *)type;
else
return NULL;
}
template <> inline const FunctionType *
CastType(const Type *type) {
if (type != NULL && type->typeId == FUNCTION_TYPE)
return (const FunctionType *)type;
else
return NULL;
}
inline bool IsReferenceType(const Type *t) {
return CastType<ReferenceType>(t) != NULL;
}
#endif // ISPC_TYPE_H #endif // ISPC_TYPE_H