diff --git a/builtins/target-nvptx.ll b/builtins/target-nvptx.ll index fd314d3b..1b5a6392 100644 --- a/builtins/target-nvptx.ll +++ b/builtins/target-nvptx.ll @@ -1386,3 +1386,21 @@ extract_insert(i64, int64) extract_insert(float, float) extract_insert(double, double) +define i8* @__extract_void(<1 x i8*>, i32) nounwind readnone alwaysinline { + %val = extractelement <1 x i8*> %0, i32 0 + %b64 = ptrtoint i8* %val to i64 + %extract64 = tail call i64 @__shfl_i64_nvptx(i64 %b64, i32 %1) + %extract = inttoptr i64 %extract64 to i8* + ret i8* %extract +} + +define <1 x i8*> @__insert_void(<1 x i8*>, i32, + i8*) nounwind readnone alwaysinline { + %orig = extractelement <1 x i8*> %0, i32 0 + %lane = call i32 @__laneidx() + %c = icmp eq i32 %lane, %1 + %val = select i1 %c, i8* %2, i8* %orig + %insert = insertelement <1 x i8*> %0, i8* %val, i32 0 + ret <1 x i8*> %insert +} + diff --git a/ctx.cpp b/ctx.cpp index 1f6e5e53..9925d51c 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -1390,16 +1390,57 @@ FunctionEmitContext::LaneMask(llvm::Value *v) { return CallInst(fmm, NULL, v, LLVMGetName(v, "_movmsk")); } +bool lAppendInsertExtractName(llvm::Value *vector, std::string &funcName) +{ + llvm::Type *type = vector->getType(); + if (type == LLVMTypes::Int8VectorType) + funcName += "_int8"; + else if (type == LLVMTypes::Int16VectorType) + funcName += "_int16"; + else if (type == LLVMTypes::Int32VectorType) + funcName += "_int32"; + else if (type == LLVMTypes::Int64VectorType) + funcName += "_int64"; + else if (type == LLVMTypes::FloatVectorType) + funcName += "_float"; + else if (type == LLVMTypes::DoubleVectorType) + funcName += "_double"; + else + return false; + return true; +} + llvm::Value* FunctionEmitContext::Insert(llvm::Value *vector, llvm::Value *lane, llvm::Value *scalar) { - return NULL; + std::string funcName = "__insert"; + assert(lAppendInsertExtractName(vector, funcName)); + assert(lane->getType() == LLVMTypes::Int32Type); + + llvm::Function *func = m->module->getFunction(funcName.c_str()); + assert(func != NULL); + std::vector args; + args.push_back(vector); + args.push_back(lane); + args.push_back(scalar); + llvm::Value *ret = llvm::CallInst::Create(func, args, LLVMGetName(vector, funcName.c_str()), GetCurrentBasicBlock()); + return ret; } llvm::Value* FunctionEmitContext::Extract(llvm::Value *vector, llvm::Value *lane) { - return NULL; + std::string funcName = "__extract"; + assert(lAppendInsertExtractName(vector, funcName)); + assert(lane->getType() == LLVMTypes::Int32Type); + + llvm::Function *func = m->module->getFunction(funcName.c_str()); + assert(func != NULL); + std::vector args; + args.push_back(vector); + args.push_back(lane); + llvm::Value *ret = llvm::CallInst::Create(func, args, LLVMGetName(vector, funcName.c_str()), GetCurrentBasicBlock()); + return ret; } diff --git a/stmt.cpp b/stmt.cpp index b30a0000..4c9a8f05 100644 --- a/stmt.cpp +++ b/stmt.cpp @@ -2355,8 +2355,6 @@ ForeachUniqueStmt::ForeachUniqueStmt(const char *iterName, Expr *e, sym = m->symbolTable->LookupVariable(iterName); expr = e; stmts = s; - if (g->target->getISA() == Target::NVPTX) - Error(pos, "\"foreach_unique\" is not yetsupported with \"nvptx\" target."); } @@ -2442,10 +2440,19 @@ ForeachUniqueStmt::EmitCode(FunctionEmitContext *ctx) const { // And load the corresponding element value from the temporary // memory storing the value of the varying expr. - llvm::Value *uniqueValuePtr = + llvm::Value *uniqueValue; + if (g->target->getISA() != Target::NVPTX) + { + llvm::Value *uniqueValuePtr = ctx->GetElementPtrInst(exprMem, LLVMInt64(0), firstSet, exprPtrType, - "unique_index_ptr"); - llvm::Value *uniqueValue = ctx->LoadInst(uniqueValuePtr, "unique_value"); + "unique_index_ptr"); + uniqueValue = ctx->LoadInst(uniqueValuePtr, "unique_value"); + } + else /* in case of PTX target, use __shfl PTX intrinsics via __insert/__extract function */ + { + llvm::Value *firstSet32 = ctx->TruncInst(firstSet, LLVMTypes::Int32Type); + uniqueValue = ctx->Extract(exprValue, firstSet32); + } // If it's a varying pointer type, need to convert from the int // type we store in the vector to the actual pointer type