From 2e724b095e030b2d548758b965529735f184ca43 Mon Sep 17 00:00:00 2001 From: Ilia Filippov Date: Mon, 7 Oct 2013 15:43:31 +0400 Subject: [PATCH] support of operators --- expr.cpp | 135 ++++++++++++++++++++++++++++++------------ expr.h | 2 + lex.ll | 8 +++ parse.yy | 14 ++--- tests/operators.ispc | 70 ++++++++++++++++++++++ tests/operators1.ispc | 64 ++++++++++++++++++++ tests/operators2.ispc | 51 ++++++++++++++++ 7 files changed, 299 insertions(+), 45 deletions(-) create mode 100644 tests/operators.ispc create mode 100644 tests/operators1.ispc create mode 100644 tests/operators2.ispc diff --git a/expr.cpp b/expr.cpp index 614cb5e5..c92503e0 100644 --- a/expr.cpp +++ b/expr.cpp @@ -1660,6 +1660,64 @@ BinaryExpr::BinaryExpr(Op o, Expr *a, Expr *b, SourcePos p) arg1 = b; } +Expr *lCreateBinaryOperatorCall(const BinaryExpr::Op bop, + Expr *a0, Expr *a1, + const SourcePos &sp) +{ + if ((a0 == NULL) || (a1 == NULL)) { + return NULL; + } + Expr *arg0 = a0->TypeCheck(); + Expr *arg1 = a1->TypeCheck(); + if ((arg0 == NULL) || (arg1 == NULL)) { + return NULL; + } + const Type *type0 = arg0->GetType(); + const Type *type1 = arg1->GetType(); + + // If either operand is a reference, dereference it before we move + // forward + if (CastType(type0) != NULL) { + arg0 = new RefDerefExpr(arg0, arg0->pos); + type0 = arg0->GetType(); + } + if (CastType(type1) != NULL) { + arg1 = new RefDerefExpr(arg1, arg1->pos); + type1 = arg1->GetType(); + } + if ((type0 == NULL) || (type1 == NULL)) { + return NULL; + } + if (CastType(type0) != NULL || + CastType(type1) != NULL) { + std::string opName = std::string("operator") + lOpString(bop); + std::vector funs; + m->symbolTable->LookupFunction(opName.c_str(), &funs); + if (funs.size() == 0) { + Error(sp, "operator %s(%s, %s) is not defined.", + opName.c_str(), (type0->GetString()).c_str(), (type1->GetString()).c_str()); + return NULL; + } + Expr *func = new FunctionSymbolExpr(opName.c_str(), funs, sp); + ExprList *args = new ExprList(sp); + args->exprs.push_back(arg0); + args->exprs.push_back(arg1); + Expr *opCallExpr = new FunctionCallExpr(func, args, sp); + return opCallExpr; + } + return NULL; +} + + +Expr * MakeBinaryExpr(BinaryExpr::Op o, Expr *a, Expr *b, SourcePos p) { + Expr * op = lCreateBinaryOperatorCall(o, a, b, p); + if (op != NULL) { + return op; + } + op = new BinaryExpr(o, a, b, p); + return op; +} + /** Emit code for a && or || logical operator. In particular, the code here handles "short-circuit" evaluation, where the second expression @@ -2985,29 +3043,10 @@ AssignExpr::TypeCheck() { if (lvalueIsReference) lvalue = new RefDerefExpr(lvalue, lvalue->pos); - FunctionSymbolExpr *fse; - if ((fse = dynamic_cast(rvalue)) != NULL) { - // Special case to use the type of the LHS to resolve function - // overloads when we're assigning a function pointer where the - // function is overloaded. - const Type *lvalueType = lvalue->GetType(); - const FunctionType *ftype; - if (CastType(lvalueType) == NULL || - (ftype = CastType(lvalueType->GetBaseType())) == NULL) { - Error(lvalue->pos, "Can't assign function pointer to type \"%s\".", - lvalueType ? lvalueType->GetString().c_str() : ""); - return NULL; - } - - std::vector paramTypes; - for (int i = 0; i < ftype->GetNumParameters(); ++i) - paramTypes.push_back(ftype->GetParameterType(i)); - - if (!fse->ResolveOverloads(rvalue->pos, paramTypes)) { - Error(pos, "Unable to find overloaded function for function " - "pointer assignment."); - return NULL; - } + if (PossiblyResolveFunctionOverloads(rvalue, lvalue->GetType()) == false) { + Error(pos, "Unable to find overloaded function for function " + "pointer assignment."); + return NULL; } const Type *lhsType = lvalue->GetType(); @@ -3650,10 +3689,37 @@ FunctionCallExpr::GetLValue(FunctionEmitContext *ctx) const { return NULL; } } - + + +bool FullResolveOverloads(Expr * func, ExprList * args, + std::vector *argTypes, + std::vector *argCouldBeNULL, + std::vector *argIsConstant) { + for (unsigned int i = 0; i < args->exprs.size(); ++i) { + Expr *expr = args->exprs[i]; + if (expr == NULL) + return false; + const Type *t = expr->GetType(); + if (t == NULL) + return false; + argTypes->push_back(t); + argCouldBeNULL->push_back(lIsAllIntZeros(expr) || dynamic_cast(expr)); + argIsConstant->push_back(dynamic_cast(expr) || dynamic_cast(expr)); + } + return true; +} + const Type * FunctionCallExpr::GetType() const { + std::vector argTypes; + std::vector argCouldBeNULL, argIsConstant; + if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == true) { + FunctionSymbolExpr *fse = dynamic_cast(func); + if (fse != NULL) { + fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, &argIsConstant); + } + } const FunctionType *ftype = lGetFunctionType(func); return ftype ? ftype->GetReturnType() : NULL; } @@ -3689,20 +3755,9 @@ FunctionCallExpr::TypeCheck() { std::vector argTypes; std::vector argCouldBeNULL, argIsConstant; - for (unsigned int i = 0; i < args->exprs.size(); ++i) { - Expr *expr = args->exprs[i]; - if (expr == NULL) - return NULL; - const Type *t = expr->GetType(); - if (t == NULL) - return NULL; - - argTypes.push_back(t); - argCouldBeNULL.push_back(lIsAllIntZeros(expr) || - dynamic_cast(expr)); - argIsConstant.push_back(dynamic_cast(expr) || - dynamic_cast(expr)); + if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == false) { + return NULL; } FunctionSymbolExpr *fse = dynamic_cast(func); @@ -7010,7 +7065,8 @@ TypeCastExpr::GetLValue(FunctionEmitContext *ctx) const { const Type * TypeCastExpr::GetType() const { - AssertPos(pos, type->HasUnboundVariability() == false); + // We have to switch off this assert after supporting of operators. + //AssertPos(pos, type->HasUnboundVariability() == false); return type; } @@ -8190,6 +8246,9 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos, const std::vector *argCouldBeNULL, const std::vector *argIsConstant) { const char *funName = candidateFunctions.front()->name.c_str(); + if (triedToResolve == true) { + return true; + } triedToResolve = true; diff --git a/expr.h b/expr.h index 42fdff45..f8b96abd 100644 --- a/expr.h +++ b/expr.h @@ -730,6 +730,8 @@ bool CanConvertTypes(const Type *fromType, const Type *toType, */ Expr *TypeConvertExpr(Expr *expr, const Type *toType, const char *errorMsgBase); +Expr * MakeBinaryExpr(BinaryExpr::Op o, Expr *a, Expr *b, SourcePos p); + /** Utility routine that emits code to initialize a symbol given an initializer expression. diff --git a/lex.ll b/lex.ll index 3655220f..87a80145 100644 --- a/lex.ll +++ b/lex.ll @@ -419,6 +419,14 @@ while { RT; return TOKEN_WHILE; } \"C\" { RT; return TOKEN_STRING_C_LITERAL; } \.\.\. { RT; return TOKEN_DOTDOTDOT; } +"operator*" { return TOKEN_IDENTIFIER; } +"operator+" { return TOKEN_IDENTIFIER; } +"operator-" { return TOKEN_IDENTIFIER; } +"operator<<" { return TOKEN_IDENTIFIER; } +"operator>>" { return TOKEN_IDENTIFIER; } +"operator/" { return TOKEN_IDENTIFIER; } +"operator%" { return TOKEN_IDENTIFIER; } + L?\"(\\.|[^\\"])*\" { lStringConst(&yylval, &yylloc); return TOKEN_STRING_LITERAL; } {IDENT} { diff --git a/parse.yy b/parse.yy index 933a3455..38c5ba77 100644 --- a/parse.yy +++ b/parse.yy @@ -468,27 +468,27 @@ cast_expression multiplicative_expression : cast_expression | multiplicative_expression '*' cast_expression - { $$ = new BinaryExpr(BinaryExpr::Mul, $1, $3, Union(@1, @3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Mul, $1, $3, Union(@1, @3)); } | multiplicative_expression '/' cast_expression - { $$ = new BinaryExpr(BinaryExpr::Div, $1, $3, Union(@1, @3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Div, $1, $3, Union(@1, @3)); } | multiplicative_expression '%' cast_expression - { $$ = new BinaryExpr(BinaryExpr::Mod, $1, $3, Union(@1, @3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Mod, $1, $3, Union(@1, @3)); } ; additive_expression : multiplicative_expression | additive_expression '+' multiplicative_expression - { $$ = new BinaryExpr(BinaryExpr::Add, $1, $3, Union(@1, @3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Add, $1, $3, Union(@1, @3)); } | additive_expression '-' multiplicative_expression - { $$ = new BinaryExpr(BinaryExpr::Sub, $1, $3, Union(@1, @3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Sub, $1, $3, Union(@1, @3)); } ; shift_expression : additive_expression | shift_expression TOKEN_LEFT_OP additive_expression - { $$ = new BinaryExpr(BinaryExpr::Shl, $1, $3, Union(@1,@3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Shl, $1, $3, Union(@1, @3)); } | shift_expression TOKEN_RIGHT_OP additive_expression - { $$ = new BinaryExpr(BinaryExpr::Shr, $1, $3, Union(@1,@3)); } + { $$ = MakeBinaryExpr(BinaryExpr::Shr, $1, $3, Union(@1, @3)); } ; relational_expression diff --git a/tests/operators.ispc b/tests/operators.ispc new file mode 100644 index 00000000..95502bdd --- /dev/null +++ b/tests/operators.ispc @@ -0,0 +1,70 @@ + +export uniform int width() { return programCount; } + +struct S { + float a; +}; + +// References "struct&" were put in random order to test them. +struct S operator*(struct S& rr, struct S rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S operator/(struct S& rr, struct S& rv) { + struct S c; + c.a = rr.a - rr.a + 2; + return c; +} + +struct S operator%(struct S rr, struct S& rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S operator+(struct S rr, struct S rv) { + struct S c; + c.a = rr.a / rv.a + 3; + return c; +} + +struct S operator-(struct S rr, struct S& rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S operator>>(struct S& rr, struct S rv) { + struct S c; + c.a = rr.a / rv.a + 3; + return c; +} + +struct S operator<<(struct S& rr, struct S& rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S a, a1; +struct S b, b1; +struct S d1, d2, d3, d4, d5, d6, d7; + +export void f_f(uniform float RET[], uniform float aFOO[]) { + a.a = aFOO[programIndex]; + b.a = -aFOO[programIndex]; + d1 = a * b; + d2 = a / b; + d3 = a % b; + d4 = a + b; + d5 = a - b; + d6 = a >> b; + d7 = a << b; + RET[programIndex] = d1.a + d2.a + d3.a + d4.a + d5.a + d6.a + d7.a; +} + +export void result(uniform float RET[4]) { + RET[programIndex] = 14; +} diff --git a/tests/operators1.ispc b/tests/operators1.ispc new file mode 100644 index 00000000..f52c4c35 --- /dev/null +++ b/tests/operators1.ispc @@ -0,0 +1,64 @@ + +export uniform int width() { return programCount; } + +struct S { + float a; +}; + +// References "struct&" were put in random order to test them. +struct S operator*(struct S& rr, struct S rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S operator/(struct S& rr, struct S& rv) { + struct S c; + c.a = rr.a - rr.a + 2; + return c; +} + +struct S operator%(struct S rr, struct S& rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S operator+(struct S rr, struct S rv) { + struct S c; + c.a = rr.a / rv.a + 3; + return c; +} + +struct S operator-(struct S rr, struct S& rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S operator>>(struct S& rr, struct S rv) { + struct S c; + c.a = rr.a / rv.a + 3; + return c; +} + +struct S operator<<(struct S& rr, struct S& rv) { + struct S c; + c.a = rr.a + rv.a + 2; + return c; +} + +struct S a; +struct S b; +struct S d; + +export void f_f(uniform float RET[], uniform float aFOO[]) { + a.a = 5; + b.a = -5; + d = a * b + b / a - a << (b - b) - a; + RET[programIndex] = d.a; +} + +export void result(uniform float RET[4]) { + RET[programIndex] = 12; +} diff --git a/tests/operators2.ispc b/tests/operators2.ispc new file mode 100644 index 00000000..b732b24a --- /dev/null +++ b/tests/operators2.ispc @@ -0,0 +1,51 @@ +int off; + +export uniform int width() { return programCount; } + +struct S { + float a; +}; + +struct S operator+(struct S rr, struct S rv) { + struct S c; + c.a = rr.a / rv.a + 3; + if (off == 1) + c.a = 22; + return c; +} + +struct S operator/(struct S rr, struct S rv) { + struct S c; + c.a = rr.a + rv.a + 10; + if (off == 1) + c.a = 33; + return c; +} + +struct S a; +struct S b; +struct S d; + +export void f_f(uniform float RET[], uniform float aFOO[]) { + int T = programIndex; + a.a = aFOO[programIndex]; + b.a = -aFOO[programIndex]; + if (programIndex == 3) + off = 1; + else + off = 0; + if (T % 2) + d = a + b; + else + d = a / b; + + RET[programIndex] = d.a; +} + +export void result(uniform float RET[4]) { + if (programIndex % 2) + RET[programIndex] = 2; + else + RET[programIndex] = 10; + RET[3] = 22; +}