From 4cd0cf1650317934bc8b3b93fed80c0b9020653f Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Tue, 3 Apr 2012 10:09:07 -0700 Subject: [PATCH] Revamp handling of function types, conversion to function ptr types. Implicit conversion to function types is now a more standard part of the type conversion infrastructure, rather than special cases of things like FunctionSymbolExpr immediately returning a pointer type, etc. Improved AddressOfExpr::TypeCheck() to actually issue errors in cases where it's illegal to take the address of an expression. Added AddressOfExpr::GetConstant() implementation that handles taking the address of functions. Issue #223. --- expr.cpp | 165 +++++++++++++++++++++++++++-------- expr.h | 1 + tests/funcptr-null-1.ispc | 2 +- tests/funcptr-null-3.ispc | 2 +- tests/funcptr-null-6.ispc | 2 +- tests/funcptr-uniform-2.ispc | 2 +- tests_errors/addr-of-1.ispc | 5 ++ type.cpp | 11 +++ 8 files changed, 149 insertions(+), 41 deletions(-) create mode 100644 tests_errors/addr-of-1.ispc diff --git a/expr.cpp b/expr.cpp index ebfac14f..f46347ee 100644 --- a/expr.cpp +++ b/expr.cpp @@ -212,11 +212,27 @@ lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, } if (dynamic_cast(fromType)) { - if (!failureOk) - Error(pos, "Can't convert function type \"%s\" to \"%s\" for %s.", - fromType->GetString().c_str(), - toType->GetString().c_str(), errorMsgBase); - return false; + if (dynamic_cast(toType) != NULL) { + // Convert function type to pointer to function type + if (expr != NULL) { + Expr *aoe = new AddressOfExpr(*expr, (*expr)->pos); + if (lDoTypeConv(aoe->GetType(), toType, &aoe, failureOk, + errorMsgBase, pos)) { + *expr = aoe; + return true; + } + } + else + return lDoTypeConv(PointerType::GetUniform(fromType), toType, NULL, + failureOk, errorMsgBase, pos); + } + else { + if (!failureOk) + Error(pos, "Can't convert function type \"%s\" to \"%s\" for %s.", + fromType->GetString().c_str(), + toType->GetString().c_str(), errorMsgBase); + return false; + } } if (dynamic_cast(toType)) { if (!failureOk) @@ -3434,10 +3450,15 @@ FunctionCallExpr::TypeCheck() { if (func == NULL) return NULL; - const PointerType *pt = - dynamic_cast(func->GetType()); - const FunctionType *ft = (pt == NULL) ? NULL : - dynamic_cast(pt->GetBaseType()); + const FunctionType *ft = + dynamic_cast(func->GetType()); + if (ft == NULL) { + const PointerType *pt = + dynamic_cast(func->GetType()); + ft = (pt == NULL) ? NULL : + dynamic_cast(pt->GetBaseType()); + } + if (ft == NULL) { Error(pos, "Valid function name must be used for function call."); return NULL; @@ -6774,6 +6795,34 @@ TypeCastExpr::GetBaseSymbol() const { } +static +llvm::Constant * +lConvertPointerConstant(llvm::Constant *c, const Type *constType) { + if (c == NULL || constType->IsUniformType()) + return c; + + // Handle conversion to int and then to vector of int or array of int + // (for varying and soa types, respectively) + llvm::Constant *intPtr = + llvm::ConstantExpr::getPtrToInt(c, LLVMTypes::PointerIntType); + Assert(constType->IsVaryingType() || constType->IsSOAType()); + int count = constType->IsVaryingType() ? g->target.vectorWidth : + constType->GetSOAWidth(); + + std::vector smear; + for (int i = 0; i < count; ++i) + smear.push_back(intPtr); + + if (constType->IsVaryingType()) + return llvm::ConstantVector::get(smear); + else { + LLVM_TYPE_CONST llvm::ArrayType *at = + llvm::ArrayType::get(LLVMTypes::PointerIntType, count); + return llvm::ConstantArray::get(at, smear); + } +} + + llvm::Constant * TypeCastExpr::GetConstant(const Type *constType) const { // We don't need to worry about most the basic cases where the type @@ -6781,11 +6830,18 @@ TypeCastExpr::GetConstant(const Type *constType) const { // TypeCastExpr::Optimize() method generally ends up doing the type // conversion and returning a ConstExpr, which in turn will have its // GetConstant() method called. However, because ConstExpr currently - // can't represent pointer values, we have to handle two cases here: - // 1. Null pointers (NULL, 0) valued initializers, and - // 2. Converting a uniform function pointer to a varying function - // pointer of the same type. - return expr->GetConstant(constType); + // can't represent pointer values, we have to handle a few cases + // related to pointers here: + // + // 1. Null pointer (NULL, 0) valued initializers + // 2. Converting function types to pointer-to-function types + // 3. And converting these from uniform to the varying/soa equivalents. + // + if (dynamic_cast(constType) == NULL) + return NULL; + + llvm::Constant *c = expr->GetConstant(constType->GetAsUniformType()); + return lConvertPointerConstant(c, constType); } @@ -7078,7 +7134,8 @@ AddressOfExpr::GetValue(FunctionEmitContext *ctx) const { return NULL; const Type *exprType = expr->GetType(); - if (dynamic_cast(exprType) != NULL) + if (dynamic_cast(exprType) != NULL || + dynamic_cast(exprType) != NULL) return expr->GetValue(ctx); else return expr->GetLValue(ctx); @@ -7093,8 +7150,18 @@ AddressOfExpr::GetType() const { const Type *exprType = expr->GetType(); if (dynamic_cast(exprType) != NULL) return PointerType::GetUniform(exprType->GetReferenceTarget()); - else - return expr->GetLValueType(); + + const Type *t = expr->GetLValueType(); + if (t != NULL) + return t; + else { + t = expr->GetType(); + if (t == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + return PointerType::GetUniform(t); + } } @@ -7118,7 +7185,22 @@ AddressOfExpr::Print() const { Expr * AddressOfExpr::TypeCheck() { - return this; + const Type *exprType; + if (expr == NULL || (exprType = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + if (dynamic_cast(exprType) != NULL|| + dynamic_cast(exprType) != NULL) { + return this; + } + + if (expr->GetLValueType() != NULL) + return this; + + Error(expr->pos, "Illegal to take address of non-lvalue or function."); + return NULL; } @@ -7134,6 +7216,29 @@ AddressOfExpr::EstimateCost() const { } +llvm::Constant * +AddressOfExpr::GetConstant(const Type *type) const { + const Type *exprType; + if (expr == NULL || (exprType = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + const PointerType *pt = dynamic_cast(type); + if (pt == NULL) + return NULL; + + const FunctionType *ft = + dynamic_cast(pt->GetBaseType()); + if (ft != NULL) { + llvm::Constant *c = expr->GetConstant(ft); + return lConvertPointerConstant(c, type); + } + else + return NULL; +} + + /////////////////////////////////////////////////////////////////////////// // SizeOfExpr @@ -7313,8 +7418,7 @@ FunctionSymbolExpr::GetType() const { return NULL; } - return matchingFunc ? - new PointerType(matchingFunc->type, Variability::Uniform, true) : NULL; + return matchingFunc ? matchingFunc->type : NULL; } @@ -7364,27 +7468,14 @@ FunctionSymbolExpr::GetConstant(const Type *type) const { if (matchingFunc == NULL || matchingFunc->function == NULL) return NULL; - const FunctionType *ft; - if (dynamic_cast(type) == NULL || - (ft = dynamic_cast(type->GetBaseType())) == NULL) + const FunctionType *ft = dynamic_cast(type); + if (ft == NULL) return NULL; - LLVM_TYPE_CONST llvm::Type *llvmUnifType = - type->GetAsUniformType()->LLVMType(g->ctx); - if (llvmUnifType != matchingFunc->function->getType()) + if (Type::Equal(type, matchingFunc->type) == false) return NULL; - if (type->IsUniformType()) - return matchingFunc->function; - else { - llvm::Constant *intPtr = - llvm::ConstantExpr::getPtrToInt(matchingFunc->function, - LLVMTypes::PointerIntType); - std::vector smear; - for (int i = 0; i < g->target.vectorWidth; ++i) - smear.push_back(intPtr); - return llvm::ConstantVector::get(smear); - } + return matchingFunc->function; } diff --git a/expr.h b/expr.h index e7461a1a..f7d112b9 100644 --- a/expr.h +++ b/expr.h @@ -584,6 +584,7 @@ public: Expr *TypeCheck(); Expr *Optimize(); int EstimateCost() const; + llvm::Constant *GetConstant(const Type *type) const; Expr *expr; }; diff --git a/tests/funcptr-null-1.ispc b/tests/funcptr-null-1.ispc index 05798918..784b5ada 100644 --- a/tests/funcptr-null-1.ispc +++ b/tests/funcptr-null-1.ispc @@ -15,7 +15,7 @@ export void f_f(uniform float RET[], uniform float aFOO[]) { float a = aFOO[programIndex]; float b = aFOO[0]-1; uniform FuncType func = foo; - RET[programIndex] = (func ? func : bar)(a, b); + RET[programIndex] = (func ? func : &bar)(a, b); } export void result(uniform float RET[]) { diff --git a/tests/funcptr-null-3.ispc b/tests/funcptr-null-3.ispc index 8e228315..3fd74da0 100644 --- a/tests/funcptr-null-3.ispc +++ b/tests/funcptr-null-3.ispc @@ -14,7 +14,7 @@ static float bar(float a, float b) { export void f_f(uniform float RET[], uniform float aFOO[]) { float a = aFOO[programIndex]; float b = aFOO[0]-1; - FuncType func = foo; + FuncType func = &foo; if (a == 2) func = NULL; if (func != NULL) diff --git a/tests/funcptr-null-6.ispc b/tests/funcptr-null-6.ispc index cf92c4a7..45bcfcdd 100644 --- a/tests/funcptr-null-6.ispc +++ b/tests/funcptr-null-6.ispc @@ -16,7 +16,7 @@ export void f_f(uniform float RET[], uniform float aFOO[]) { float b = aFOO[0]-1; FuncType func = NULL; if (a == 2) - func = foo; + func = &foo; if (!func) RET[programIndex] = -1; else diff --git a/tests/funcptr-uniform-2.ispc b/tests/funcptr-uniform-2.ispc index 849c9492..59d54b40 100644 --- a/tests/funcptr-uniform-2.ispc +++ b/tests/funcptr-uniform-2.ispc @@ -14,7 +14,7 @@ static float bar(float a, float b) { export void f_f(uniform float RET[], uniform float aFOO[]) { float a = aFOO[programIndex]; float b = aFOO[0]-1; - uniform FuncType func = bar; + uniform FuncType func = &bar; if (aFOO[0] == 1) func = foo; RET[programIndex] = func(a, b); diff --git a/tests_errors/addr-of-1.ispc b/tests_errors/addr-of-1.ispc new file mode 100644 index 00000000..4d770f01 --- /dev/null +++ b/tests_errors/addr-of-1.ispc @@ -0,0 +1,5 @@ +// Illegal to take address of non-lvalue or function + +void foo() { + int *ptr = &(1+1); +} diff --git a/type.cpp b/type.cpp index 2fb0a678..8fffb682 100644 --- a/type.cpp +++ b/type.cpp @@ -2695,6 +2695,17 @@ Type::MoreGeneralType(const Type *t0, const Type *t1, SourcePos pos, const char bool forceVarying, int vecSize) { Assert(reason != NULL); + // First, if one or both types are function types, convert them to + // pointer to function types and then try again. + if (dynamic_cast(t0) || + dynamic_cast(t1)) { + if (dynamic_cast(t0)) + t0 = PointerType::GetUniform(t0); + if (dynamic_cast(t1)) + t1 = PointerType::GetUniform(t1); + return MoreGeneralType(t0, t1, pos, reason, forceVarying, vecSize); + } + // First, if we need to go varying, promote both of the types to be // varying. if (t0->IsVaryingType() || t1->IsVaryingType() || forceVarying) {