Allow 'continue' statements in 'foreach' loops.

This commit is contained in:
Matt Pharr
2011-12-03 09:31:02 -08:00
parent c3b55de1ad
commit a1c0b4f95a
17 changed files with 359 additions and 40 deletions

99
ctx.cpp
View File

@@ -401,22 +401,26 @@ FunctionEmitContext::EndIf() {
// or continue statements (and breakLanesPtr and continueLanesPtr
// have their initial 'all off' values), so we don't need to check
// for that here.
if (breakLanesPtr != NULL) {
assert(continueLanesPtr != NULL);
if (continueLanesPtr != NULL) {
// We want to compute:
// newMask = (oldMask & ~(breakLanes | continueLanes))
llvm::Value *oldMask = GetInternalMask();
llvm::Value *breakLanes = LoadInst(breakLanesPtr, "break_lanes");
llvm::Value *continueLanes = LoadInst(continueLanesPtr,
"continue_lanes");
llvm::Value *breakOrContinueLanes =
BinaryOperator(llvm::Instruction::Or, breakLanes, continueLanes,
"break|continue_lanes");
llvm::Value *notBreakOrContinue = NotOperator(breakOrContinueLanes,
"!(break|continue)_lanes");
llvm::Value *bcLanes = continueLanes;
if (breakLanesPtr != NULL) {
// breakLanesPtr will be NULL if we're inside a 'foreach' loop
llvm::Value *breakLanes = LoadInst(breakLanesPtr, "break_lanes");
bcLanes = BinaryOperator(llvm::Instruction::Or, breakLanes,
continueLanes, "break|continue_lanes");
}
llvm::Value *notBreakOrContinue =
NotOperator(bcLanes, "!(break|continue)_lanes");
llvm::Value *newMask =
BinaryOperator(llvm::Instruction::And, oldMask, notBreakOrContinue,
"new_mask");
BinaryOperator(llvm::Instruction::And, oldMask,
notBreakOrContinue, "new_mask");
SetInternalMask(newMask);
}
}
@@ -477,15 +481,20 @@ FunctionEmitContext::EndLoop() {
void
FunctionEmitContext::StartForeach() {
FunctionEmitContext::StartForeach(llvm::BasicBlock *ct) {
// Store the current values of various loop-related state so that we
// can restore it when we exit this loop.
llvm::Value *oldMask = GetInternalMask();
controlFlowInfo.push_back(CFInfo::GetForeach(breakTarget, continueTarget, breakLanesPtr,
continueLanesPtr, oldMask, loopMask));
continueLanesPtr = breakLanesPtr = NULL;
controlFlowInfo.push_back(CFInfo::GetForeach(breakTarget, continueTarget,
breakLanesPtr, continueLanesPtr,
oldMask, loopMask));
breakLanesPtr = NULL;
breakTarget = NULL;
continueTarget = NULL;
continueLanesPtr = AllocaInst(LLVMTypes::MaskType, "foreach_continue_lanes");
StoreInst(LLVMMaskAllOff, continueLanesPtr);
continueTarget = ct;
loopMask = NULL;
}
@@ -526,7 +535,8 @@ FunctionEmitContext::restoreMaskGivenReturns(llvm::Value *oldMask) {
void
FunctionEmitContext::Break(bool doCoherenceCheck) {
if (breakTarget == NULL) {
Error(currentPos, "\"break\" statement is illegal outside of for/while/do loops.");
Error(currentPos, "\"break\" statement is illegal outside of "
"for/while/do loops.");
return;
}
@@ -576,7 +586,8 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
void
FunctionEmitContext::Continue(bool doCoherenceCheck) {
if (!continueTarget) {
Error(currentPos, "\"continue\" statement illegal outside of for/while/do loops.");
Error(currentPos, "\"continue\" statement illegal outside of "
"for/while/do/foreach loops.");
return;
}
@@ -586,8 +597,8 @@ FunctionEmitContext::Continue(bool doCoherenceCheck) {
// loop or if we can tell that the mask is all on.
AddInstrumentationPoint("continue: uniform CF, jumped");
if (ifsInLoopAllUniform() && doCoherenceCheck)
Warning(currentPos, "Coherent continue statement not necessary in fully uniform "
"control flow.");
Warning(currentPos, "Coherent continue statement not necessary in "
"fully uniform control flow.");
BranchInst(continueTarget);
bblock = NULL;
}
@@ -638,26 +649,40 @@ FunctionEmitContext::ifsInLoopAllUniform() const {
void
FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) {
// Check to see if (returned lanes | continued lanes | break lanes) is
// equal to the value of mask at the start of the loop iteration. If
// so, everyone is done and we can jump to the given target
llvm::Value *returned = LoadInst(returnedLanesPtr,
"returned_lanes");
llvm::Value *continued = LoadInst(continueLanesPtr,
"continue_lanes");
llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes");
llvm::Value *returnedOrContinued = BinaryOperator(llvm::Instruction::Or,
returned, continued,
"returned|continued");
llvm::Value *returnedOrContinuedOrBreaked =
BinaryOperator(llvm::Instruction::Or, returnedOrContinued,
breaked, "returned|continued");
llvm::Value *allDone = NULL;
assert(continueLanesPtr != NULL);
if (breakLanesPtr == NULL) {
// In a foreach loop, break and return are illegal, and
// breakLanesPtr is NULL. In this case, the mask is guaranteed to
// be all on at the start of each iteration, so we only need to
// check if all lanes have continued..
llvm::Value *continued = LoadInst(continueLanesPtr,
"continue_lanes");
allDone = All(continued);
}
else {
// Check to see if (returned lanes | continued lanes | break lanes) is
// equal to the value of mask at the start of the loop iteration. If
// so, everyone is done and we can jump to the given target
llvm::Value *returned = LoadInst(returnedLanesPtr,
"returned_lanes");
llvm::Value *continued = LoadInst(continueLanesPtr,
"continue_lanes");
llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes");
llvm::Value *returnedOrContinued = BinaryOperator(llvm::Instruction::Or,
returned, continued,
"returned|continued");
llvm::Value *returnedOrContinuedOrBreaked =
BinaryOperator(llvm::Instruction::Or, returnedOrContinued,
breaked, "returned|continued");
// Do we match the mask at loop entry?
allDone = MasksAllEqual(returnedOrContinuedOrBreaked, loopMask);
}
// Do we match the mask at loop entry?
llvm::Value *allRCB = MasksAllEqual(returnedOrContinuedOrBreaked, loopMask);
llvm::BasicBlock *bAll = CreateBasicBlock("all_continued_or_breaked");
llvm::BasicBlock *bNotAll = CreateBasicBlock("not_all_continued_or_breaked");
BranchInst(bAll, bNotAll, allRCB);
BranchInst(bAll, bNotAll, allDone);
// If so, have an extra basic block along the way to add
// instrumentation, if the user asked for it.

7
ctx.h
View File

@@ -159,8 +159,11 @@ public:
finished. */
void EndLoop();
/** */
void StartForeach();
/** Indicates that code generation for a 'foreach' or 'foreach_tiled'
loop is about to start. The provided basic block pointer indicates
where control flow should go if a 'continue' statement is executed
in the loop. */
void StartForeach(llvm::BasicBlock *continueTarget);
void EndForeach();
/** Emit code for a 'break' statement in a loop. If doCoherenceCheck

View File

@@ -1962,6 +1962,14 @@ of iteration dimensions may be specified, with each one spanning a
different range of values. Within the ``foreach`` loop, the given
identifiers are available as ``const varying int32`` variables.
It is illegal to have a ``break`` statement or a ``return`` statement
within a ``foreach`` loop; a compile-time error will be issued in this
case. (It is legal to have a ``break`` in a regular ``for`` loop that's
nested inside a ``foreach`` loop.) ``continue`` statements are legal in
``foreach`` loops; they have the same effect as in regular ``for`` loops:
a program instances that executes a ``continue`` statement effectively
skips over the rest of the loop body for the current iteration.
As a specific example, consdier the following ``foreach`` statement:
::

View File

@@ -1728,7 +1728,6 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
llvm::Value *oldMask = ctx->GetInternalMask();
ctx->StartForeach();
ctx->SetDebugPos(pos);
ctx->StartScope();
@@ -1803,6 +1802,8 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]);
}
ctx->StartForeach(bbStep[nDims-1]);
// On to the outermost loop's test
ctx->BranchInst(bbTest[0]);
@@ -1886,6 +1887,8 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
// just generate its value when we need it in the loop body.
for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbStep[i]);
if (i == nDims-1)
ctx->RestoreContinuedLanes();
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i]);
llvm::Value *newCounter =
ctx->BinaryOperator(llvm::Instruction::Add, counter,

20
tests/foreach-20.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
for (uniform int i = 0; i < programCount; ++i)
RET[i] = 0;
foreach (i = 0 ... programCount) {
if (i % 2)
continue;
RET[i] = 1;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex % 2) ? 0 : 1;
}

20
tests/foreach-21.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
for (uniform int i = 0; i < programCount; ++i)
RET[i] = 0;
foreach (i = 0 ... programCount) {
if (i % 2)
ccontinue;
RET[i] = 1;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex % 2) ? 0 : 1;
}

27
tests/foreach-22.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (i < 4)
ccontinue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
RET[0] = 0;
RET[1] = 1;
RET[2] = 2;
RET[3] = 3;
}

27
tests/foreach-23.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (i < 4)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
RET[0] = 0;
RET[1] = 1;
RET[2] = 2;
RET[3] = 3;
}

27
tests/foreach-24.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (i < 4)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
RET[0] = 0;
RET[1] = 1;
RET[2] = 2;
RET[3] = 3;
}

23
tests/foreach-25.ispc Normal file
View File

@@ -0,0 +1,23 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
foreach (i = 0 ... 21) {
x[i] = 0;
if (x[i] != 12345)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
}

24
tests/foreach-26.ispc Normal file
View File

@@ -0,0 +1,24 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (x[i] != 12345)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = min(programIndex, 20);
}

23
tests/foreach-27.ispc Normal file
View File

@@ -0,0 +1,23 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (x[i] == 12345)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
}

21
tests/foreach-28.ispc Normal file
View File

@@ -0,0 +1,21 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
float b = ub;
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (b == 12345)
ccontinue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 10;
}

21
tests/foreach-29.ispc Normal file
View File

@@ -0,0 +1,21 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
float b = ub;
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (ub == 12345)
ccontinue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 10;
}

20
tests/foreach-30.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (i == 5)
continue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 9;
}

20
tests/foreach-31.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (i == 5)
ccontinue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 9;
}

View File

@@ -0,0 +1,7 @@
// Can't assign to type "const int32"
int foo() {
foreach (i = 0 ... 10) {
++i;
}
}