Skip to content

Commit 2be0acf

Browse files
committed
Fix Performance Issue
1 parent a023896 commit 2be0acf

File tree

6 files changed

+86
-20
lines changed

6 files changed

+86
-20
lines changed

torchvision/csrc/ops/cpu/nms_kernel.cpp

+40
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,50 @@ at::Tensor nms_kernel(
107107
return result;
108108
}
109109

110+
111+
at::Tensor nms_kernel_postprocess(
112+
const at::Tensor& order,
113+
const at::Tensor& iou_keep_out_mask,
114+
const int64_t num_boxes) {
115+
116+
// ceil div to 32. Which is the size of ulong type.
117+
const int col_blocks = (num_boxes + 32 - 1) / 32;
118+
std::vector<unsigned long> remove_box(col_blocks);
119+
std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks);
120+
121+
122+
at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU));
123+
int64_t * keep_data_ptr = keep.data_ptr<int64_t>();
124+
125+
unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr<int64_t>();
126+
int num_to_keep = 0;
127+
unsigned long long* iou_keep_out_mask_data_ptr0 = (unsigned long long*)iou_keep_out_mask[0].data_ptr<int64_t>();
128+
unsigned long long*iou_keep_out_mask_data_ptr1 = (unsigned long long*)iou_keep_out_mask[1].data_ptr<int64_t>();
129+
130+
// Note that the iou_keep_out_mask has the shape of (N, N//32)
131+
for (int64_t i = 0; i < num_boxes; i++) {
132+
int nblock = i / 32;
133+
// module 32
134+
int inblock = (31 - i) & (32 -1);
135+
136+
if (!(remove_box[nblock] & (1UL << inblock))){
137+
keep_data_ptr[num_to_keep++]=i;
138+
unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks;
139+
for (int j = nblock; j < col_blocks; j++){
140+
remove_box[j] |= p[j];
141+
}
142+
}
143+
}
144+
return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)});
145+
}
146+
147+
148+
110149
} // namespace
111150

112151
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
113152
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel));
153+
m.impl(TORCH_SELECTIVE_NAME("torchvision::nms_kernel_postprocess"), TORCH_FN(nms_kernel_postprocess));
114154
}
115155

116156
} // namespace ops

torchvision/csrc/ops/nms.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) {
2222
m.set_python_module("torchvision._meta_registrations");
2323
m.def(TORCH_SELECTIVE_SCHEMA(
2424
"torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"));
25+
m.def(TORCH_SELECTIVE_SCHEMA(
26+
"torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor"));
2527
}
2628

2729
} // namespace ops

torchvision/ops/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
generalized_box_iou,
1111
masks_to_boxes,
1212
nms,
13+
nms_kernel_postprocess,
1314
remove_small_boxes,
1415
)
1516
from .ciou_loss import complete_box_iou_loss
@@ -37,6 +38,7 @@
3738
"DeformConv2d",
3839
"nms",
3940
"batched_nms",
41+
"nms_kernel_postprocess",
4042
"remove_small_boxes",
4143
"clip_boxes_to_image",
4244
"box_convert",

torchvision/ops/boxes.py

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
4141
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
4242

4343

44+
def nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor:
45+
return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes)
46+
47+
4448
def batched_nms(
4549
boxes: Tensor,
4650
scores: Tensor,

torchvision/ops/triton/nms.py

+24-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33

44

55
@triton.jit
6-
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: tl.constexpr):
6+
def _combine_bits(val0, val1):
7+
tl.static_assert(val0.dtype == tl.int32, "input must be int32")
8+
tl.static_assert(val1.dtype == tl.int32, "input must be int32")
9+
return val0 | val1
10+
11+
12+
def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr):
713
"""
814
This nms_kernel computes the supressed mask of boxes [i, j].
915
mask[i, j]==1 means if we choose box 1, the box j will be supressed.
@@ -14,6 +20,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
1420
output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored.
1521
threshold (float): The IoU threshold for suppressing boxes.
1622
num_boxes (int): The total number of boxes.
23+
stride_i (int): The stride of the output tensor along the first dimension.
24+
stride_j (int): The stride of the output tensor along the second dimension.
1725
BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel.
1826
"""
1927

@@ -59,14 +67,23 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t
5967
area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1)
6068
union = area_a + area_b - intersection
6169

62-
iou_keep_out_mask = ((intersection / union) > threshold).to(tl.int8)
70+
iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32)
71+
72+
shift_offsets = tl.arange(0, BLOCK_SIZE) % 32
73+
shift_offsets = tl.flip(shift_offsets, 0)[None, :]
74+
shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE])
75+
iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets
76+
77+
iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32))
78+
iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits)
6379

80+
iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)
6481
output_block_ptr = tl.make_block_ptr(
6582
output_ptr,
66-
shape=(num_boxes, num_boxes),
67-
strides=(num_boxes, 1),
68-
offsets=(row_block_start, col_block_start),
69-
block_shape=(BLOCK_SIZE, BLOCK_SIZE),
83+
shape=(num_boxes, (num_boxes + 32 - 1) // 32),
84+
strides=(stride_i, stride_j),
85+
offsets=(row_block_start, 0),
86+
block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32),
7087
order=(0, 1),
7188
)
72-
tl.store(output_block_ptr, iou_keep_out_mask, boundary_check=(0, 1))
89+
tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1))

torchvision/ops/xpu/nms.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import triton
3+
from torchvision.ops.boxes import nms_kernel_postprocess
34

45
from torchvision.ops.triton.nms import triton_nms_IoU_kernel
56

@@ -35,21 +36,21 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float)
3536
# Triton does not support argsort yet, thus it needs to fallback to ATen Calls
3637
order = torch.argsort(scores, descending=True)
3738
boxes = boxes[order]
38-
iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device)
39+
iou_keep_out_mask = torch.zeros(num_boxes, (num_boxes + 32 - 1) // 32, dtype=torch.int64, device=boxes.device)
3940

4041
grid = lambda meta: ( # noqa: E731
4142
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
4243
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
4344
)
44-
# TODO: We need to tune the config from different devices.
45-
triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8)
46-
47-
# # TODO: Need to improve performance for this reduction
48-
picked = []
49-
remove_box = torch.zeros(num_boxes, dtype=torch.bool, device=boxes.device)
50-
for i in range(num_boxes):
51-
if not (remove_box[i]):
52-
picked.append(order[i])
53-
remove_box[i:] |= iou_keep_out_mask[i][i:]
54-
55-
return torch.as_tensor(picked)
45+
triton_nms_IoU_kernel[grid](
46+
boxes,
47+
iou_keep_out_mask,
48+
threshold,
49+
num_boxes,
50+
iou_keep_out_mask.stride(0),
51+
iou_keep_out_mask.stride(1),
52+
BLOCK_SIZE=64,
53+
num_warps=4,
54+
)
55+
56+
return nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device)

0 commit comments

Comments
 (0)