diff --git a/func.cpp b/func.cpp index b975049b..dea45afc 100644 --- a/func.cpp +++ b/func.cpp @@ -132,9 +132,28 @@ Function::Function(Symbol *s, Stmt *c) { Assert(taskIndexSym); taskCountSym = m->symbolTable->LookupVariable("taskCount"); Assert(taskCountSym); + + taskIndexSym_x = m->symbolTable->LookupVariable("taskIndex_x"); + Assert(taskIndexSym_x); + taskIndexSym_y = m->symbolTable->LookupVariable("taskIndex_y"); + Assert(taskIndexSym_y); + taskIndexSym_z = m->symbolTable->LookupVariable("taskIndex_z"); + Assert(taskIndexSym_z); + + + taskCountSym_x = m->symbolTable->LookupVariable("taskCount_x"); + Assert(taskCountSym_x); + taskCountSym_y = m->symbolTable->LookupVariable("taskCount_y"); + Assert(taskCountSym_y); + taskCountSym_z = m->symbolTable->LookupVariable("taskCount_z"); + Assert(taskCountSym_z); } else + { threadIndexSym = threadCountSym = taskIndexSym = taskCountSym = NULL; + taskIndexSym_x = taskIndexSym_y = taskIndexSym_z = NULL; + taskCountSym_x = taskCountSym_y = taskCountSym_z = NULL; + } } @@ -225,6 +244,12 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function, llvm::Value *threadCount = argIter++; llvm::Value *taskIndex = argIter++; llvm::Value *taskCount = argIter++; + llvm::Value *taskIndex_x = argIter++; + llvm::Value *taskIndex_y = argIter++; + llvm::Value *taskIndex_z = argIter++; + llvm::Value *taskCount_x = argIter++; + llvm::Value *taskCount_y = argIter++; + llvm::Value *taskCount_z = argIter++; // Copy the function parameter values from the structure into local // storage @@ -256,6 +281,20 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function, taskCountSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount"); ctx->StoreInst(taskCount, taskCountSym->storagePtr); + + taskIndexSym_x->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex_x"); + ctx->StoreInst(taskIndex_x, taskIndexSym_x->storagePtr); + taskIndexSym_y->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex_y"); + ctx->StoreInst(taskIndex_y, taskIndexSym_y->storagePtr); + taskIndexSym_z->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex_z"); + ctx->StoreInst(taskIndex_z, taskIndexSym_z->storagePtr); + + taskCountSym_x->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount_x"); + ctx->StoreInst(taskCount_x, taskCountSym_x->storagePtr); + taskCountSym_y->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount_y"); + ctx->StoreInst(taskCount_y, taskCountSym_y->storagePtr); + taskCountSym_z->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount_z"); + ctx->StoreInst(taskCount_z, taskCountSym_z->storagePtr); } else { // Regular, non-task function diff --git a/func.h b/func.h index ac3e1447..ee44a6c5 100644 --- a/func.h +++ b/func.h @@ -60,7 +60,10 @@ private: Stmt *code; Symbol *maskSymbol; Symbol *threadIndexSym, *threadCountSym; - Symbol *taskIndexSym, *taskCountSym; + Symbol *taskIndexSym, *taskCountSym; + Symbol *taskIndexSym_x, *taskCountSym_x; + Symbol *taskIndexSym_y, *taskCountSym_y; + Symbol *taskIndexSym_z, *taskCountSym_z; }; #endif // ISPC_FUNC_H diff --git a/parse.yy b/parse.yy index 38c5ba77..1de4644f 100644 --- a/parse.yy +++ b/parse.yy @@ -2214,9 +2214,24 @@ static void lAddThreadIndexCountToSymbolTable(SourcePos pos) { Symbol *taskIndexSym = new Symbol("taskIndex", pos, type); m->symbolTable->AddVariable(taskIndexSym); - + Symbol *taskCountSym = new Symbol("taskCount", pos, type); m->symbolTable->AddVariable(taskCountSym); + + Symbol *taskIndexSym_x = new Symbol("taskIndex_x", pos, type); + m->symbolTable->AddVariable(taskIndexSym_x); + Symbol *taskIndexSym_y = new Symbol("taskIndex_y", pos, type); + m->symbolTable->AddVariable(taskIndexSym_y); + Symbol *taskIndexSym_z = new Symbol("taskIndex_z", pos, type); + m->symbolTable->AddVariable(taskIndexSym_z); + + + Symbol *taskCountSym_x = new Symbol("taskCount_x", pos, type); + m->symbolTable->AddVariable(taskCountSym_x); + Symbol *taskCountSym_y = new Symbol("taskCount_y", pos, type); + m->symbolTable->AddVariable(taskCountSym_y); + Symbol *taskCountSym_z = new Symbol("taskCount_z", pos, type); + m->symbolTable->AddVariable(taskCountSym_z); } diff --git a/type.cpp b/type.cpp index 5fa1845b..d36c63c2 100644 --- a/type.cpp +++ b/type.cpp @@ -2961,6 +2961,12 @@ FunctionType::LLVMFunctionType(llvm::LLVMContext *ctx, bool removeMask) const { callTypes.push_back(LLVMTypes::Int32Type); // threadCount callTypes.push_back(LLVMTypes::Int32Type); // taskIndex callTypes.push_back(LLVMTypes::Int32Type); // taskCount + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex_x + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex_y + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex_z + callTypes.push_back(LLVMTypes::Int32Type); // taskCount_x + callTypes.push_back(LLVMTypes::Int32Type); // taskCount_y + callTypes.push_back(LLVMTypes::Int32Type); // taskCount_z } else // Otherwise we already have the types of the arguments