3
3
4
4
5
5
@triton .jit
6
- def triton_nms_IoU_kernel (boxes , output_ptr , threshold , num_boxes , BLOCK_SIZE : tl .constexpr ):
6
+ def _combine_bits (val0 , val1 ):
7
+ tl .static_assert (val0 .dtype == tl .int32 , "input must be int32" )
8
+ tl .static_assert (val1 .dtype == tl .int32 , "input must be int32" )
9
+ return val0 | val1
10
+
11
+
12
+ def triton_nms_IoU_kernel (boxes , output_ptr , threshold , num_boxes , stride_i , stride_j , BLOCK_SIZE : tl .constexpr ):
7
13
"""
8
14
This nms_kernel computes the supressed mask of boxes [i, j].
9
15
mask[i, j]==1 means if we choose box 1, the box j will be supressed.
@@ -14,6 +20,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
14
20
output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored.
15
21
threshold (float): The IoU threshold for suppressing boxes.
16
22
num_boxes (int): The total number of boxes.
23
+ stride_i (int): The stride of the output tensor along the first dimension.
24
+ stride_j (int): The stride of the output tensor along the second dimension.
17
25
BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
18
26
"""
19
27
@@ -59,14 +67,23 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
59
67
area_b = (col_block_x2 - col_block_x1 ) * (col_block_y2 - col_block_y1 )
60
68
union = area_a + area_b - intersection
61
69
62
- iou_keep_out_mask = ((intersection / union ) > threshold ).to (tl .int8 )
70
+ iou_keep_out_bit_mask = ((intersection / union ) > threshold ).to (tl .int32 )
71
+
72
+ shift_offsets = tl .arange (0 , BLOCK_SIZE ) % 32
73
+ shift_offsets = tl .flip (shift_offsets , 0 )[None , :]
74
+ shift_offsets = tl .broadcast_to (shift_offsets .to (tl .int32 ), [BLOCK_SIZE , BLOCK_SIZE ])
75
+ iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets
76
+
77
+ iou_keep_out_bit_mask = tl .reshape (iou_keep_out_bit_mask , (BLOCK_SIZE , (BLOCK_SIZE + 32 - 1 ) // 32 , 32 ))
78
+ iou_keep_out_combined = tl .reduce (iou_keep_out_bit_mask , axis = 2 , combine_fn = _combine_bits )
63
79
80
+ iou_keep_out_combined = iou_keep_out_combined .to (tl .int64 )
64
81
output_block_ptr = tl .make_block_ptr (
65
82
output_ptr ,
66
- shape = (num_boxes , num_boxes ),
67
- strides = (num_boxes , 1 ),
68
- offsets = (row_block_start , col_block_start ),
69
- block_shape = (BLOCK_SIZE , BLOCK_SIZE ),
83
+ shape = (num_boxes , ( num_boxes + 32 - 1 ) // 32 ),
84
+ strides = (stride_i , stride_j ),
85
+ offsets = (row_block_start , 0 ),
86
+ block_shape = (BLOCK_SIZE , ( BLOCK_SIZE + 32 - 1 ) // 32 ),
70
87
order = (0 , 1 ),
71
88
)
72
- tl .store (output_block_ptr , iou_keep_out_mask , boundary_check = (0 , 1 ))
89
+ tl .store (output_block_ptr , iou_keep_out_combined , boundary_check = (0 , 1 ))
0 commit comments