diff --git a/ast.cpp b/ast.cpp index 55f09c34..c89f00bb 100644 --- a/ast.cpp +++ b/ast.cpp @@ -180,7 +180,8 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, MemberExpr *me; TypeCastExpr *tce; ReferenceExpr *re; - DereferenceExpr *dre; + PtrDerefExpr *ptrderef; + RefDerefExpr *refderef; SizeOfExpr *soe; AddressOfExpr *aoe; NewExpr *newe; @@ -221,8 +222,12 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); else if ((re = dynamic_cast(node)) != NULL) re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); - else if ((dre = dynamic_cast(node)) != NULL) - dre->expr = (Expr *)WalkAST(dre->expr, preFunc, postFunc, data); + else if ((ptrderef = dynamic_cast(node)) != NULL) + ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, + data); + else if ((refderef = dynamic_cast(node)) != NULL) + refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, + data); else if ((soe = dynamic_cast(node)) != NULL) soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); else if ((aoe = dynamic_cast(node)) != NULL) @@ -417,13 +422,9 @@ lCheckAllOffSafety(ASTNode *node, void *data) { return false; } - DereferenceExpr *de; - if ((de = dynamic_cast(node)) != NULL) { - const Type *exprType = de->expr->GetType(); - if (dynamic_cast(exprType) != NULL) { - *okPtr = false; - return false; - } + if (dynamic_cast(node) != NULL) { + *okPtr = false; + return false; } return true; diff --git a/expr.cpp b/expr.cpp index ecd6a8c5..17541012 100644 --- a/expr.cpp +++ b/expr.cpp @@ -401,7 +401,7 @@ lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, else { // convert from a reference T -> T if (expr != NULL) { - Expr *drExpr = new DereferenceExpr(*expr, pos); + Expr *drExpr = new RefDerefExpr(*expr, pos); if (lDoTypeConv(drExpr->GetType(), toType, &drExpr, failureOk, errorMsgBase, pos) == true) { *expr = drExpr; @@ -979,7 +979,7 @@ lEmitPrePostIncDec(UnaryExpr::Op op, Expr *expr, SourcePos pos, type = type->GetReferenceTarget(); lvalue = expr->GetValue(ctx); - Expr *deref = new DereferenceExpr(expr, expr->pos); + Expr *deref = new RefDerefExpr(expr, expr->pos); rvalue = deref->GetValue(ctx); } else { @@ -1239,7 +1239,7 @@ UnaryExpr::TypeCheck() { // don't do this for pre/post increment/decrement if (dynamic_cast(type)) { - expr = new DereferenceExpr(expr, pos); + expr = new RefDerefExpr(expr, pos); type = expr->GetType(); } @@ -2225,12 +2225,12 @@ BinaryExpr::TypeCheck() { // If either operand is a reference, dereference it before we move // forward if (dynamic_cast(type0) != NULL) { - arg0 = new DereferenceExpr(arg0, arg0->pos); + arg0 = new RefDerefExpr(arg0, arg0->pos); type0 = arg0->GetType(); Assert(type0 != NULL); } if (dynamic_cast(type1) != NULL) { - arg1 = new DereferenceExpr(arg1, arg1->pos); + arg1 = new RefDerefExpr(arg1, arg1->pos); type1 = arg1->GetType(); Assert(type1 != NULL); } @@ -2742,7 +2742,7 @@ AssignExpr::TypeCheck() { bool lvalueIsReference = dynamic_cast(lvalue->GetType()) != NULL; if (lvalueIsReference) - lvalue = new DereferenceExpr(lvalue, lvalue->pos); + lvalue = new RefDerefExpr(lvalue, lvalue->pos); FunctionSymbolExpr *fse; if ((fse = dynamic_cast(rvalue)) != NULL) { @@ -4637,7 +4637,7 @@ MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos, const ReferenceType *referenceType = dynamic_cast(exprType); if (referenceType != NULL) { - e = new DereferenceExpr(e, e->pos); + e = new RefDerefExpr(e, e->pos); exprType = e->GetType(); Assert(exprType != NULL); } @@ -6847,16 +6847,16 @@ ReferenceExpr::Print() const { /////////////////////////////////////////////////////////////////////////// -// DereferenceExpr +// DerefExpr -DereferenceExpr::DereferenceExpr(Expr *e, SourcePos p) +DerefExpr::DerefExpr(Expr *e, SourcePos p) : Expr(p) { expr = e; } llvm::Value * -DereferenceExpr::GetValue(FunctionEmitContext *ctx) const { +DerefExpr::GetValue(FunctionEmitContext *ctx) const { if (expr == NULL) return NULL; llvm::Value *ptr = expr->GetValue(ctx); @@ -6879,7 +6879,7 @@ DereferenceExpr::GetValue(FunctionEmitContext *ctx) const { llvm::Value * -DereferenceExpr::GetLValue(FunctionEmitContext *ctx) const { +DerefExpr::GetLValue(FunctionEmitContext *ctx) const { if (expr == NULL) return NULL; return expr->GetValue(ctx); @@ -6887,7 +6887,7 @@ DereferenceExpr::GetLValue(FunctionEmitContext *ctx) const { const Type * -DereferenceExpr::GetLValueType() const { +DerefExpr::GetLValueType() const { if (expr == NULL) return NULL; return expr->GetType(); @@ -6895,64 +6895,70 @@ DereferenceExpr::GetLValueType() const { Symbol * -DereferenceExpr::GetBaseSymbol() const { +DerefExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; } -const Type * -DereferenceExpr::GetType() const { +Expr * +DerefExpr::Optimize() { if (expr == NULL) return NULL; - const Type *exprType = expr->GetType(); - if (exprType == NULL) - return NULL; - if (dynamic_cast(exprType) != NULL) - return exprType->GetReferenceTarget(); - else { - Assert(dynamic_cast(exprType) != NULL); - if (exprType->IsUniformType()) - return exprType->GetBaseType(); - else - return exprType->GetBaseType()->GetAsVaryingType(); - } -} - - -Expr * -DereferenceExpr::TypeCheck() { - if (expr == NULL) { - Assert(m->errorCount > 0); - return NULL; - } - - if (dynamic_cast(expr->GetType()) == NULL && - dynamic_cast(expr->GetType()) == NULL) { - Error(pos, "Illegal to dereference non-pointer or reference " - "type \"%s\".", expr->GetType()->GetString().c_str()); - return NULL; - } - return this; } -Expr * -DereferenceExpr::Optimize() { - if (expr == NULL) +/////////////////////////////////////////////////////////////////////////// +// PtrDerefExpr + +PtrDerefExpr::PtrDerefExpr(Expr *e, SourcePos p) + : DerefExpr(e, p) { +} + + +const Type * +PtrDerefExpr::GetType() const { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); return NULL; + } + Assert(dynamic_cast(type) != NULL); + + if (type->IsUniformType()) + return type->GetBaseType(); + else + return type->GetBaseType()->GetAsVaryingType(); +} + + +Expr * +PtrDerefExpr::TypeCheck() { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + if (dynamic_cast(type) == NULL) { + Error(pos, "Illegal to dereference non-pointer type \"%s\".", + type->GetString().c_str()); + return NULL; + } + return this; } int -DereferenceExpr::EstimateCost() const { - if (expr == NULL) +PtrDerefExpr::EstimateCost() const { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); return 0; + } - const Type *exprType = expr->GetType(); - if (dynamic_cast(exprType) && - exprType->IsVaryingType()) + if (type->IsVaryingType()) // Be pessimistic; some of these will later be optimized into // vector loads/stores.. return COST_GATHER + COST_DEREF; @@ -6962,7 +6968,7 @@ DereferenceExpr::EstimateCost() const { void -DereferenceExpr::Print() const { +PtrDerefExpr::Print() const { if (expr == NULL || GetType() == NULL) return; @@ -6973,6 +6979,65 @@ DereferenceExpr::Print() const { } +/////////////////////////////////////////////////////////////////////////// +// RefDerefExpr + +RefDerefExpr::RefDerefExpr(Expr *e, SourcePos p) + : DerefExpr(e, p) { +} + + +const Type * +RefDerefExpr::GetType() const { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + Assert(dynamic_cast(type) != NULL); + return type->GetReferenceTarget(); +} + + +Expr * +RefDerefExpr::TypeCheck() { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + // We only create RefDerefExprs internally for references in + // expressions, so we should never create one with a non-reference + // expression... + Assert(dynamic_cast(type) != NULL); + + return this; +} + + +int +RefDerefExpr::EstimateCost() const { + if (expr == NULL) + return 0; + + return COST_DEREF; +} + + +void +RefDerefExpr::Print() const { + if (expr == NULL || GetType() == NULL) + return; + + printf("[%s] deref-reference (", GetType()->GetString().c_str()); + expr->Print(); + printf(")"); + pos.Print(); +} + + /////////////////////////////////////////////////////////////////////////// // AddressOfExpr diff --git a/expr.h b/expr.h index e0d1348c..5c59ae83 100644 --- a/expr.h +++ b/expr.h @@ -1,5 +1,5 @@ /* - Copyright (c) 2010-2011, Intel Corporation + Copyright (c) 2010-2012, Intel Corporation All rights reserved. Redistribution and use in source and binary forms, with or without @@ -530,26 +530,48 @@ public: }; -/** @brief Expression that represents dereferencing a reference to get its - value. */ -class DereferenceExpr : public Expr { +/** @brief Common base class that provides shared functionality for + PtrDerefExpr and RefDerefExpr. */ +class DerefExpr : public Expr { public: - DereferenceExpr(Expr *e, SourcePos p); + DerefExpr(Expr *e, SourcePos p); llvm::Value *GetValue(FunctionEmitContext *ctx) const; llvm::Value *GetLValue(FunctionEmitContext *ctx) const; - const Type *GetType() const; const Type *GetLValueType() const; Symbol *GetBaseSymbol() const; - void Print() const; - Expr *TypeCheck(); Expr *Optimize(); - int EstimateCost() const; Expr *expr; }; +/** @brief Expression that represents dereferencing a pointer to get its + value. */ +class PtrDerefExpr : public DerefExpr { +public: + PtrDerefExpr(Expr *e, SourcePos p); + + const Type *GetType() const; + void Print() const; + Expr *TypeCheck(); + int EstimateCost() const; +}; + + +/** @brief Expression that represents dereferencing a reference to get its + value. */ +class RefDerefExpr : public DerefExpr { +public: + RefDerefExpr(Expr *e, SourcePos p); + + const Type *GetType() const; + void Print() const; + Expr *TypeCheck(); + int EstimateCost() const; +}; + + /** Expression that represents taking the address of an expression. */ class AddressOfExpr : public Expr { public: diff --git a/parse.yy b/parse.yy index f962d0f3..1fa8336f 100644 --- a/parse.yy +++ b/parse.yy @@ -400,7 +400,7 @@ unary_expression | '&' unary_expression { $$ = new AddressOfExpr($2, Union(@1, @2)); } | '*' unary_expression - { $$ = new DereferenceExpr($2, Union(@1, @2)); } + { $$ = new PtrDerefExpr($2, Union(@1, @2)); } | '+' cast_expression { $$ = $2; } | '-' cast_expression diff --git a/stmt.cpp b/stmt.cpp index 2e3e1da9..d9018f02 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -2426,7 +2426,7 @@ lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, std::string &argTypes) { return NULL; if (dynamic_cast(type) != NULL) { - expr = new DereferenceExpr(expr, expr->pos); + expr = new RefDerefExpr(expr, expr->pos); type = expr->GetType(); if (type == NULL) return NULL; diff --git a/tests_errors/deref-4.ispc b/tests_errors/deref-4.ispc index 79c277a8..aa5dbc9f 100644 --- a/tests_errors/deref-4.ispc +++ b/tests_errors/deref-4.ispc @@ -1,4 +1,4 @@ -// Illegal to dereference non-pointer or reference type "varying float" +// Illegal to dereference non-pointer type "varying float" float func(float a) { *a = 0; diff --git a/tests_errors/reference-deref.ispc b/tests_errors/reference-deref.ispc new file mode 100644 index 00000000..a0cd484f --- /dev/null +++ b/tests_errors/reference-deref.ispc @@ -0,0 +1,10 @@ +// Illegal to dereference non-pointer type "uniform float &". + +export void simple_reduction(uniform float vin[], uniform int w, uniform float & result) +{ + float sum = 0; + foreach (i = 0 ... w) { + sum += vin[i]; + } + *result = reduce_add(sum); // << I would expect this to produce a compiler error +}