Short-circuit evaluation of && and || operators.
We now follow C's approach of evaluating these: we don't evaluate the second expression in the operator if the value of the first one determines the overall result. Thus, these can now be used idiomatically like (index < limit && array[index] > 0) and such. For varying expressions, the mask is set appropriately when evaluating the second expression. (For expressions that can be determined to be both simple and safe to evaluate with the mask all off, we still evaluate both sides and compute the logical op result directly, which saves a number of branches and tests. However, the effect of this should never be visible to the programmer.) Issue #4.
This commit is contained in:
309
expr.cpp
309
expr.cpp
@@ -1405,13 +1405,274 @@ BinaryExpr::BinaryExpr(Op o, Expr *a, Expr *b, SourcePos p)
|
||||
}
|
||||
|
||||
|
||||
/** Emit code for a && or || logical operator. In particular, the code
|
||||
here handles "short-circuit" evaluation, where the second expression
|
||||
isn't evaluated if the value of the first one determines the value of
|
||||
the result.
|
||||
*/
|
||||
llvm::Value *
|
||||
lEmitLogicalOp(BinaryExpr::Op op, Expr *arg0, Expr *arg1,
|
||||
FunctionEmitContext *ctx, SourcePos pos) {
|
||||
|
||||
const Type *type0 = arg0->GetType(), *type1 = arg1->GetType();
|
||||
if (type0 == NULL || type1 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// There is overhead (branches, etc.), to short-circuiting, so if the
|
||||
// right side of the expression is a) relatively simple, and b) can be
|
||||
// safely executed with an all-off execution mask, then we just
|
||||
// evaluate both sides and then the logical operator in that case.
|
||||
// FIXME: not sure what we should do about vector types here...
|
||||
bool shortCircuit = (EstimateCost(arg1) > PREDICATE_SAFE_IF_STATEMENT_COST ||
|
||||
SafeToRunWithMaskAllOff(arg1) == false ||
|
||||
dynamic_cast<const VectorType *>(type0) != NULL ||
|
||||
dynamic_cast<const VectorType *>(type1) != NULL);
|
||||
if (shortCircuit == false) {
|
||||
// If one of the operands is uniform but the other is varying,
|
||||
// promote the uniform one to varying
|
||||
if (type0->IsUniformType() && type1->IsVaryingType()) {
|
||||
arg0 = TypeConvertExpr(arg0, AtomicType::VaryingBool, lOpString(op));
|
||||
Assert(arg0 != NULL);
|
||||
}
|
||||
if (type1->IsUniformType() && type0->IsVaryingType()) {
|
||||
arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, lOpString(op));
|
||||
Assert(arg1 != NULL);
|
||||
}
|
||||
|
||||
llvm::Value *value0 = arg0->GetValue(ctx);
|
||||
llvm::Value *value1 = arg1->GetValue(ctx);
|
||||
if (value0 == NULL || value1 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (op == BinaryExpr::LogicalAnd)
|
||||
return ctx->BinaryOperator(llvm::Instruction::And, value0, value1,
|
||||
"logical_and");
|
||||
else {
|
||||
Assert(op == BinaryExpr::LogicalOr);
|
||||
return ctx->BinaryOperator(llvm::Instruction::Or, value0, value1,
|
||||
"logical_or");
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate temporary storage for the return value
|
||||
const Type *retType = Type::MoreGeneralType(type0, type1, pos, lOpString(op));
|
||||
LLVM_TYPE_CONST llvm::Type *llvmRetType = retType->LLVMType(g->ctx);
|
||||
llvm::Value *retPtr = ctx->AllocaInst(llvmRetType, "logical_op_mem");
|
||||
|
||||
llvm::BasicBlock *bbSkipEvalValue1 = ctx->CreateBasicBlock("skip_eval_1");
|
||||
llvm::BasicBlock *bbEvalValue1 = ctx->CreateBasicBlock("eval_1");
|
||||
llvm::BasicBlock *bbLogicalDone = ctx->CreateBasicBlock("logical_op_done");
|
||||
|
||||
// Evaluate the first operand
|
||||
llvm::Value *value0 = arg0->GetValue(ctx);
|
||||
if (value0 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (type0->IsUniformType()) {
|
||||
// Check to see if the value of the first operand is true or false
|
||||
llvm::Value *value0True =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
|
||||
value0, LLVMTrue);
|
||||
|
||||
if (op == BinaryExpr::LogicalOr) {
|
||||
// For ||, if value0 is true, then we skip evaluating value1
|
||||
// entirely.
|
||||
ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, value0True);
|
||||
|
||||
// If value0 is true, the complete result is true (either
|
||||
// uniform or varying)
|
||||
ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
|
||||
llvm::Value *trueValue = retType->IsUniformType() ? LLVMTrue :
|
||||
LLVMMaskAllOn;
|
||||
ctx->StoreInst(trueValue, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
}
|
||||
else {
|
||||
Assert(op == BinaryExpr::LogicalAnd);
|
||||
|
||||
// Conversely, for &&, if value0 is false, we skip evaluating
|
||||
// value1.
|
||||
ctx->BranchInst(bbEvalValue1, bbSkipEvalValue1, value0True);
|
||||
|
||||
// In this case, the complete result is false (again, either a
|
||||
// uniform or varying false).
|
||||
ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
|
||||
llvm::Value *falseValue = retType->IsUniformType() ? LLVMFalse :
|
||||
LLVMMaskAllOff;
|
||||
ctx->StoreInst(falseValue, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
}
|
||||
|
||||
// Both || and && are in the same situation if the first operand's
|
||||
// value didn't resolve the final result: they need to evaluate the
|
||||
// value of the second operand, which in turn gives the value for
|
||||
// the full expression.
|
||||
ctx->SetCurrentBasicBlock(bbEvalValue1);
|
||||
if (type1->IsUniformType() && retType->IsVaryingType()) {
|
||||
arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, "logical op");
|
||||
Assert(arg1 != NULL);
|
||||
}
|
||||
|
||||
llvm::Value *value1 = arg1->GetValue(ctx);
|
||||
if (value1 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
ctx->StoreInst(value1, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
|
||||
// In all cases, we end up at the bbLogicalDone basic block;
|
||||
// loading the value stored in retPtr in turn gives the overall
|
||||
// result.
|
||||
ctx->SetCurrentBasicBlock(bbLogicalDone);
|
||||
return ctx->LoadInst(retPtr);
|
||||
}
|
||||
else {
|
||||
// Otherwise, the first operand is varying... Save the current
|
||||
// value of the mask so that we can restore it at the end.
|
||||
llvm::Value *oldMask = ctx->GetInternalMask();
|
||||
|
||||
// Convert the second operand to be varying as well, so that we can
|
||||
// perform logical vector ops with its value.
|
||||
if (type1->IsUniformType()) {
|
||||
arg1 = TypeConvertExpr(arg1, AtomicType::VaryingBool, "logical op");
|
||||
Assert(arg1 != NULL);
|
||||
type1 = arg1->GetType();
|
||||
}
|
||||
|
||||
if (op == BinaryExpr::LogicalOr) {
|
||||
// See if value0 is true for all currently executing
|
||||
// lanes--i.e. if (value0 & mask) == mask. If so, we don't
|
||||
// need to evaluate the second operand of the expression.
|
||||
llvm::Value *value0AndMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, value0, oldMask,
|
||||
"op&mask");
|
||||
llvm::Value *equalsMask =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
|
||||
value0AndMask, oldMask, "value0&mask==mask");
|
||||
equalsMask = ctx->I1VecToBoolVec(equalsMask);
|
||||
llvm::Value *allMatch = ctx->All(equalsMask);
|
||||
ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, allMatch);
|
||||
|
||||
// value0 is true for all running lanes, so it can be used for
|
||||
// the final result
|
||||
ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
|
||||
ctx->StoreInst(value0, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
|
||||
// Otherwise, we need to valuate arg1. However, first we need
|
||||
// to set the execution mask to be (oldMask & ~a); in other
|
||||
// words, only execute the instances where value0 is false.
|
||||
// For the instances where value0 was true, we need to inhibit
|
||||
// execution.
|
||||
ctx->SetCurrentBasicBlock(bbEvalValue1);
|
||||
llvm::Value *not0 = ctx->NotOperator(value0);
|
||||
ctx->SetInternalMaskAnd(oldMask, not0);
|
||||
|
||||
llvm::Value *value1 = arg1->GetValue(ctx);
|
||||
if (value1 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// We need to compute the result carefully, since vector
|
||||
// elements that were computed when the corresponding lane was
|
||||
// disabled have undefined values:
|
||||
// result = (value0 & old_mask) | (value1 & current_mask)
|
||||
llvm::Value *value1AndMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, value1,
|
||||
ctx->GetInternalMask(), "op&mask");
|
||||
llvm::Value *result =
|
||||
ctx->BinaryOperator(llvm::Instruction::Or, value0AndMask,
|
||||
value1AndMask, "or_result");
|
||||
ctx->StoreInst(result, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
}
|
||||
else {
|
||||
Assert(op == BinaryExpr::LogicalAnd);
|
||||
|
||||
// If value0 is false for all currently running lanes, the
|
||||
// overall result must be false: this corresponds to checking
|
||||
// if (mask & ~value0) == mask.
|
||||
llvm::Value *notValue0 = ctx->NotOperator(value0, "not_value0");
|
||||
llvm::Value *notValue0AndMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, notValue0, oldMask,
|
||||
"not_value0&mask");
|
||||
llvm::Value *equalsMask =
|
||||
ctx->CmpInst(llvm::Instruction::ICmp, llvm::CmpInst::ICMP_EQ,
|
||||
notValue0AndMask, oldMask, "not_value0&mask==mask");
|
||||
equalsMask = ctx->I1VecToBoolVec(equalsMask);
|
||||
llvm::Value *allMatch = ctx->All(equalsMask);
|
||||
ctx->BranchInst(bbSkipEvalValue1, bbEvalValue1, allMatch);
|
||||
|
||||
// value0 was false for all running lanes, so use its value as
|
||||
// the overall result.
|
||||
ctx->SetCurrentBasicBlock(bbSkipEvalValue1);
|
||||
ctx->StoreInst(value0, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
|
||||
// Otherwise we need to evaluate value1, but again with the
|
||||
// mask set to only be on for the lanes where value0 was true.
|
||||
// For the lanes where value0 was false, execution needs to be
|
||||
// disabled: mask = (mask & value0).
|
||||
ctx->SetCurrentBasicBlock(bbEvalValue1);
|
||||
ctx->SetInternalMaskAnd(oldMask, value0);
|
||||
|
||||
llvm::Value *value1 = arg1->GetValue(ctx);
|
||||
if (value1 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// And as in the || case, we compute the overall result by
|
||||
// masking off the valid lanes before we AND them together:
|
||||
// result = (value0 & old_mask) & (value1 & current_mask)
|
||||
llvm::Value *value0AndMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, value0, oldMask,
|
||||
"op&mask");
|
||||
llvm::Value *value1AndMask =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, value1,
|
||||
ctx->GetInternalMask(), "value1&mask");
|
||||
llvm::Value *result =
|
||||
ctx->BinaryOperator(llvm::Instruction::And, value0AndMask,
|
||||
value1AndMask, "or_result");
|
||||
ctx->StoreInst(result, retPtr);
|
||||
ctx->BranchInst(bbLogicalDone);
|
||||
}
|
||||
|
||||
// And finally we always end up in bbLogicalDone, where we restore
|
||||
// the old mask and return the computed result
|
||||
ctx->SetCurrentBasicBlock(bbLogicalDone);
|
||||
ctx->SetInternalMask(oldMask);
|
||||
return ctx->LoadInst(retPtr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
llvm::Value *
|
||||
BinaryExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||
if (!arg0 || !arg1)
|
||||
if (!arg0 || !arg1) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Handle these specially, since we want to short-circuit their evaluation...
|
||||
if (op == LogicalAnd || op == LogicalOr)
|
||||
return lEmitLogicalOp(op, arg0, arg1, ctx, pos);
|
||||
|
||||
llvm::Value *value0 = arg0->GetValue(ctx);
|
||||
llvm::Value *value1 = arg1->GetValue(ctx);
|
||||
if (value0 == NULL || value1 == NULL) {
|
||||
Assert(m->errorCount > 0);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
ctx->SetDebugPos(pos);
|
||||
|
||||
switch (op) {
|
||||
@@ -1441,12 +1702,6 @@ BinaryExpr::GetValue(FunctionEmitContext *ctx) const {
|
||||
return lEmitBinaryBitOp(op, value0, value1,
|
||||
arg0->GetType()->IsUnsignedType(), ctx);
|
||||
}
|
||||
case LogicalAnd:
|
||||
return ctx->BinaryOperator(llvm::Instruction::And, value0, value1,
|
||||
"logical_and");
|
||||
case LogicalOr:
|
||||
return ctx->BinaryOperator(llvm::Instruction::Or, value0, value1,
|
||||
"logical_or");
|
||||
case Comma:
|
||||
return value1;
|
||||
default:
|
||||
@@ -2017,12 +2272,15 @@ BinaryExpr::TypeCheck() {
|
||||
}
|
||||
case LogicalAnd:
|
||||
case LogicalOr: {
|
||||
// We need to type convert to a boolean type of the more general
|
||||
// shape of the two types
|
||||
bool isUniform = (type0->IsUniformType() && type1->IsUniformType());
|
||||
const AtomicType *boolType = isUniform ? AtomicType::UniformBool :
|
||||
AtomicType::VaryingBool;
|
||||
const Type *destType = NULL;
|
||||
// For now, we just type convert to boolean types, of the same
|
||||
// variability as the original types. (When generating code, it's
|
||||
// useful to have preserved the uniform/varying distinction.)
|
||||
const AtomicType *boolType0 = type0->IsUniformType() ?
|
||||
AtomicType::UniformBool : AtomicType::VaryingBool;
|
||||
const AtomicType *boolType1 = type1->IsUniformType() ?
|
||||
AtomicType::UniformBool : AtomicType::VaryingBool;
|
||||
|
||||
const Type *destType0 = NULL, *destType1 = NULL;
|
||||
const VectorType *vtype0 = dynamic_cast<const VectorType *>(type0);
|
||||
const VectorType *vtype1 = dynamic_cast<const VectorType *>(type1);
|
||||
if (vtype0 && vtype1) {
|
||||
@@ -2032,17 +2290,24 @@ BinaryExpr::TypeCheck() {
|
||||
"different sizes (%d vs. %d).", lOpString(op), sz0, sz1);
|
||||
return NULL;
|
||||
}
|
||||
destType = new VectorType(boolType, sz0);
|
||||
destType0 = new VectorType(boolType0, sz0);
|
||||
destType1 = new VectorType(boolType1, sz1);
|
||||
}
|
||||
else if (vtype0 != NULL) {
|
||||
destType0 = new VectorType(boolType0, vtype0->GetElementCount());
|
||||
destType1 = new VectorType(boolType1, vtype0->GetElementCount());
|
||||
}
|
||||
else if (vtype1 != NULL) {
|
||||
destType0 = new VectorType(boolType0, vtype1->GetElementCount());
|
||||
destType1 = new VectorType(boolType1, vtype1->GetElementCount());
|
||||
}
|
||||
else {
|
||||
destType0 = boolType0;
|
||||
destType1 = boolType1;
|
||||
}
|
||||
else if (vtype0)
|
||||
destType = new VectorType(boolType, vtype0->GetElementCount());
|
||||
else if (vtype1)
|
||||
destType = new VectorType(boolType, vtype1->GetElementCount());
|
||||
else
|
||||
destType = boolType;
|
||||
|
||||
arg0 = TypeConvertExpr(arg0, destType, lOpString(op));
|
||||
arg1 = TypeConvertExpr(arg1, destType, lOpString(op));
|
||||
arg0 = TypeConvertExpr(arg0, destType0, lOpString(op));
|
||||
arg1 = TypeConvertExpr(arg1, destType1, lOpString(op));
|
||||
if (arg0 == NULL || arg1 == NULL)
|
||||
return NULL;
|
||||
return this;
|
||||
|
||||
Reference in New Issue
Block a user