support of operators

This commit is contained in:
Ilia Filippov
2013-10-07 15:43:31 +04:00
parent 2741e3c1d0
commit 2e724b095e
7 changed files with 299 additions and 45 deletions

135
expr.cpp
View File

@@ -1660,6 +1660,64 @@ BinaryExpr::BinaryExpr(Op o, Expr *a, Expr *b, SourcePos p)
arg1 = b;
}
Expr *lCreateBinaryOperatorCall(const BinaryExpr::Op bop,
Expr *a0, Expr *a1,
const SourcePos &sp)
{
if ((a0 == NULL) || (a1 == NULL)) {
return NULL;
}
Expr *arg0 = a0->TypeCheck();
Expr *arg1 = a1->TypeCheck();
if ((arg0 == NULL) || (arg1 == NULL)) {
return NULL;
}
const Type *type0 = arg0->GetType();
const Type *type1 = arg1->GetType();
// If either operand is a reference, dereference it before we move
// forward
if (CastType<ReferenceType>(type0) != NULL) {
arg0 = new RefDerefExpr(arg0, arg0->pos);
type0 = arg0->GetType();
}
if (CastType<ReferenceType>(type1) != NULL) {
arg1 = new RefDerefExpr(arg1, arg1->pos);
type1 = arg1->GetType();
}
if ((type0 == NULL) || (type1 == NULL)) {
return NULL;
}
if (CastType<StructType>(type0) != NULL ||
CastType<StructType>(type1) != NULL) {
std::string opName = std::string("operator") + lOpString(bop);
std::vector<Symbol *> funs;
m->symbolTable->LookupFunction(opName.c_str(), &funs);
if (funs.size() == 0) {
Error(sp, "operator %s(%s, %s) is not defined.",
opName.c_str(), (type0->GetString()).c_str(), (type1->GetString()).c_str());
return NULL;
}
Expr *func = new FunctionSymbolExpr(opName.c_str(), funs, sp);
ExprList *args = new ExprList(sp);
args->exprs.push_back(arg0);
args->exprs.push_back(arg1);
Expr *opCallExpr = new FunctionCallExpr(func, args, sp);
return opCallExpr;
}
return NULL;
}
Expr * MakeBinaryExpr(BinaryExpr::Op o, Expr *a, Expr *b, SourcePos p) {
Expr * op = lCreateBinaryOperatorCall(o, a, b, p);
if (op != NULL) {
return op;
}
op = new BinaryExpr(o, a, b, p);
return op;
}
/** Emit code for a && or || logical operator. In particular, the code
here handles "short-circuit" evaluation, where the second expression
@@ -2985,29 +3043,10 @@ AssignExpr::TypeCheck() {
if (lvalueIsReference)
lvalue = new RefDerefExpr(lvalue, lvalue->pos);
FunctionSymbolExpr *fse;
if ((fse = dynamic_cast<FunctionSymbolExpr *>(rvalue)) != NULL) {
// Special case to use the type of the LHS to resolve function
// overloads when we're assigning a function pointer where the
// function is overloaded.
const Type *lvalueType = lvalue->GetType();
const FunctionType *ftype;
if (CastType<PointerType>(lvalueType) == NULL ||
(ftype = CastType<FunctionType>(lvalueType->GetBaseType())) == NULL) {
Error(lvalue->pos, "Can't assign function pointer to type \"%s\".",
lvalueType ? lvalueType->GetString().c_str() : "<unknown>");
return NULL;
}
std::vector<const Type *> paramTypes;
for (int i = 0; i < ftype->GetNumParameters(); ++i)
paramTypes.push_back(ftype->GetParameterType(i));
if (!fse->ResolveOverloads(rvalue->pos, paramTypes)) {
Error(pos, "Unable to find overloaded function for function "
"pointer assignment.");
return NULL;
}
if (PossiblyResolveFunctionOverloads(rvalue, lvalue->GetType()) == false) {
Error(pos, "Unable to find overloaded function for function "
"pointer assignment.");
return NULL;
}
const Type *lhsType = lvalue->GetType();
@@ -3650,10 +3689,37 @@ FunctionCallExpr::GetLValue(FunctionEmitContext *ctx) const {
return NULL;
}
}
bool FullResolveOverloads(Expr * func, ExprList * args,
std::vector<const Type *> *argTypes,
std::vector<bool> *argCouldBeNULL,
std::vector<bool> *argIsConstant) {
for (unsigned int i = 0; i < args->exprs.size(); ++i) {
Expr *expr = args->exprs[i];
if (expr == NULL)
return false;
const Type *t = expr->GetType();
if (t == NULL)
return false;
argTypes->push_back(t);
argCouldBeNULL->push_back(lIsAllIntZeros(expr) || dynamic_cast<NullPointerExpr *>(expr));
argIsConstant->push_back(dynamic_cast<ConstExpr *>(expr) || dynamic_cast<NullPointerExpr *>(expr));
}
return true;
}
const Type *
FunctionCallExpr::GetType() const {
std::vector<const Type *> argTypes;
std::vector<bool> argCouldBeNULL, argIsConstant;
if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == true) {
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
if (fse != NULL) {
fse->ResolveOverloads(args->pos, argTypes, &argCouldBeNULL, &argIsConstant);
}
}
const FunctionType *ftype = lGetFunctionType(func);
return ftype ? ftype->GetReturnType() : NULL;
}
@@ -3689,20 +3755,9 @@ FunctionCallExpr::TypeCheck() {
std::vector<const Type *> argTypes;
std::vector<bool> argCouldBeNULL, argIsConstant;
for (unsigned int i = 0; i < args->exprs.size(); ++i) {
Expr *expr = args->exprs[i];
if (expr == NULL)
return NULL;
const Type *t = expr->GetType();
if (t == NULL)
return NULL;
argTypes.push_back(t);
argCouldBeNULL.push_back(lIsAllIntZeros(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
argIsConstant.push_back(dynamic_cast<ConstExpr *>(expr) ||
dynamic_cast<NullPointerExpr *>(expr));
if (FullResolveOverloads(func, args, &argTypes, &argCouldBeNULL, &argIsConstant) == false) {
return NULL;
}
FunctionSymbolExpr *fse = dynamic_cast<FunctionSymbolExpr *>(func);
@@ -7010,7 +7065,8 @@ TypeCastExpr::GetLValue(FunctionEmitContext *ctx) const {
const Type *
TypeCastExpr::GetType() const {
AssertPos(pos, type->HasUnboundVariability() == false);
// We have to switch off this assert after supporting of operators.
//AssertPos(pos, type->HasUnboundVariability() == false);
return type;
}
@@ -8190,6 +8246,9 @@ FunctionSymbolExpr::ResolveOverloads(SourcePos argPos,
const std::vector<bool> *argCouldBeNULL,
const std::vector<bool> *argIsConstant) {
const char *funName = candidateFunctions.front()->name.c_str();
if (triedToResolve == true) {
return true;
}
triedToResolve = true;

2
expr.h
View File

@@ -730,6 +730,8 @@ bool CanConvertTypes(const Type *fromType, const Type *toType,
*/
Expr *TypeConvertExpr(Expr *expr, const Type *toType, const char *errorMsgBase);
Expr * MakeBinaryExpr(BinaryExpr::Op o, Expr *a, Expr *b, SourcePos p);
/** Utility routine that emits code to initialize a symbol given an
initializer expression.

8
lex.ll
View File

@@ -419,6 +419,14 @@ while { RT; return TOKEN_WHILE; }
\"C\" { RT; return TOKEN_STRING_C_LITERAL; }
\.\.\. { RT; return TOKEN_DOTDOTDOT; }
"operator*" { return TOKEN_IDENTIFIER; }
"operator+" { return TOKEN_IDENTIFIER; }
"operator-" { return TOKEN_IDENTIFIER; }
"operator<<" { return TOKEN_IDENTIFIER; }
"operator>>" { return TOKEN_IDENTIFIER; }
"operator/" { return TOKEN_IDENTIFIER; }
"operator%" { return TOKEN_IDENTIFIER; }
L?\"(\\.|[^\\"])*\" { lStringConst(&yylval, &yylloc); return TOKEN_STRING_LITERAL; }
{IDENT} {

View File

@@ -468,27 +468,27 @@ cast_expression
multiplicative_expression
: cast_expression
| multiplicative_expression '*' cast_expression
{ $$ = new BinaryExpr(BinaryExpr::Mul, $1, $3, Union(@1, @3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Mul, $1, $3, Union(@1, @3)); }
| multiplicative_expression '/' cast_expression
{ $$ = new BinaryExpr(BinaryExpr::Div, $1, $3, Union(@1, @3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Div, $1, $3, Union(@1, @3)); }
| multiplicative_expression '%' cast_expression
{ $$ = new BinaryExpr(BinaryExpr::Mod, $1, $3, Union(@1, @3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Mod, $1, $3, Union(@1, @3)); }
;
additive_expression
: multiplicative_expression
| additive_expression '+' multiplicative_expression
{ $$ = new BinaryExpr(BinaryExpr::Add, $1, $3, Union(@1, @3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Add, $1, $3, Union(@1, @3)); }
| additive_expression '-' multiplicative_expression
{ $$ = new BinaryExpr(BinaryExpr::Sub, $1, $3, Union(@1, @3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Sub, $1, $3, Union(@1, @3)); }
;
shift_expression
: additive_expression
| shift_expression TOKEN_LEFT_OP additive_expression
{ $$ = new BinaryExpr(BinaryExpr::Shl, $1, $3, Union(@1,@3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Shl, $1, $3, Union(@1, @3)); }
| shift_expression TOKEN_RIGHT_OP additive_expression
{ $$ = new BinaryExpr(BinaryExpr::Shr, $1, $3, Union(@1,@3)); }
{ $$ = MakeBinaryExpr(BinaryExpr::Shr, $1, $3, Union(@1, @3)); }
;
relational_expression

70
tests/operators.ispc Normal file
View File

@@ -0,0 +1,70 @@
export uniform int width() { return programCount; }
struct S {
float a;
};
// References "struct&" were put in random order to test them.
struct S operator*(struct S& rr, struct S rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S operator/(struct S& rr, struct S& rv) {
struct S c;
c.a = rr.a - rr.a + 2;
return c;
}
struct S operator%(struct S rr, struct S& rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S operator+(struct S rr, struct S rv) {
struct S c;
c.a = rr.a / rv.a + 3;
return c;
}
struct S operator-(struct S rr, struct S& rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S operator>>(struct S& rr, struct S rv) {
struct S c;
c.a = rr.a / rv.a + 3;
return c;
}
struct S operator<<(struct S& rr, struct S& rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S a, a1;
struct S b, b1;
struct S d1, d2, d3, d4, d5, d6, d7;
export void f_f(uniform float RET[], uniform float aFOO[]) {
a.a = aFOO[programIndex];
b.a = -aFOO[programIndex];
d1 = a * b;
d2 = a / b;
d3 = a % b;
d4 = a + b;
d5 = a - b;
d6 = a >> b;
d7 = a << b;
RET[programIndex] = d1.a + d2.a + d3.a + d4.a + d5.a + d6.a + d7.a;
}
export void result(uniform float RET[4]) {
RET[programIndex] = 14;
}

64
tests/operators1.ispc Normal file
View File

@@ -0,0 +1,64 @@
export uniform int width() { return programCount; }
struct S {
float a;
};
// References "struct&" were put in random order to test them.
struct S operator*(struct S& rr, struct S rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S operator/(struct S& rr, struct S& rv) {
struct S c;
c.a = rr.a - rr.a + 2;
return c;
}
struct S operator%(struct S rr, struct S& rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S operator+(struct S rr, struct S rv) {
struct S c;
c.a = rr.a / rv.a + 3;
return c;
}
struct S operator-(struct S rr, struct S& rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S operator>>(struct S& rr, struct S rv) {
struct S c;
c.a = rr.a / rv.a + 3;
return c;
}
struct S operator<<(struct S& rr, struct S& rv) {
struct S c;
c.a = rr.a + rv.a + 2;
return c;
}
struct S a;
struct S b;
struct S d;
export void f_f(uniform float RET[], uniform float aFOO[]) {
a.a = 5;
b.a = -5;
d = a * b + b / a - a << (b - b) - a;
RET[programIndex] = d.a;
}
export void result(uniform float RET[4]) {
RET[programIndex] = 12;
}

51
tests/operators2.ispc Normal file
View File

@@ -0,0 +1,51 @@
int off;
export uniform int width() { return programCount; }
struct S {
float a;
};
struct S operator+(struct S rr, struct S rv) {
struct S c;
c.a = rr.a / rv.a + 3;
if (off == 1)
c.a = 22;
return c;
}
struct S operator/(struct S rr, struct S rv) {
struct S c;
c.a = rr.a + rv.a + 10;
if (off == 1)
c.a = 33;
return c;
}
struct S a;
struct S b;
struct S d;
export void f_f(uniform float RET[], uniform float aFOO[]) {
int T = programIndex;
a.a = aFOO[programIndex];
b.a = -aFOO[programIndex];
if (programIndex == 3)
off = 1;
else
off = 0;
if (T % 2)
d = a + b;
else
d = a / b;
RET[programIndex] = d.a;
}
export void result(uniform float RET[4]) {
if (programIndex % 2)
RET[programIndex] = 2;
else
RET[programIndex] = 10;
RET[3] = 22;
}