From e0cf83310214af8c5689c269855f27a8cf41ed59 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Fri, 25 Oct 2024 14:37:57 +0800 Subject: [PATCH 1/9] Init XPU support for NMS kernel --- torchvision/ops/__init__.py | 3 +- torchvision/ops/triton/__init__.py | 0 torchvision/ops/triton/nms.py | 77 ++++++++++++++++++++++++++++++ torchvision/ops/xpu/__init__.py | 5 ++ torchvision/ops/xpu/nms.py | 52 ++++++++++++++++++++ 5 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 torchvision/ops/triton/__init__.py create mode 100644 torchvision/ops/triton/nms.py create mode 100644 torchvision/ops/xpu/__init__.py create mode 100644 torchvision/ops/xpu/nms.py diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 827505b842d..f750b2ee2db 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -26,9 +26,10 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth +from .xpu import _register_xpu_ops _register_custom_op() - +_register_xpu_ops() __all__ = [ "masks_to_boxes", diff --git a/torchvision/ops/triton/__init__.py b/torchvision/ops/triton/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py new file mode 100644 index 00000000000..20c97fc0e05 --- /dev/null +++ b/torchvision/ops/triton/nms.py @@ -0,0 +1,77 @@ +import torch +import torchvision.ops +import triton +import triton.language as tl +from torch import Tensor +from torch._decomp import register_decomposition +from torchvision import ops + + +@triton.jit +def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: tl.constexpr): + """ + This nms_kernel computes the supressed mask of boxes [i, j]. + mask[i, j]==1 means if we choose box 1, the box j will be supressed. + The output is a mask of size [num_boxes, num_boxes]. + + Args: + boxes (tl.tensor): A tensor containing the bounding boxes with shape (num_boxes, 4). + output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored. + threshold (float): The IoU threshold for suppressing boxes. + num_boxes (int): The total number of boxes. + BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel. + """ + + # The Triton kernel is a 2D block kernel. The block size is BLOCK_SIZE x BLOCK_SIZE. + # Each kernel will compute the IoU of boxes[row: row + BLOCK_SIZE, col: col + BLOCK_SIZE] + row_block_pid = tl.program_id(axis=0) + col_block_pid = tl.program_id(axis=1) + + row_block_start = row_block_pid * BLOCK_SIZE + col_block_start = col_block_pid * BLOCK_SIZE + + row_block_offsets = row_block_start + tl.arange(0, BLOCK_SIZE) + col_block_offsets = col_block_start + tl.arange(0, BLOCK_SIZE) + + row_block_mask = row_block_offsets < num_boxes + col_block_mask = col_block_offsets < num_boxes + + # Since Triton does not support tensor slicing yet, we need to load point elements individiually + # Every row_block is loaded as a 1 dim tensor of size [BLOCK_SIZE] + # We then expand 1 dim for row. So that the row block dim would be [BLOCK_SIZE, 1] + row_block_x1 = tl.load(boxes + row_block_offsets * 4 + 0, mask=row_block_mask)[:, None] + row_block_y1 = tl.load(boxes + row_block_offsets * 4 + 1, mask=row_block_mask)[:, None] + row_block_x2 = tl.load(boxes + row_block_offsets * 4 + 2, mask=row_block_mask)[:, None] + row_block_y2 = tl.load(boxes + row_block_offsets * 4 + 3, mask=row_block_mask)[:, None] + + # Expand 1 dim for col. So that the col block dim would be [1, BLOCK_SIZE] + col_block_x1 = tl.load(boxes + col_block_offsets * 4 + 0, mask=col_block_mask)[None, :] + col_block_y1 = tl.load(boxes + col_block_offsets * 4 + 1, mask=col_block_mask)[None, :] + col_block_x2 = tl.load(boxes + col_block_offsets * 4 + 2, mask=col_block_mask)[None, :] + col_block_y2 = tl.load(boxes + col_block_offsets * 4 + 3, mask=col_block_mask)[None, :] + + # Together, the minimum / maximum will broadcast and form into a [BLOCK_SIZE, BLOCK_SIZE] matrix + left = tl.maximum(row_block_x1, col_block_x1) + right = tl.minimum(row_block_x2, col_block_x2) + top = tl.maximum(row_block_y1, col_block_y1) + bottom = tl.minimum(row_block_y2, col_block_y2) + + width = tl.maximum(right - left, 0) + height = tl.maximum(bottom - top, 0) + + intersection = width * height + area_a = (row_block_x2 - row_block_x1) * (row_block_y2 - row_block_y1) + area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1) + union = area_a + area_b - intersection + + iou_keep_out_mask = ((intersection / union) > threshold).to(tl.int8) + + output_block_ptr = tl.make_block_ptr( + output_ptr, + shape=(num_boxes, num_boxes), + strides=(num_boxes, 1), + offsets=(row_block_start, col_block_start), + block_shape=(BLOCK_SIZE, BLOCK_SIZE), + order=(0, 1), + ) + tl.store(output_block_ptr, iou_keep_out_mask, boundary_check=(0, 1)) diff --git a/torchvision/ops/xpu/__init__.py b/torchvision/ops/xpu/__init__.py new file mode 100644 index 00000000000..863d37a4c4e --- /dev/null +++ b/torchvision/ops/xpu/__init__.py @@ -0,0 +1,5 @@ +from .nms import xpu_triton_nms + + +def _register_xpu_ops(): + pass diff --git a/torchvision/ops/xpu/nms.py b/torchvision/ops/xpu/nms.py new file mode 100644 index 00000000000..de2a6af306b --- /dev/null +++ b/torchvision/ops/xpu/nms.py @@ -0,0 +1,52 @@ +import torch +import triton + +from torchvision.ops.triton.nms import triton_nms_IoU_kernel + + +@torch.library.register_kernel("torchvision::nms", "xpu") +def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) -> torch.Tensor: + """ + Performs non-maximum suppression (NMS) on the boxes according + to their intersection-over-union (IoU). + + NMS iteratively removes lower scoring boxes which have an + IoU greater than ``iou_threshold`` with another (higher scoring) + box. + + If multiple boxes have the exact same score and satisfy the IoU + criterion with respect to a reference box, the selected box is + not guaranteed to be the same between CPU and GPU. This is similar + to the behavior of argsort in PyTorch when repeated values are present. + + Args: + boxes (Tensor[N, 4])): boxes to perform NMS on. They + are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and + ``0 <= y1 < y2``. + scores (Tensor[N]): scores for each one of the boxes + iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold + + Returns: + Tensor: int64 tensor with the indices of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + num_boxes = boxes.shape[0] + + # Triton does not support argsort yet, thus it needs to fallback to ATen Calls + order = torch.argsort(scores, descending=True) + boxes = boxes[order] + iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device) + + grid = lambda meta: (triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), triton.cdiv(num_boxes, meta["BLOCK_SIZE"])) + # TODO: We need to tune the config from different devices. + triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8) + + # # TODO: Need to improve performance for this reduction + picked = [] + remove_box = torch.zeros(num_boxes, dtype=torch.bool, device=boxes.device) + for i in range(num_boxes): + if not (remove_box[i]): + picked.append(order[i]) + remove_box[i:] |= iou_keep_out_mask[i][i:] + + return torch.as_tensor(picked) From d79010ad76ea8ff28cad20625756a0c648d3b938 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Fri, 25 Oct 2024 14:38:10 +0800 Subject: [PATCH 2/9] Init test for NMS kernel --- test/common_utils.py | 7 +++++++ test/conftest.py | 8 ++++++++ test/test_ops.py | 1 + 3 files changed, 16 insertions(+) diff --git a/test/common_utils.py b/test/common_utils.py index 99c7931587d..a0dac03ab0f 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -29,6 +29,7 @@ IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1" CUDA_NOT_AVAILABLE_MSG = "CUDA device not available" MPS_NOT_AVAILABLE_MSG = "MPS device not available" +XPU_NOT_AVAILABLE_MSG = "XPU device not available" OSS_CI_GPU_NO_CUDA_MSG = "We're in an OSS GPU machine, and this test doesn't need cuda." @@ -141,6 +142,12 @@ def needs_mps(test_func): return pytest.mark.needs_mps(test_func) +def needs_xpu(test_func): + import pytest # noqa + + return pytest.mark.needs_xpu(test_func) + + def _create_data(height=3, width=3, channels=3, device="cpu"): # TODO: When all relevant tests are ported to pytest, turn this into a module-level fixture tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device) diff --git a/test/conftest.py b/test/conftest.py index a9768598ded..984cba981b9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,7 @@ IN_RE_WORKER, MPS_NOT_AVAILABLE_MSG, OSS_CI_GPU_NO_CUDA_MSG, + XPU_NOT_AVAILABLE_MSG, ) @@ -18,6 +19,7 @@ def pytest_configure(config): # register an additional marker (see pytest_collection_modifyitems) config.addinivalue_line("markers", "needs_cuda: mark for tests that rely on a CUDA device") config.addinivalue_line("markers", "needs_mps: mark for tests that rely on a MPS device") + config.addinivalue_line("markers", "needs_xpu: mark for tests that rely on a XPU device") config.addinivalue_line("markers", "dont_collect: mark for tests that should not be collected") config.addinivalue_line("markers", "opcheck_only_one: only opcheck one parametrization") @@ -43,12 +45,18 @@ def pytest_collection_modifyitems(items): # and the ones with device == 'cpu' won't have the mark. needs_cuda = item.get_closest_marker("needs_cuda") is not None needs_mps = item.get_closest_marker("needs_mps") is not None + needs_xpu = item.get_closest_marker("needs_xpu") is not None if needs_cuda and not torch.cuda.is_available(): # In general, we skip cuda tests on machines without a GPU # There are special cases though, see below item.add_marker(pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)) + if needs_xpu and not torch.xpu.is_available(): + # In general, we skip xpu tests on machines without a GPU + # There are special cases though, see below + item.add_marker(pytest.mark.skip(reason=XPU_NOT_AVAILABLE_MSG)) + if needs_mps and not torch.backends.mps.is_available(): item.add_marker(pytest.mark.skip(reason=MPS_NOT_AVAILABLE_MSG)) diff --git a/test/test_ops.py b/test/test_ops.py index 1ba7a2c9efa..4519ed967a6 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -831,6 +831,7 @@ def test_qnms(self, iou, scale, zero_point): ( pytest.param("cuda", marks=pytest.mark.needs_cuda), pytest.param("mps", marks=pytest.mark.needs_mps), + pytest.param("xpu", marks=pytest.mark.needs_xpu), ), ) @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) From 01df93ae02733f00c85fa8e9f27d1d0ed190d92b Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Mon, 25 Nov 2024 20:21:03 +0800 Subject: [PATCH 3/9] format code --- torchvision/ops/triton/nms.py | 5 ----- torchvision/ops/xpu/nms.py | 5 ++++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py index 20c97fc0e05..40943934697 100644 --- a/torchvision/ops/triton/nms.py +++ b/torchvision/ops/triton/nms.py @@ -1,10 +1,5 @@ -import torch -import torchvision.ops import triton import triton.language as tl -from torch import Tensor -from torch._decomp import register_decomposition -from torchvision import ops @triton.jit diff --git a/torchvision/ops/xpu/nms.py b/torchvision/ops/xpu/nms.py index de2a6af306b..209109c3762 100644 --- a/torchvision/ops/xpu/nms.py +++ b/torchvision/ops/xpu/nms.py @@ -37,7 +37,10 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) boxes = boxes[order] iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device) - grid = lambda meta: (triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), triton.cdiv(num_boxes, meta["BLOCK_SIZE"])) + grid = lambda meta: ( # noqa: E731 + triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), + triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), + ) # TODO: We need to tune the config from different devices. triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8) From 607c8396ab766472df2f425bc18038a041e6f96c Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Tue, 3 Dec 2024 17:48:37 +0800 Subject: [PATCH 4/9] Fix Performance Issue --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 40 +++++++++++++++++++++++++ torchvision/csrc/ops/nms.cpp | 2 ++ torchvision/ops/__init__.py | 2 ++ torchvision/ops/boxes.py | 4 +++ torchvision/ops/triton/nms.py | 31 ++++++++++++++----- torchvision/ops/xpu/nms.py | 27 +++++++++-------- 6 files changed, 86 insertions(+), 20 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 50479066cbd..7756d577fbd 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -107,10 +107,50 @@ at::Tensor nms_kernel( return result; } + +at::Tensor nms_kernel_postprocess( + const at::Tensor& order, + const at::Tensor& iou_keep_out_mask, + const int64_t num_boxes) { + + // ceil div to 32. Which is the size of ulong type. + const int col_blocks = (num_boxes + 32 - 1) / 32; + std::vector remove_box(col_blocks); + std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks); + + + at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU)); + int64_t * keep_data_ptr = keep.data_ptr(); + + unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr(); + int num_to_keep = 0; + unsigned long long* iou_keep_out_mask_data_ptr0 = (unsigned long long*)iou_keep_out_mask[0].data_ptr(); + unsigned long long*iou_keep_out_mask_data_ptr1 = (unsigned long long*)iou_keep_out_mask[1].data_ptr(); + + // Note that the iou_keep_out_mask has the shape of (N, N//32) + for (int64_t i = 0; i < num_boxes; i++) { + int nblock = i / 32; + // module 32 + int inblock = (31 - i) & (32 -1); + + if (!(remove_box[nblock] & (1UL << inblock))){ + keep_data_ptr[num_to_keep++]=i; + unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks; + for (int j = nblock; j < col_blocks; j++){ + remove_box[j] |= p[j]; + } + } + } + return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}); +} + + + } // namespace TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl(TORCH_SELECTIVE_NAME("torchvision::nms_kernel_postprocess"), TORCH_FN(nms_kernel_postprocess)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms.cpp b/torchvision/csrc/ops/nms.cpp index 5ecf8812f1b..f1eb9c0ee0f 100644 --- a/torchvision/csrc/ops/nms.cpp +++ b/torchvision/csrc/ops/nms.cpp @@ -22,6 +22,8 @@ TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.set_python_module("torchvision._meta_registrations"); m.def(TORCH_SELECTIVE_SCHEMA( "torchvision::nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms_kernel_postprocess(Tensor order, Tensor iou_keep_out_mask, int num_boxes) -> Tensor")); } } // namespace ops diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index f750b2ee2db..bb944347ce0 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -10,6 +10,7 @@ generalized_box_iou, masks_to_boxes, nms, + nms_kernel_postprocess, remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss @@ -37,6 +38,7 @@ "DeformConv2d", "nms", "batched_nms", + "nms_kernel_postprocess", "remove_small_boxes", "clip_boxes_to_image", "box_convert", diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 96631278d48..89cdc2f761d 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -41,6 +41,10 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: return torch.ops.torchvision.nms(boxes, scores, iou_threshold) +def nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor: + return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) + + def batched_nms( boxes: Tensor, scores: Tensor, diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py index 40943934697..797b82283f8 100644 --- a/torchvision/ops/triton/nms.py +++ b/torchvision/ops/triton/nms.py @@ -3,7 +3,13 @@ @triton.jit -def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: tl.constexpr): +def _combine_bits(val0, val1): + tl.static_assert(val0.dtype == tl.int32, "input must be int32") + tl.static_assert(val1.dtype == tl.int32, "input must be int32") + return val0 | val1 + + +def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr): """ This nms_kernel computes the supressed mask of boxes [i, j]. mask[i, j]==1 means if we choose box 1, the box j will be supressed. @@ -14,6 +20,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t output_ptr (tl.pointer): A pointer to the output tensor where the mask will be stored. threshold (float): The IoU threshold for suppressing boxes. num_boxes (int): The total number of boxes. + stride_i (int): The stride of the output tensor along the first dimension. + stride_j (int): The stride of the output tensor along the second dimension. BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel. """ @@ -59,14 +67,23 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, BLOCK_SIZE: t area_b = (col_block_x2 - col_block_x1) * (col_block_y2 - col_block_y1) union = area_a + area_b - intersection - iou_keep_out_mask = ((intersection / union) > threshold).to(tl.int8) + iou_keep_out_bit_mask = ((intersection / union) > threshold).to(tl.int32) + + shift_offsets = tl.arange(0, BLOCK_SIZE) % 32 + shift_offsets = tl.flip(shift_offsets, 0)[None, :] + shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE]) + iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets + + iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32)) + iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits) + iou_keep_out_combined = iou_keep_out_combined.to(tl.int64) output_block_ptr = tl.make_block_ptr( output_ptr, - shape=(num_boxes, num_boxes), - strides=(num_boxes, 1), - offsets=(row_block_start, col_block_start), - block_shape=(BLOCK_SIZE, BLOCK_SIZE), + shape=(num_boxes, (num_boxes + 32 - 1) // 32), + strides=(stride_i, stride_j), + offsets=(row_block_start, 0), + block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32), order=(0, 1), ) - tl.store(output_block_ptr, iou_keep_out_mask, boundary_check=(0, 1)) + tl.store(output_block_ptr, iou_keep_out_combined, boundary_check=(0, 1)) diff --git a/torchvision/ops/xpu/nms.py b/torchvision/ops/xpu/nms.py index 209109c3762..0e46b322ed7 100644 --- a/torchvision/ops/xpu/nms.py +++ b/torchvision/ops/xpu/nms.py @@ -1,5 +1,6 @@ import torch import triton +from torchvision.ops.boxes import nms_kernel_postprocess from torchvision.ops.triton.nms import triton_nms_IoU_kernel @@ -35,21 +36,21 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) # Triton does not support argsort yet, thus it needs to fallback to ATen Calls order = torch.argsort(scores, descending=True) boxes = boxes[order] - iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device) + iou_keep_out_mask = torch.zeros(num_boxes, (num_boxes + 32 - 1) // 32, dtype=torch.int64, device=boxes.device) grid = lambda meta: ( # noqa: E731 triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), ) - # TODO: We need to tune the config from different devices. - triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8) - - # # TODO: Need to improve performance for this reduction - picked = [] - remove_box = torch.zeros(num_boxes, dtype=torch.bool, device=boxes.device) - for i in range(num_boxes): - if not (remove_box[i]): - picked.append(order[i]) - remove_box[i:] |= iou_keep_out_mask[i][i:] - - return torch.as_tensor(picked) + triton_nms_IoU_kernel[grid]( + boxes, + iou_keep_out_mask, + threshold, + num_boxes, + iou_keep_out_mask.stride(0), + iou_keep_out_mask.stride(1), + BLOCK_SIZE=64, + num_warps=4, + ) + + return nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device) From 7f277f60349e7ac60ad646f9cc3ab8f5b3d89244 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Mon, 16 Dec 2024 08:44:46 +0000 Subject: [PATCH 5/9] Fix runtime issue --- torchvision/ops/triton/nms.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py index 797b82283f8..c80b252e949 100644 --- a/torchvision/ops/triton/nms.py +++ b/torchvision/ops/triton/nms.py @@ -9,6 +9,7 @@ def _combine_bits(val0, val1): return val0 | val1 +@triton.jit def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr): """ This nms_kernel computes the supressed mask of boxes [i, j]. @@ -76,13 +77,16 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32)) iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits) - iou_keep_out_combined = iou_keep_out_combined.to(tl.int64) + + # The bits are combined along the col, thus we need to change the col block offsets + # For the row offset, it will remain the same. + combined_col_blk_offsets = col_block_pid * ((BLOCK_SIZE + 31) // 32) output_block_ptr = tl.make_block_ptr( output_ptr, shape=(num_boxes, (num_boxes + 32 - 1) // 32), strides=(stride_i, stride_j), - offsets=(row_block_start, 0), + offsets=(row_block_start, combined_col_blk_offsets), block_shape=(BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32), order=(0, 1), ) From 96745dfd6b4292f05f694e8ea155a35d74c8e306 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Mon, 16 Dec 2024 09:25:19 +0000 Subject: [PATCH 6/9] delete unused code --- torchvision/ops/__init__.py | 3 --- torchvision/ops/xpu/__init__.py | 5 ----- 2 files changed, 8 deletions(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index bb944347ce0..b19a84ab65d 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -10,7 +10,6 @@ generalized_box_iou, masks_to_boxes, nms, - nms_kernel_postprocess, remove_small_boxes, ) from .ciou_loss import complete_box_iou_loss @@ -27,10 +26,8 @@ from .roi_align import roi_align, RoIAlign from .roi_pool import roi_pool, RoIPool from .stochastic_depth import stochastic_depth, StochasticDepth -from .xpu import _register_xpu_ops _register_custom_op() -_register_xpu_ops() __all__ = [ "masks_to_boxes", diff --git a/torchvision/ops/xpu/__init__.py b/torchvision/ops/xpu/__init__.py index 863d37a4c4e..e69de29bb2d 100644 --- a/torchvision/ops/xpu/__init__.py +++ b/torchvision/ops/xpu/__init__.py @@ -1,5 +0,0 @@ -from .nms import xpu_triton_nms - - -def _register_xpu_ops(): - pass From 3e75978a0c176f175836dba80f63691bb87e4328 Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Mon, 16 Dec 2024 09:25:26 +0000 Subject: [PATCH 7/9] Add comments for the code --- torchvision/csrc/ops/cpu/nms_kernel.cpp | 64 ++++++++++++++----------- torchvision/ops/boxes.py | 19 ++++++-- torchvision/ops/xpu/nms.py | 10 +++- 3 files changed, 60 insertions(+), 33 deletions(-) diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 7756d577fbd..6bb03a355d4 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -108,39 +108,49 @@ at::Tensor nms_kernel( } +/** + * @brief Post-processes the results of the Non-Maximum Suppression (NMS) algorithm. + * + * This function iterates over the boxes and determines which ones to keep based on the IOU (Intersection Over Union) keep-out mask. + * It uses a 32-bitmask to efficiently track and suppress overlapping boxes. + * + * @param order A tensor containing the order of the boxes. + * @param iou_keep_out_mask A tensor containing the IOU keep-out mask. This mask has the shape (N, N//32), where N is the number of boxes. + * The datatype MUST be int32. + * @param num_boxes The total number of boxes. + * @return A tensor containing the indices of the boxes to keep. + */ + at::Tensor nms_kernel_postprocess( const at::Tensor& order, const at::Tensor& iou_keep_out_mask, const int64_t num_boxes) { - - // ceil div to 32. Which is the size of ulong type. - const int col_blocks = (num_boxes + 32 - 1) / 32; - std::vector remove_box(col_blocks); - std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks); - - - at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU)); - int64_t * keep_data_ptr = keep.data_ptr(); - - unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr(); - int num_to_keep = 0; - unsigned long long* iou_keep_out_mask_data_ptr0 = (unsigned long long*)iou_keep_out_mask[0].data_ptr(); - unsigned long long*iou_keep_out_mask_data_ptr1 = (unsigned long long*)iou_keep_out_mask[1].data_ptr(); - - // Note that the iou_keep_out_mask has the shape of (N, N//32) - for (int64_t i = 0; i < num_boxes; i++) { - int nblock = i / 32; - // module 32 - int inblock = (31 - i) & (32 -1); - - if (!(remove_box[nblock] & (1UL << inblock))){ - keep_data_ptr[num_to_keep++]=i; - unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks; - for (int j = nblock; j < col_blocks; j++){ - remove_box[j] |= p[j]; - } + // Calculate the number of 32-bit blocks needed to cover all boxes + const int col_blocks = (num_boxes + 32 - 1) / 32; + std::vector remove_box(col_blocks); + std::memset(&remove_box[0], 0, sizeof(unsigned long) * col_blocks); + + + at::Tensor keep = at::empty({num_boxes}, order.options().dtype(at::kLong).device(at::kCPU)); + int64_t * keep_data_ptr = keep.data_ptr(); + + unsigned long long* iou_keep_out_mask_data_ptr = (unsigned long long*)iou_keep_out_mask.data_ptr(); + int num_to_keep = 0; + // Note that the iou_keep_out_mask has the shape of (N, N//32) + // The following function iterate over each box to check if it should be kept + for (int64_t i = 0; i < num_boxes; i++) { + int nblock = i / 32; + // This is equivalent to module: 31 - i % 32 + int inblock = (31 - i) & (32 -1); + + if (!(remove_box[nblock] & (1UL << inblock))){ + keep_data_ptr[num_to_keep++]=i; + unsigned long long*p = iou_keep_out_mask_data_ptr + i*col_blocks; + for (int j = nblock; j < col_blocks; j++){ + remove_box[j] |= p[j]; } } + } return order.index({keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)}); } diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 89cdc2f761d..5d8356a5cae 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -41,10 +41,6 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: return torch.ops.torchvision.nms(boxes, scores, iou_threshold) -def nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor: - return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) - - def batched_nms( boxes: Tensor, scores: Tensor, @@ -142,6 +138,21 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor: return keep +def _nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) -> Tensor: + """ + Post-processes the results of the non-maximum suppression (NMS) kernel. + Args: + order (Tensor): A tensor containing the order of the boxes. + iou_keep_out_mask (Tensor): A tensor containing the mask of boxes to keep based on IoU. + The datatype is int32. + num_boxes (int): The number of boxes. + Returns: + Tensor: A tensor containing the post-processed results of the NMS kernel. + """ + + return torch.ops.torchvision.nms_kernel_postprocess(order, iou_keep_out_mask, num_boxes) + + def clip_boxes_to_image(boxes: Tensor, size: Tuple[int, int]) -> Tensor: """ Clip boxes so that they lie inside an image of size ``size``. diff --git a/torchvision/ops/xpu/nms.py b/torchvision/ops/xpu/nms.py index 0e46b322ed7..7b43e0f8b59 100644 --- a/torchvision/ops/xpu/nms.py +++ b/torchvision/ops/xpu/nms.py @@ -1,6 +1,6 @@ import torch import triton -from torchvision.ops.boxes import nms_kernel_postprocess +from torchvision.ops.boxes import _nms_kernel_postprocess from torchvision.ops.triton.nms import triton_nms_IoU_kernel @@ -42,6 +42,10 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), ) + + # This triton kernel will calcualte the IoU matrix for all the input boxes (iou_keep_out_mask). + # The iou_keep_out_mask is defined as a 32-bit long bitmask matrix. So the matrix shape is [N, N//32]. + # Each item [i, j] will be interpreted as whether we should keep box j when we choose box i. triton_nms_IoU_kernel[grid]( boxes, iou_keep_out_mask, @@ -53,4 +57,6 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float) num_warps=4, ) - return nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device) + # The postprocess will calculate the final indices of the boxes that should be kept. + # It is a serialized process, and we choose to run it on CPU for more generalization. + return _nms_kernel_postprocess(order.cpu(), iou_keep_out_mask.cpu(), num_boxes).to(order.device) From 3cb98951e331844c87930d2f8b30e49589409aae Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Mon, 16 Dec 2024 09:26:56 +0000 Subject: [PATCH 8/9] delete unused code --- torchvision/ops/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index b19a84ab65d..827505b842d 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -29,13 +29,13 @@ _register_custom_op() + __all__ = [ "masks_to_boxes", "deform_conv2d", "DeformConv2d", "nms", "batched_nms", - "nms_kernel_postprocess", "remove_small_boxes", "clip_boxes_to_image", "box_convert", From e366724b50b9a441b25ac9311b7ef90ef43d3cdd Mon Sep 17 00:00:00 2001 From: "Su, Tong" Date: Mon, 16 Dec 2024 09:45:00 +0000 Subject: [PATCH 9/9] Add comments on code --- torchvision/ops/triton/nms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/triton/nms.py b/torchvision/ops/triton/nms.py index c80b252e949..545c850e6d0 100644 --- a/torchvision/ops/triton/nms.py +++ b/torchvision/ops/triton/nms.py @@ -13,8 +13,8 @@ def _combine_bits(val0, val1): def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, stride_j, BLOCK_SIZE: tl.constexpr): """ This nms_kernel computes the supressed mask of boxes [i, j]. - mask[i, j]==1 means if we choose box 1, the box j will be supressed. - The output is a mask of size [num_boxes, num_boxes]. + mask[i, j]==1 means if we choose box i, the box j will be supressed. + The output is a mask of size [num_boxes, num_boxes//32], where each item is int32. Args: boxes (tl.tensor): A tensor containing the bounding boxes with shape (num_boxes, 4). @@ -24,6 +24,9 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str stride_i (int): The stride of the output tensor along the first dimension. stride_j (int): The stride of the output tensor along the second dimension. BLOCK_SIZE (tl.constexpr): The block size for the Triton kernel. + Returns: + Tensor (int32): Tensor with size [num_boxes, num_boxes//32]. It indicates that if `box i` is + choosen, whether box `j` could be choosen. The value `1` means it cannot be choosen. """ # The Triton kernel is a 2D block kernel. The block size is BLOCK_SIZE x BLOCK_SIZE. @@ -75,6 +78,8 @@ def triton_nms_IoU_kernel(boxes, output_ptr, threshold, num_boxes, stride_i, str shift_offsets = tl.broadcast_to(shift_offsets.to(tl.int32), [BLOCK_SIZE, BLOCK_SIZE]) iou_keep_out_bit_mask = iou_keep_out_bit_mask << shift_offsets + # The process of combine bits. Note that the Triton seems having problem when the dtype is int64. + # Thus choosing 32 bits as the mask. And convert it to int64 at the end to avoid further potential overflow. iou_keep_out_bit_mask = tl.reshape(iou_keep_out_bit_mask, (BLOCK_SIZE, (BLOCK_SIZE + 32 - 1) // 32, 32)) iou_keep_out_combined = tl.reduce(iou_keep_out_bit_mask, axis=2, combine_fn=_combine_bits) iou_keep_out_combined = iou_keep_out_combined.to(tl.int64)