From 926b3b9ee3d1a6d08160d17ca36f8198020411c5 Mon Sep 17 00:00:00 2001 From: Matt Pharr Date: Mon, 9 Jul 2012 15:13:30 -0700 Subject: [PATCH] Fix bugs with mask-handling for switch/do/for/while statements. All of these pass the current mask to FunctionEmitContext::SetBlockEntryMask() so that when a break/continue/return is encountered, it can test to see if all lanes have followed that path and then return; this in turn ensures that we never run statements with an all-off execution mask. These functions were passing the function internal mask, not the full mask, and thus could end up executing code with the mask all off if some lanes were disabled by an outer function. (The new tests test this case.) --- stmt.cpp | 6 +++--- tests/entry-mask-do.ispc | 30 ++++++++++++++++++++++++++++++ tests/entry-mask-for.ispc | 30 ++++++++++++++++++++++++++++++ tests/entry-mask-switch.ispc | 30 ++++++++++++++++++++++++++++++ 4 files changed, 93 insertions(+), 3 deletions(-) create mode 100644 tests/entry-mask-do.ispc create mode 100644 tests/entry-mask-for.ispc create mode 100644 tests/entry-mask-switch.ispc diff --git a/stmt.cpp b/stmt.cpp index 7a547149..04807faf 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -830,7 +830,7 @@ void DoStmt::EmitCode(FunctionEmitContext *ctx) const { // And now emit code for the loop body ctx->SetCurrentBasicBlock(bloop); - ctx->SetBlockEntryMask(ctx->GetInternalMask()); + ctx->SetBlockEntryMask(ctx->GetFullMask()); ctx->SetDebugPos(pos); // FIXME: in the StmtList::EmitCode() method takes starts/stops a new // scope around the statements in the list. So if the body is just a @@ -1047,7 +1047,7 @@ ForStmt::EmitCode(FunctionEmitContext *ctx) const { // On to emitting the code for the loop body. ctx->SetCurrentBasicBlock(bloop); - ctx->SetBlockEntryMask(ctx->GetInternalMask()); + ctx->SetBlockEntryMask(ctx->GetFullMask()); ctx->AddInstrumentationPoint("for loop body"); if (!dynamic_cast(stmts)) ctx->StartScope(); @@ -2555,7 +2555,7 @@ SwitchStmt::EmitCode(FunctionEmitContext *ctx) const { bool isUniformCF = (type->IsUniformType() && lHasVaryingBreakOrContinue(stmts) == false); ctx->StartSwitch(isUniformCF, bbDone); - ctx->SetBlockEntryMask(ctx->GetInternalMask()); + ctx->SetBlockEntryMask(ctx->GetFullMask()); ctx->SwitchInst(exprValue, svi.defaultBlock ? svi.defaultBlock : bbDone, svi.caseBlocks, svi.nextBlock); diff --git a/tests/entry-mask-do.ispc b/tests/entry-mask-do.ispc new file mode 100644 index 00000000..dd2254b3 --- /dev/null +++ b/tests/entry-mask-do.ispc @@ -0,0 +1,30 @@ + +export uniform int width() { return programCount; } + +int * uniform ptr = NULL; + +int func(int v) { + int ret; + // print("%\n", v); + do { + if (v == 0) { + ret = 1; + break; + } + *ptr = 1; + } while(true); + return ret; +} + +export void f_f(uniform float RET[], uniform float aFOO[]) { + int count = 10; + if (programIndex & 1) + count = 0x7ffffff; + RET[programIndex] = 0; + if (!(programIndex & 1)) + RET[programIndex] = func(programIndex & 1); +} + +export void result(uniform float RET[]) { + RET[programIndex] = !(programIndex & 1); +} diff --git a/tests/entry-mask-for.ispc b/tests/entry-mask-for.ispc new file mode 100644 index 00000000..a04642c7 --- /dev/null +++ b/tests/entry-mask-for.ispc @@ -0,0 +1,30 @@ + +export uniform int width() { return programCount; } + +int * uniform ptr = NULL; + +int func(int v) { + int ret; + // print("%\n", v); + for (;;) { + if (v == 0) { + ret = 1; + break; + } + *ptr = 1; + } + return ret; +} + +export void f_f(uniform float RET[], uniform float aFOO[]) { + int count = 10; + if (programIndex & 1) + count = 0x7ffffff; + RET[programIndex] = 0; + if (!(programIndex & 1)) + RET[programIndex] = func(programIndex & 1); +} + +export void result(uniform float RET[]) { + RET[programIndex] = !(programIndex & 1); +} diff --git a/tests/entry-mask-switch.ispc b/tests/entry-mask-switch.ispc new file mode 100644 index 00000000..541f1487 --- /dev/null +++ b/tests/entry-mask-switch.ispc @@ -0,0 +1,30 @@ + +export uniform int width() { return programCount; } + +int * uniform ptr = NULL; + +int func(int v) { + int ret; + // print("%\n", v); + switch (v) { + case 0: + ret = 1; + break; + case 1: + *ptr = 1; + } + return ret; +} + +export void f_f(uniform float RET[], uniform float aFOO[]) { + int count = 10; + if (programIndex & 1) + count = 0x7ffffff; + RET[programIndex] = 0; + if (!(programIndex & 1)) + RET[programIndex] = func(programIndex & 1); +} + +export void result(uniform float RET[]) { + RET[programIndex] = !(programIndex & 1); +}