diff --git a/stmt.cpp b/stmt.cpp index 974af871..25d1d3d5 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -1649,11 +1649,12 @@ ForeachStmt::EmitCode(FunctionEmitContext *ctx) const { // width. Set the mask and jump to the masked loop body. ctx->SetCurrentBasicBlock(bbAllInnerPartialOuter); { llvm::Value *mask; - if (extrasMaskPtrs.size() == 0) + if (nDims == 1) // 1D loop; we shouldn't ever get here anyway mask = LLVMMaskAllOff; else - mask = ctx->LoadInst(extrasMaskPtrs.back()); + mask = ctx->LoadInst(extrasMaskPtrs[nDims-2]); + ctx->SetInternalMask(mask); ctx->StoreInst(LLVMTrue, stepIndexAfterMaskedBodyPtr);