diff --git a/builtins/target-nvptx64.ll b/builtins/target-nvptx64.ll index 9714b0b2..063fdc7b 100644 --- a/builtins/target-nvptx64.ll +++ b/builtins/target-nvptx64.ll @@ -71,6 +71,11 @@ define float @__shfl_xor_float(float, i32) nounwind readnone alwaysinline %shfl = tail call float asm sideeffect "shfl.bfly.b32 $0, $1, $2, 0x1f;", "=f,f,r"(float %0, i32 %1) nounwind readnone alwaysinline ret float %shfl } +define i32 @__shfl_xor_i32(i32, i32) nounwind readnone alwaysinline +{ + %shfl = tail call i32 asm sideeffect "shfl.bfly.b32 $0, $1, $2, 0x1f;", "=r,r,r"(i32 %0, i32 %1) nounwind readnone alwaysinline + ret i32 %shfl +} define float @__fminf(float,float) nounwind readnone alwaysinline { %min = tail call float asm sideeffect "min.f32 $0, $1, $2;", "=f,f,f"(float %0, float %1) nounwind readnone alwaysinline @@ -667,9 +672,20 @@ define float @__reduce_max_float(<1 x float>) nounwind readnone ret float %call1.4 } -define i32 @__reduce_add_int32(<1 x i32> %v) nounwind readnone { - %r = extractelement <1 x i32> %v, i32 0 - ret i32 %r +define i32 @__reduce_add_int32(<1 x i32>) nounwind readnone +{ + %value = extractelement <1 x i32> %0, i32 0 + %call = tail call i32 @__shfl_xor_i32(i32 %value, i32 16) + %call1 = add i32 %call, %value + %call.1 = tail call i32 @__shfl_xor_i32(i32 %call1, i32 8) + %call1.1 =add i32 %call1, %call.1 + %call.2 = tail call i32 @__shfl_xor_i32(i32 %call1.1, i32 4) + %call1.2 = add i32 %call1.1, %call.2 + %call.3 = tail call i32 @__shfl_xor_i32(i32 %call1.2, i32 2) + %call1.3 = add i32 %call1.2, %call.3 + %call.4 = tail call i32 @__shfl_xor_i32(i32 %call1.3, i32 1) + %call1.4 = add i32 %call1.3, %call.4 + ret i32 %call1.4 } define i32 @__reduce_min_int32(<1 x i32>) nounwind readnone { diff --git a/ctx.cpp b/ctx.cpp index 0642fcda..983a996b 100644 --- a/ctx.cpp +++ b/ctx.cpp @@ -1406,7 +1406,7 @@ FunctionEmitContext::MasksAllEqual(llvm::Value *v1, llvm::Value *v2) { llvm::Value * FunctionEmitContext::ProgramIndexVector(bool is32bits) { - if (g->target->getISA() != Target::NVPTX64) + if (!g->target->isPTX()) //g->target->getISA() != Target::NVPTX64) { llvm::SmallVector array; for (int i = 0; i < g->target->getVectorWidth() ; ++i) { @@ -1424,7 +1424,7 @@ FunctionEmitContext::ProgramIndexVector(bool is32bits) { llvm::Function *func_warpsz = m->module->getFunction("__warpsize"); llvm::Value *__tid_x = CallInst(func_tid_x, NULL, std::vector(), "laneIdxForEach"); llvm::Value *__warpsz = CallInst(func_warpsz, NULL, std::vector(), "warpSZForEach"); - llvm::Value *__warpszm1 = BinaryOperator(llvm::Instruction::And, __warpsz, LLVMInt32(-1), "__warpszm1"); + llvm::Value *__warpszm1 = BinaryOperator(llvm::Instruction::Add, __warpsz, LLVMInt32(-1), "__warpszm1"); llvm::Value *laneIdx = BinaryOperator(llvm::Instruction::And, __tid_x, __warpszm1, "__laneidx"); llvm::Value *index = InsertInst(llvm::UndefValue::get(LLVMTypes::Int32VectorType), laneIdx, 0, "__laneIdxV"); return index;