From 20044f5749121715f23576a971996926edc54f23 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Thu, 22 Mar 2012 06:48:02 -0700 Subject: [PATCH] Distinguish between dereferencing pointers and references. We now have separate Expr implementations for dereferencing pointers and automatically dereferencing references. This is in particular necessary so that we can detect attempts to dereference references with the '*' operator in programs and issue an error in that case. Fixes issue #192. --- ast.cpp | 21 ++-- expr.cpp | 171 +++++++++++++++++++++--------- expr.h | 40 +++++-- parse.yy | 2 +- stmt.cpp | 2 +- tests_errors/deref-4.ispc | 2 +- tests_errors/reference-deref.ispc | 10 ++ 7 files changed, 173 insertions(+), 75 deletions(-) create mode 100644 tests_errors/reference-deref.ispc diff --git a/ast.cpp b/ast.cpp index 55f09c34..c89f00bb 100644 --- a/ast.cpp +++ b/ast.cpp @@ -180,7 +180,8 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, MemberExpr *me; TypeCastExpr *tce; ReferenceExpr *re; - DereferenceExpr *dre; + PtrDerefExpr *ptrderef; + RefDerefExpr *refderef; SizeOfExpr *soe; AddressOfExpr *aoe; NewExpr *newe; @@ -221,8 +222,12 @@ WalkAST(ASTNode *node, ASTPreCallBackFunc preFunc, ASTPostCallBackFunc postFunc, tce->expr = (Expr *)WalkAST(tce->expr, preFunc, postFunc, data); else if ((re = dynamic_cast(node)) != NULL) re->expr = (Expr *)WalkAST(re->expr, preFunc, postFunc, data); - else if ((dre = dynamic_cast(node)) != NULL) - dre->expr = (Expr *)WalkAST(dre->expr, preFunc, postFunc, data); + else if ((ptrderef = dynamic_cast(node)) != NULL) + ptrderef->expr = (Expr *)WalkAST(ptrderef->expr, preFunc, postFunc, + data); + else if ((refderef = dynamic_cast(node)) != NULL) + refderef->expr = (Expr *)WalkAST(refderef->expr, preFunc, postFunc, + data); else if ((soe = dynamic_cast(node)) != NULL) soe->expr = (Expr *)WalkAST(soe->expr, preFunc, postFunc, data); else if ((aoe = dynamic_cast(node)) != NULL) @@ -417,13 +422,9 @@ lCheckAllOffSafety(ASTNode *node, void *data) { return false; } - DereferenceExpr *de; - if ((de = dynamic_cast(node)) != NULL) { - const Type *exprType = de->expr->GetType(); - if (dynamic_cast(exprType) != NULL) { - *okPtr = false; - return false; - } + if (dynamic_cast(node) != NULL) { + *okPtr = false; + return false; } return true; diff --git a/expr.cpp b/expr.cpp index ecd6a8c5..17541012 100644 --- a/expr.cpp +++ b/expr.cpp @@ -401,7 +401,7 @@ lDoTypeConv(const Type *fromType, const Type *toType, Expr **expr, else { // convert from a reference T -> T if (expr != NULL) { - Expr *drExpr = new DereferenceExpr(*expr, pos); + Expr *drExpr = new RefDerefExpr(*expr, pos); if (lDoTypeConv(drExpr->GetType(), toType, &drExpr, failureOk, errorMsgBase, pos) == true) { *expr = drExpr; @@ -979,7 +979,7 @@ lEmitPrePostIncDec(UnaryExpr::Op op, Expr *expr, SourcePos pos, type = type->GetReferenceTarget(); lvalue = expr->GetValue(ctx); - Expr *deref = new DereferenceExpr(expr, expr->pos); + Expr *deref = new RefDerefExpr(expr, expr->pos); rvalue = deref->GetValue(ctx); } else { @@ -1239,7 +1239,7 @@ UnaryExpr::TypeCheck() { // don't do this for pre/post increment/decrement if (dynamic_cast(type)) { - expr = new DereferenceExpr(expr, pos); + expr = new RefDerefExpr(expr, pos); type = expr->GetType(); } @@ -2225,12 +2225,12 @@ BinaryExpr::TypeCheck() { // If either operand is a reference, dereference it before we move // forward if (dynamic_cast(type0) != NULL) { - arg0 = new DereferenceExpr(arg0, arg0->pos); + arg0 = new RefDerefExpr(arg0, arg0->pos); type0 = arg0->GetType(); Assert(type0 != NULL); } if (dynamic_cast(type1) != NULL) { - arg1 = new DereferenceExpr(arg1, arg1->pos); + arg1 = new RefDerefExpr(arg1, arg1->pos); type1 = arg1->GetType(); Assert(type1 != NULL); } @@ -2742,7 +2742,7 @@ AssignExpr::TypeCheck() { bool lvalueIsReference = dynamic_cast(lvalue->GetType()) != NULL; if (lvalueIsReference) - lvalue = new DereferenceExpr(lvalue, lvalue->pos); + lvalue = new RefDerefExpr(lvalue, lvalue->pos); FunctionSymbolExpr *fse; if ((fse = dynamic_cast(rvalue)) != NULL) { @@ -4637,7 +4637,7 @@ MemberExpr::create(Expr *e, const char *id, SourcePos p, SourcePos idpos, const ReferenceType *referenceType = dynamic_cast(exprType); if (referenceType != NULL) { - e = new DereferenceExpr(e, e->pos); + e = new RefDerefExpr(e, e->pos); exprType = e->GetType(); Assert(exprType != NULL); } @@ -6847,16 +6847,16 @@ ReferenceExpr::Print() const { /////////////////////////////////////////////////////////////////////////// -// DereferenceExpr +// DerefExpr -DereferenceExpr::DereferenceExpr(Expr *e, SourcePos p) +DerefExpr::DerefExpr(Expr *e, SourcePos p) : Expr(p) { expr = e; } llvm::Value * -DereferenceExpr::GetValue(FunctionEmitContext *ctx) const { +DerefExpr::GetValue(FunctionEmitContext *ctx) const { if (expr == NULL) return NULL; llvm::Value *ptr = expr->GetValue(ctx); @@ -6879,7 +6879,7 @@ DereferenceExpr::GetValue(FunctionEmitContext *ctx) const { llvm::Value * -DereferenceExpr::GetLValue(FunctionEmitContext *ctx) const { +DerefExpr::GetLValue(FunctionEmitContext *ctx) const { if (expr == NULL) return NULL; return expr->GetValue(ctx); @@ -6887,7 +6887,7 @@ DereferenceExpr::GetLValue(FunctionEmitContext *ctx) const { const Type * -DereferenceExpr::GetLValueType() const { +DerefExpr::GetLValueType() const { if (expr == NULL) return NULL; return expr->GetType(); @@ -6895,64 +6895,70 @@ DereferenceExpr::GetLValueType() const { Symbol * -DereferenceExpr::GetBaseSymbol() const { +DerefExpr::GetBaseSymbol() const { return expr ? expr->GetBaseSymbol() : NULL; } -const Type * -DereferenceExpr::GetType() const { +Expr * +DerefExpr::Optimize() { if (expr == NULL) return NULL; - const Type *exprType = expr->GetType(); - if (exprType == NULL) - return NULL; - if (dynamic_cast(exprType) != NULL) - return exprType->GetReferenceTarget(); - else { - Assert(dynamic_cast(exprType) != NULL); - if (exprType->IsUniformType()) - return exprType->GetBaseType(); - else - return exprType->GetBaseType()->GetAsVaryingType(); - } -} - - -Expr * -DereferenceExpr::TypeCheck() { - if (expr == NULL) { - Assert(m->errorCount > 0); - return NULL; - } - - if (dynamic_cast(expr->GetType()) == NULL && - dynamic_cast(expr->GetType()) == NULL) { - Error(pos, "Illegal to dereference non-pointer or reference " - "type \"%s\".", expr->GetType()->GetString().c_str()); - return NULL; - } - return this; } -Expr * -DereferenceExpr::Optimize() { - if (expr == NULL) +/////////////////////////////////////////////////////////////////////////// +// PtrDerefExpr + +PtrDerefExpr::PtrDerefExpr(Expr *e, SourcePos p) + : DerefExpr(e, p) { +} + + +const Type * +PtrDerefExpr::GetType() const { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); return NULL; + } + Assert(dynamic_cast(type) != NULL); + + if (type->IsUniformType()) + return type->GetBaseType(); + else + return type->GetBaseType()->GetAsVaryingType(); +} + + +Expr * +PtrDerefExpr::TypeCheck() { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + if (dynamic_cast(type) == NULL) { + Error(pos, "Illegal to dereference non-pointer type \"%s\".", + type->GetString().c_str()); + return NULL; + } + return this; } int -DereferenceExpr::EstimateCost() const { - if (expr == NULL) +PtrDerefExpr::EstimateCost() const { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); return 0; + } - const Type *exprType = expr->GetType(); - if (dynamic_cast(exprType) && - exprType->IsVaryingType()) + if (type->IsVaryingType()) // Be pessimistic; some of these will later be optimized into // vector loads/stores.. return COST_GATHER + COST_DEREF; @@ -6962,7 +6968,7 @@ DereferenceExpr::EstimateCost() const { void -DereferenceExpr::Print() const { +PtrDerefExpr::Print() const { if (expr == NULL || GetType() == NULL) return; @@ -6973,6 +6979,65 @@ DereferenceExpr::Print() const { } +/////////////////////////////////////////////////////////////////////////// +// RefDerefExpr + +RefDerefExpr::RefDerefExpr(Expr *e, SourcePos p) + : DerefExpr(e, p) { +} + + +const Type * +RefDerefExpr::GetType() const { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + Assert(dynamic_cast(type) != NULL); + return type->GetReferenceTarget(); +} + + +Expr * +RefDerefExpr::TypeCheck() { + const Type *type; + if (expr == NULL || (type = expr->GetType()) == NULL) { + Assert(m->errorCount > 0); + return NULL; + } + + // We only create RefDerefExprs internally for references in + // expressions, so we should never create one with a non-reference + // expression... + Assert(dynamic_cast(type) != NULL); + + return this; +} + + +int +RefDerefExpr::EstimateCost() const { + if (expr == NULL) + return 0; + + return COST_DEREF; +} + + +void +RefDerefExpr::Print() const { + if (expr == NULL || GetType() == NULL) + return; + + printf("[%s] deref-reference (", GetType()->GetString().c_str()); + expr->Print(); + printf(")"); + pos.Print(); +} + + /////////////////////////////////////////////////////////////////////////// // AddressOfExpr diff --git a/expr.h b/expr.h index e0d1348c..5c59ae83 100644 --- a/expr.h +++ b/expr.h @@ -1,5 +1,5 @@ /* - Copyright (c) 2010-2011, Intel Corporation + Copyright (c) 2010-2012, Intel Corporation All rights reserved. Redistribution and use in source and binary forms, with or without @@ -530,26 +530,48 @@ public: }; -/** @brief Expression that represents dereferencing a reference to get its - value. */ -class DereferenceExpr : public Expr { +/** @brief Common base class that provides shared functionality for + PtrDerefExpr and RefDerefExpr. */ +class DerefExpr : public Expr { public: - DereferenceExpr(Expr *e, SourcePos p); + DerefExpr(Expr *e, SourcePos p); llvm::Value *GetValue(FunctionEmitContext *ctx) const; llvm::Value *GetLValue(FunctionEmitContext *ctx) const; - const Type *GetType() const; const Type *GetLValueType() const; Symbol *GetBaseSymbol() const; - void Print() const; - Expr *TypeCheck(); Expr *Optimize(); - int EstimateCost() const; Expr *expr; }; +/** @brief Expression that represents dereferencing a pointer to get its + value. */ +class PtrDerefExpr : public DerefExpr { +public: + PtrDerefExpr(Expr *e, SourcePos p); + + const Type *GetType() const; + void Print() const; + Expr *TypeCheck(); + int EstimateCost() const; +}; + + +/** @brief Expression that represents dereferencing a reference to get its + value. */ +class RefDerefExpr : public DerefExpr { +public: + RefDerefExpr(Expr *e, SourcePos p); + + const Type *GetType() const; + void Print() const; + Expr *TypeCheck(); + int EstimateCost() const; +}; + + /** Expression that represents taking the address of an expression. */ class AddressOfExpr : public Expr { public: diff --git a/parse.yy b/parse.yy index f962d0f3..1fa8336f 100644 --- a/parse.yy +++ b/parse.yy @@ -400,7 +400,7 @@ unary_expression | '&' unary_expression { $$ = new AddressOfExpr($2, Union(@1, @2)); } | '*' unary_expression - { $$ = new DereferenceExpr($2, Union(@1, @2)); } + { $$ = new PtrDerefExpr($2, Union(@1, @2)); } | '+' cast_expression { $$ = $2; } | '-' cast_expression diff --git a/stmt.cpp b/stmt.cpp index 2e3e1da9..d9018f02 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -2426,7 +2426,7 @@ lProcessPrintArg(Expr *expr, FunctionEmitContext *ctx, std::string &argTypes) { return NULL; if (dynamic_cast(type) != NULL) { - expr = new DereferenceExpr(expr, expr->pos); + expr = new RefDerefExpr(expr, expr->pos); type = expr->GetType(); if (type == NULL) return NULL; diff --git a/tests_errors/deref-4.ispc b/tests_errors/deref-4.ispc index 79c277a8..aa5dbc9f 100644 --- a/tests_errors/deref-4.ispc +++ b/tests_errors/deref-4.ispc @@ -1,4 +1,4 @@ -// Illegal to dereference non-pointer or reference type "varying float" +// Illegal to dereference non-pointer type "varying float" float func(float a) { *a = 0; diff --git a/tests_errors/reference-deref.ispc b/tests_errors/reference-deref.ispc new file mode 100644 index 00000000..a0cd484f --- /dev/null +++ b/tests_errors/reference-deref.ispc @@ -0,0 +1,10 @@ +// Illegal to dereference non-pointer type "uniform float &". + +export void simple_reduction(uniform float vin[], uniform int w, uniform float & result) +{ + float sum = 0; + foreach (i = 0 ... w) { + sum += vin[i]; + } + *result = reduce_add(sum); // << I would expect this to produce a compiler error +}