diff --git a/func.cpp b/func.cpp index d5b1f3f9..b0bbc70f 100644 --- a/func.cpp +++ b/func.cpp @@ -240,78 +240,82 @@ Function::emitCode(FunctionEmitContext *ctx, llvm::Function *function, // thread index, and the thread count variables. llvm::Function::arg_iterator argIter = function->arg_begin(); llvm::Value *structParamPtr = argIter++; - llvm::Value *threadIndex = argIter++; - llvm::Value *threadCount = argIter++; - llvm::Value *taskIndex = argIter++; - llvm::Value *taskCount = argIter++; - llvm::Value *taskIndex0 = argIter++; - llvm::Value *taskIndex1 = argIter++; - llvm::Value *taskIndex2 = argIter++; - llvm::Value *taskCount0 = argIter++; - llvm::Value *taskCount1 = argIter++; - llvm::Value *taskCount2 = argIter++; // Copy the function parameter values from the structure into local // storage for (unsigned int i = 0; i < args.size(); ++i) - lCopyInTaskParameter(i, structParamPtr, args, ctx); + lCopyInTaskParameter(i, structParamPtr, args, ctx); if (type->isUnmasked == false) { - // Copy in the mask as well. - int nArgs = (int)args.size(); - // The mask is the last parameter in the argument structure - llvm::Value *ptr = ctx->AddElementOffset(structParamPtr, nArgs, NULL, - "task_struct_mask"); - llvm::Value *ptrval = ctx->LoadInst(ptr, "mask"); - ctx->SetFunctionMask(ptrval); + // Copy in the mask as well. + int nArgs = (int)args.size(); + // The mask is the last parameter in the argument structure + llvm::Value *ptr = ctx->AddElementOffset(structParamPtr, nArgs, NULL, + "task_struct_mask"); + llvm::Value *ptrval = ctx->LoadInst(ptr, "mask"); + ctx->SetFunctionMask(ptrval); } - // Copy threadIndex and threadCount into stack-allocated storage so - // that their symbols point to something reasonable. - threadIndexSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "threadIndex"); - ctx->StoreInst(threadIndex, threadIndexSym->storagePtr); + if (g->target->getISA() != Target::NVPTX64) + { + llvm::Value *threadIndex = argIter++; + llvm::Value *threadCount = argIter++; + llvm::Value *taskIndex = argIter++; + llvm::Value *taskCount = argIter++; + llvm::Value *taskIndex0 = argIter++; + llvm::Value *taskIndex1 = argIter++; + llvm::Value *taskIndex2 = argIter++; + llvm::Value *taskCount0 = argIter++; + llvm::Value *taskCount1 = argIter++; + llvm::Value *taskCount2 = argIter++; - threadCountSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "threadCount"); - ctx->StoreInst(threadCount, threadCountSym->storagePtr); + // Copy threadIndex and threadCount into stack-allocated storage so + // that their symbols point to something reasonable. + threadIndexSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "threadIndex"); + ctx->StoreInst(threadIndex, threadIndexSym->storagePtr); - // Copy taskIndex and taskCount into stack-allocated storage so - // that their symbols point to something reasonable. - taskIndexSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex"); - ctx->StoreInst(taskIndex, taskIndexSym->storagePtr); + threadCountSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "threadCount"); + ctx->StoreInst(threadCount, threadCountSym->storagePtr); - taskCountSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount"); - ctx->StoreInst(taskCount, taskCountSym->storagePtr); + // Copy taskIndex and taskCount into stack-allocated storage so + // that their symbols point to something reasonable. + taskIndexSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex"); + ctx->StoreInst(taskIndex, taskIndexSym->storagePtr); - /* nvptx map: - * programCount : llvm.nvvm.read.ptx.sreg.warpsize - * programIndex : llvm.ptx.read.laneid _or_ ed.ptx.sreg.tid.llvm.nvvm.read.ptx.sreg.tid.x & programCount - * taskIndex0 : llvm.nvvm.read.ptx.sreg.ctaid.x - * taskIndex1 : llvm.nvvm.read.ptx.sreg.ctaid.y - * taskIndex3 : llvm.nvvm.read.ptx.sreg.ctaid.z - * taskCount0 : llvm.nvvm.read.ptx.sreg.nctaid.x - * taskCount1 : llvm.nvvm.read.ptx.sreg.nctaid.y - * taskCount3 : llvm.nvvm.read.ptx.sreg.nctaid.z - */ - - // llvm.nvvm.read.ptx.sreg.ctaid.x - taskIndexSym0->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex0"); - ctx->StoreInst(taskIndex0, taskIndexSym0->storagePtr); - // llvm.nvvm.read.ptx.sreg.ctaid.y - taskIndexSym1->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex1"); - ctx->StoreInst(taskIndex1, taskIndexSym1->storagePtr); - // llvm.nvvm.read.ptx.sreg.ctaid.z - taskIndexSym2->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex2"); - ctx->StoreInst(taskIndex2, taskIndexSym2->storagePtr); - - // llvm.nvvm.read.ptx.sreg.nctaid.x - taskCountSym0->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount0"); - ctx->StoreInst(taskCount0, taskCountSym0->storagePtr); - // llvm.nvvm.read.ptx.sreg.nctaid.y - taskCountSym1->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount1"); - ctx->StoreInst(taskCount1, taskCountSym1->storagePtr); - // llvm.nvvm.read.ptx.sreg.nctaid.z - taskCountSym2->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount2"); - ctx->StoreInst(taskCount2, taskCountSym2->storagePtr); + taskCountSym->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount"); + ctx->StoreInst(taskCount, taskCountSym->storagePtr); + + /* nvptx map: + * programCount : llvm.nvvm.read.ptx.sreg.warpsize + * programIndex : llvm.ptx.read.laneid _or_ ed.ptx.sreg.tid.llvm.nvvm.read.ptx.sreg.tid.x & programCount + * taskIndex0 : llvm.nvvm.read.ptx.sreg.ctaid.x + * taskIndex1 : llvm.nvvm.read.ptx.sreg.ctaid.y + * taskIndex3 : llvm.nvvm.read.ptx.sreg.ctaid.z + * taskCount0 : llvm.nvvm.read.ptx.sreg.nctaid.x + * taskCount1 : llvm.nvvm.read.ptx.sreg.nctaid.y + * taskCount3 : llvm.nvvm.read.ptx.sreg.nctaid.z + */ + + // llvm.nvvm.read.ptx.sreg.ctaid.x + taskIndexSym0->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex0"); + ctx->StoreInst(taskIndex0, taskIndexSym0->storagePtr); + // llvm.nvvm.read.ptx.sreg.ctaid.y + taskIndexSym1->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex1"); + ctx->StoreInst(taskIndex1, taskIndexSym1->storagePtr); + // llvm.nvvm.read.ptx.sreg.ctaid.z + taskIndexSym2->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskIndex2"); + ctx->StoreInst(taskIndex2, taskIndexSym2->storagePtr); + + // llvm.nvvm.read.ptx.sreg.nctaid.x + taskCountSym0->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount0"); + ctx->StoreInst(taskCount0, taskCountSym0->storagePtr); + // llvm.nvvm.read.ptx.sreg.nctaid.y + taskCountSym1->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount1"); + ctx->StoreInst(taskCount1, taskCountSym1->storagePtr); + // llvm.nvvm.read.ptx.sreg.nctaid.z + taskCountSym2->storagePtr = ctx->AllocaInst(LLVMTypes::Int32Type, "taskCount2"); + ctx->StoreInst(taskCount2, taskCountSym2->storagePtr); + } } else { // Regular, non-task function diff --git a/type.cpp b/type.cpp index 516276f0..956a0ddd 100644 --- a/type.cpp +++ b/type.cpp @@ -2957,16 +2957,19 @@ FunctionType::LLVMFunctionType(llvm::LLVMContext *ctx, bool removeMask) const { // hold them until the task actually runs.) llvm::Type *st = llvm::StructType::get(*ctx, llvmArgTypes); callTypes.push_back(llvm::PointerType::getUnqual(st)); - callTypes.push_back(LLVMTypes::Int32Type); // threadIndex - callTypes.push_back(LLVMTypes::Int32Type); // threadCount - callTypes.push_back(LLVMTypes::Int32Type); // taskIndex - callTypes.push_back(LLVMTypes::Int32Type); // taskCount - callTypes.push_back(LLVMTypes::Int32Type); // taskIndex0 - callTypes.push_back(LLVMTypes::Int32Type); // taskIndex1 - callTypes.push_back(LLVMTypes::Int32Type); // taskIndex2 - callTypes.push_back(LLVMTypes::Int32Type); // taskCount0 - callTypes.push_back(LLVMTypes::Int32Type); // taskCount1 - callTypes.push_back(LLVMTypes::Int32Type); // taskCount2 + if (g->target->getISA() != Target::NVPTX64) + { + callTypes.push_back(LLVMTypes::Int32Type); // threadIndex + callTypes.push_back(LLVMTypes::Int32Type); // threadCount + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex + callTypes.push_back(LLVMTypes::Int32Type); // taskCount + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex0 + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex1 + callTypes.push_back(LLVMTypes::Int32Type); // taskIndex2 + callTypes.push_back(LLVMTypes::Int32Type); // taskCount0 + callTypes.push_back(LLVMTypes::Int32Type); // taskCount1 + callTypes.push_back(LLVMTypes::Int32Type); // taskCount2 + } } else // Otherwise we already have the types of the arguments