Allow 'continue' statements in 'foreach' loops.

This commit is contained in:
Matt Pharr
2011-12-03 09:31:02 -08:00
parent c3b55de1ad
commit a1c0b4f95a
17 changed files with 359 additions and 40 deletions

99
ctx.cpp
View File

@@ -401,22 +401,26 @@ FunctionEmitContext::EndIf() {
// or continue statements (and breakLanesPtr and continueLanesPtr // or continue statements (and breakLanesPtr and continueLanesPtr
// have their initial 'all off' values), so we don't need to check // have their initial 'all off' values), so we don't need to check
// for that here. // for that here.
if (breakLanesPtr != NULL) { if (continueLanesPtr != NULL) {
assert(continueLanesPtr != NULL); // We want to compute:
// newMask = (oldMask & ~(breakLanes | continueLanes)) // newMask = (oldMask & ~(breakLanes | continueLanes))
llvm::Value *oldMask = GetInternalMask(); llvm::Value *oldMask = GetInternalMask();
llvm::Value *breakLanes = LoadInst(breakLanesPtr, "break_lanes");
llvm::Value *continueLanes = LoadInst(continueLanesPtr, llvm::Value *continueLanes = LoadInst(continueLanesPtr,
"continue_lanes"); "continue_lanes");
llvm::Value *breakOrContinueLanes = llvm::Value *bcLanes = continueLanes;
BinaryOperator(llvm::Instruction::Or, breakLanes, continueLanes,
"break|continue_lanes"); if (breakLanesPtr != NULL) {
llvm::Value *notBreakOrContinue = NotOperator(breakOrContinueLanes, // breakLanesPtr will be NULL if we're inside a 'foreach' loop
"!(break|continue)_lanes"); llvm::Value *breakLanes = LoadInst(breakLanesPtr, "break_lanes");
bcLanes = BinaryOperator(llvm::Instruction::Or, breakLanes,
continueLanes, "break|continue_lanes");
}
llvm::Value *notBreakOrContinue =
NotOperator(bcLanes, "!(break|continue)_lanes");
llvm::Value *newMask = llvm::Value *newMask =
BinaryOperator(llvm::Instruction::And, oldMask, notBreakOrContinue, BinaryOperator(llvm::Instruction::And, oldMask,
"new_mask"); notBreakOrContinue, "new_mask");
SetInternalMask(newMask); SetInternalMask(newMask);
} }
} }
@@ -477,15 +481,20 @@ FunctionEmitContext::EndLoop() {
void void
FunctionEmitContext::StartForeach() { FunctionEmitContext::StartForeach(llvm::BasicBlock *ct) {
// Store the current values of various loop-related state so that we // Store the current values of various loop-related state so that we
// can restore it when we exit this loop. // can restore it when we exit this loop.
llvm::Value *oldMask = GetInternalMask(); llvm::Value *oldMask = GetInternalMask();
controlFlowInfo.push_back(CFInfo::GetForeach(breakTarget, continueTarget, breakLanesPtr, controlFlowInfo.push_back(CFInfo::GetForeach(breakTarget, continueTarget,
continueLanesPtr, oldMask, loopMask)); breakLanesPtr, continueLanesPtr,
continueLanesPtr = breakLanesPtr = NULL; oldMask, loopMask));
breakLanesPtr = NULL;
breakTarget = NULL; breakTarget = NULL;
continueTarget = NULL;
continueLanesPtr = AllocaInst(LLVMTypes::MaskType, "foreach_continue_lanes");
StoreInst(LLVMMaskAllOff, continueLanesPtr);
continueTarget = ct;
loopMask = NULL; loopMask = NULL;
} }
@@ -526,7 +535,8 @@ FunctionEmitContext::restoreMaskGivenReturns(llvm::Value *oldMask) {
void void
FunctionEmitContext::Break(bool doCoherenceCheck) { FunctionEmitContext::Break(bool doCoherenceCheck) {
if (breakTarget == NULL) { if (breakTarget == NULL) {
Error(currentPos, "\"break\" statement is illegal outside of for/while/do loops."); Error(currentPos, "\"break\" statement is illegal outside of "
"for/while/do loops.");
return; return;
} }
@@ -576,7 +586,8 @@ FunctionEmitContext::Break(bool doCoherenceCheck) {
void void
FunctionEmitContext::Continue(bool doCoherenceCheck) { FunctionEmitContext::Continue(bool doCoherenceCheck) {
if (!continueTarget) { if (!continueTarget) {
Error(currentPos, "\"continue\" statement illegal outside of for/while/do loops."); Error(currentPos, "\"continue\" statement illegal outside of "
"for/while/do/foreach loops.");
return; return;
} }
@@ -586,8 +597,8 @@ FunctionEmitContext::Continue(bool doCoherenceCheck) {
// loop or if we can tell that the mask is all on. // loop or if we can tell that the mask is all on.
AddInstrumentationPoint("continue: uniform CF, jumped"); AddInstrumentationPoint("continue: uniform CF, jumped");
if (ifsInLoopAllUniform() && doCoherenceCheck) if (ifsInLoopAllUniform() && doCoherenceCheck)
Warning(currentPos, "Coherent continue statement not necessary in fully uniform " Warning(currentPos, "Coherent continue statement not necessary in "
"control flow."); "fully uniform control flow.");
BranchInst(continueTarget); BranchInst(continueTarget);
bblock = NULL; bblock = NULL;
} }
@@ -638,26 +649,40 @@ FunctionEmitContext::ifsInLoopAllUniform() const {
void void
FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) { FunctionEmitContext::jumpIfAllLoopLanesAreDone(llvm::BasicBlock *target) {
// Check to see if (returned lanes | continued lanes | break lanes) is llvm::Value *allDone = NULL;
// equal to the value of mask at the start of the loop iteration. If assert(continueLanesPtr != NULL);
// so, everyone is done and we can jump to the given target if (breakLanesPtr == NULL) {
llvm::Value *returned = LoadInst(returnedLanesPtr, // In a foreach loop, break and return are illegal, and
"returned_lanes"); // breakLanesPtr is NULL. In this case, the mask is guaranteed to
llvm::Value *continued = LoadInst(continueLanesPtr, // be all on at the start of each iteration, so we only need to
"continue_lanes"); // check if all lanes have continued..
llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes"); llvm::Value *continued = LoadInst(continueLanesPtr,
llvm::Value *returnedOrContinued = BinaryOperator(llvm::Instruction::Or, "continue_lanes");
returned, continued, allDone = All(continued);
"returned|continued"); }
llvm::Value *returnedOrContinuedOrBreaked = else {
BinaryOperator(llvm::Instruction::Or, returnedOrContinued, // Check to see if (returned lanes | continued lanes | break lanes) is
breaked, "returned|continued"); // equal to the value of mask at the start of the loop iteration. If
// so, everyone is done and we can jump to the given target
llvm::Value *returned = LoadInst(returnedLanesPtr,
"returned_lanes");
llvm::Value *continued = LoadInst(continueLanesPtr,
"continue_lanes");
llvm::Value *breaked = LoadInst(breakLanesPtr, "break_lanes");
llvm::Value *returnedOrContinued = BinaryOperator(llvm::Instruction::Or,
returned, continued,
"returned|continued");
llvm::Value *returnedOrContinuedOrBreaked =
BinaryOperator(llvm::Instruction::Or, returnedOrContinued,
breaked, "returned|continued");
// Do we match the mask at loop entry?
allDone = MasksAllEqual(returnedOrContinuedOrBreaked, loopMask);
}
// Do we match the mask at loop entry?
llvm::Value *allRCB = MasksAllEqual(returnedOrContinuedOrBreaked, loopMask);
llvm::BasicBlock *bAll = CreateBasicBlock("all_continued_or_breaked"); llvm::BasicBlock *bAll = CreateBasicBlock("all_continued_or_breaked");
llvm::BasicBlock *bNotAll = CreateBasicBlock("not_all_continued_or_breaked"); llvm::BasicBlock *bNotAll = CreateBasicBlock("not_all_continued_or_breaked");
BranchInst(bAll, bNotAll, allRCB); BranchInst(bAll, bNotAll, allDone);
// If so, have an extra basic block along the way to add // If so, have an extra basic block along the way to add
// instrumentation, if the user asked for it. // instrumentation, if the user asked for it.

7
ctx.h
View File

@@ -159,8 +159,11 @@ public:
finished. */ finished. */
void EndLoop(); void EndLoop();
/** */ /** Indicates that code generation for a 'foreach' or 'foreach_tiled'
void StartForeach(); loop is about to start. The provided basic block pointer indicates
where control flow should go if a 'continue' statement is executed
in the loop. */
void StartForeach(llvm::BasicBlock *continueTarget);
void EndForeach(); void EndForeach();
/** Emit code for a 'break' statement in a loop. If doCoherenceCheck /** Emit code for a 'break' statement in a loop. If doCoherenceCheck

View File

@@ -1962,6 +1962,14 @@ of iteration dimensions may be specified, with each one spanning a
different range of values. Within the ``foreach`` loop, the given different range of values. Within the ``foreach`` loop, the given
identifiers are available as ``const varying int32`` variables. identifiers are available as ``const varying int32`` variables.
It is illegal to have a ``break`` statement or a ``return`` statement
within a ``foreach`` loop; a compile-time error will be issued in this
case. (It is legal to have a ``break`` in a regular ``for`` loop that's
nested inside a ``foreach`` loop.) ``continue`` statements are legal in
``foreach`` loops; they have the same effect as in regular ``for`` loops:
a program instances that executes a ``continue`` statement effectively
skips over the rest of the loop body for the current iteration.
As a specific example, consdier the following ``foreach`` statement: As a specific example, consdier the following ``foreach`` statement:
:: ::

View File

@@ -1728,7 +1728,6 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
llvm::Value *oldMask = ctx->GetInternalMask(); llvm::Value *oldMask = ctx->GetInternalMask();
ctx->StartForeach();
ctx->SetDebugPos(pos); ctx->SetDebugPos(pos);
ctx->StartScope(); ctx->StartScope();
@@ -1803,6 +1802,8 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]); ctx->StoreInst(LLVMMaskAllOn, extrasMaskPtrs[i]);
} }
ctx->StartForeach(bbStep[nDims-1]);
// On to the outermost loop's test // On to the outermost loop's test
ctx->BranchInst(bbTest[0]); ctx->BranchInst(bbTest[0]);
@@ -1886,6 +1887,8 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const {
// just generate its value when we need it in the loop body. // just generate its value when we need it in the loop body.
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
ctx->SetCurrentBasicBlock(bbStep[i]); ctx->SetCurrentBasicBlock(bbStep[i]);
if (i == nDims-1)
ctx->RestoreContinuedLanes();
llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i]); llvm::Value *counter = ctx->LoadInst(uniformCounterPtrs[i]);
llvm::Value *newCounter = llvm::Value *newCounter =
ctx->BinaryOperator(llvm::Instruction::Add, counter, ctx->BinaryOperator(llvm::Instruction::Add, counter,

20
tests/foreach-20.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
for (uniform int i = 0; i < programCount; ++i)
RET[i] = 0;
foreach (i = 0 ... programCount) {
if (i % 2)
continue;
RET[i] = 1;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex % 2) ? 0 : 1;
}

20
tests/foreach-21.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
for (uniform int i = 0; i < programCount; ++i)
RET[i] = 0;
foreach (i = 0 ... programCount) {
if (i % 2)
ccontinue;
RET[i] = 1;
}
}
export void result(uniform float RET[]) {
RET[programIndex] = (programIndex % 2) ? 0 : 1;
}

27
tests/foreach-22.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (i < 4)
ccontinue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
RET[0] = 0;
RET[1] = 1;
RET[2] = 2;
RET[3] = 3;
}

27
tests/foreach-23.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (i < 4)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
RET[0] = 0;
RET[1] = 1;
RET[2] = 2;
RET[3] = 3;
}

27
tests/foreach-24.ispc Normal file
View File

@@ -0,0 +1,27 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (i < 4)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
RET[0] = 0;
RET[1] = 1;
RET[2] = 2;
RET[3] = 3;
}

23
tests/foreach-25.ispc Normal file
View File

@@ -0,0 +1,23 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
foreach (i = 0 ... 21) {
x[i] = 0;
if (x[i] != 12345)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 0;
}

24
tests/foreach-26.ispc Normal file
View File

@@ -0,0 +1,24 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (x[i] != 12345)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = min(programIndex, 20);
}

23
tests/foreach-27.ispc Normal file
View File

@@ -0,0 +1,23 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_f(uniform float RET[], uniform float aFOO[]) {
uniform int x[21];
for (uniform int i = 0; i < 21; ++i)
x[i] = i;
foreach (i = 0 ... 21) {
if (x[i] == 12345)
continue;
x[i] = 1;
}
RET[programIndex] = x[min(programIndex, 20)];
}
export void result(uniform float RET[]) {
RET[programIndex] = 1;
}

21
tests/foreach-28.ispc Normal file
View File

@@ -0,0 +1,21 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
float b = ub;
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (b == 12345)
ccontinue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 10;
}

21
tests/foreach-29.ispc Normal file
View File

@@ -0,0 +1,21 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
float b = ub;
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (ub == 12345)
ccontinue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 10;
}

20
tests/foreach-30.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (i == 5)
continue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 9;
}

20
tests/foreach-31.ispc Normal file
View File

@@ -0,0 +1,20 @@
export uniform int width() { return programCount; }
uniform int foo(int i);
export void f_fu(uniform float RET[], uniform float aFOO[], uniform float ub) {
RET[programIndex] = 0;
foreach (i = 0 ... 10, j = 0 ... programCount + 1) {
if (i == 5)
ccontinue;
if (j > 0)
++RET[j-1];
}
}
export void result(uniform float RET[]) {
RET[programIndex] = 9;
}

View File

@@ -0,0 +1,7 @@
// Can't assign to type "const int32"
int foo() {
foreach (i = 0 ... 10) {
++i;
}
}