partial support for reduce equal
This commit is contained in:
@@ -1022,10 +1022,76 @@ define i64 @__reduce_max_uint64(<1 x i64>) nounwind readnone alwaysinline {
|
|||||||
}
|
}
|
||||||
|
|
||||||
;;;; reduce equal
|
;;;; reduce equal
|
||||||
declare i1 @__reduce_equal_int32(<1 x i32> %vv, i32 * %samevalue, <1 x i1> %mask) nounwind alwaysinline;
|
define i32 @__shfl_reduce_or_step_i32_nvptx(i32, i32) nounwind readnone alwaysinline
|
||||||
declare i1 @__reduce_equal_float(<1 x float> %vv, float * %samevalue, <1 x i1> %mask) nounwind alwaysinline;
|
{
|
||||||
declare i1 @__reduce_equal_int64(<1 x i64> %vv, i64 * %samevalue, <1 x i1> %mask) nounwind alwaysinline;
|
%shfl = tail call i32 asm sideeffect
|
||||||
declare i1 @__reduce_equal_double(<1 x double> %vv, double * %samevalue, <1 x i1> %mask) nounwind alwaysinline;
|
"{.reg .u32 r0;
|
||||||
|
.reg .pred p;
|
||||||
|
shfl.bfly.b32 r0|p, $1, $2, 0;
|
||||||
|
@p or.b32 r0, r0, $3;
|
||||||
|
mov.u32 $0, r0;
|
||||||
|
}", "=r,r,r,r"(i32 %0, i32 %1, i32 %0) nounwind readnone alwaysinline
|
||||||
|
ret i32 %shfl
|
||||||
|
}
|
||||||
|
shfl64(__shfl_reduce_or_step, i64)
|
||||||
|
|
||||||
|
define(`reduce_equal',`
|
||||||
|
define i1 @__reduce_equal_$2(<1 x $1> %v0, $1 * %samevalue, <1 x i1> %maskv) nounwind alwaysinline
|
||||||
|
{
|
||||||
|
entry:
|
||||||
|
%vv = bitcast <1 x $1> %v0 to <1 x $3>
|
||||||
|
%mask = extractelement <1 x i1> %maskv, i32 0
|
||||||
|
%val0 = extractelement <1 x $3> %vv, i32 0
|
||||||
|
|
||||||
|
;; increment by one for zero value
|
||||||
|
%zero = icmp eq $3 %val0, 0
|
||||||
|
%val1 = select i1 %zero, $3 0, $3 %val0
|
||||||
|
|
||||||
|
;; for negative mask use zero
|
||||||
|
%val = select i1 %mask, $3 %val1, $3 0
|
||||||
|
|
||||||
|
;; reduce
|
||||||
|
%s0 = tail call $3 @__shfl_reduce_or_step_$3_nvptx($3 %val, i32 1)
|
||||||
|
%s1 = tail call $3 @__shfl_reduce_or_step_$3_nvptx($3 %s0, i32 2)
|
||||||
|
%s2 = tail call $3 @__shfl_reduce_or_step_$3_nvptx($3 %s1, i32 4)
|
||||||
|
%s3 = tail call $3 @__shfl_reduce_or_step_$3_nvptx($3 %s2, i32 8)
|
||||||
|
%s4 = tail call $3 @__shfl_reduce_or_step_$3_nvptx($3 %s3, i32 16)
|
||||||
|
|
||||||
|
;; find first active lane
|
||||||
|
%res1 = call i32 @__ballot_nvptx(i1 %mask)
|
||||||
|
%lane = call i32 @__count_trailing_zeros_i32(i32 %res1)
|
||||||
|
|
||||||
|
;; broadcast from this lane
|
||||||
|
%s5 = tail call $3 @__shfl_$3_nvptx($3 %s4, i32 %lane)
|
||||||
|
|
||||||
|
;; compare result to the original value
|
||||||
|
%cmp0 = icmp eq $3 %val, %s5
|
||||||
|
|
||||||
|
;; mask it if inactive
|
||||||
|
%negmask = xor i1 %mask, 1
|
||||||
|
%cmp = or i1 %cmp0, %negmask
|
||||||
|
|
||||||
|
;; compute final result
|
||||||
|
%res = call i32 @__ballot_nvptx(i1 %cmp)
|
||||||
|
%ret = icmp eq i32 %res, %res1
|
||||||
|
br i1 %ret, label %all_equal, label %retval
|
||||||
|
|
||||||
|
all_equal:
|
||||||
|
br i1 %mask, label %all_equal_store, label %retval
|
||||||
|
|
||||||
|
all_equal_store:
|
||||||
|
%vstore = extractelement <1 x $1> %v0, i32 0
|
||||||
|
store $1 %vstore, $1* %samevalue;
|
||||||
|
ret i1 %ret
|
||||||
|
|
||||||
|
retval:
|
||||||
|
ret i1 %ret
|
||||||
|
}
|
||||||
|
')
|
||||||
|
reduce_equal(i32, int32, i32);
|
||||||
|
reduce_equal(i64, int64, i64);
|
||||||
|
reduce_equal(float, float, i32);
|
||||||
|
reduce_equal(double, double, i64);
|
||||||
|
|
||||||
;;;;;;;;;;; shuffle
|
;;;;;;;;;;; shuffle
|
||||||
define(`shuffle1', `
|
define(`shuffle1', `
|
||||||
|
|||||||
Reference in New Issue
Block a user