Skip to content

Commit ad5698b

Browse files
committed
format code
1 parent f9f7b32 commit ad5698b

File tree

2 files changed

+4
-6
lines changed

2 files changed

+4
-6
lines changed

torchvision/ops/triton/nms.py

-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
import torch
2-
import torchvision.ops
31
import triton
42
import triton.language as tl
5-
from torch import Tensor
6-
from torch._decomp import register_decomposition
7-
from torchvision import ops
83

94

105
@triton.jit

torchvision/ops/xpu/nms.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ def xpu_triton_nms(boxes: torch.Tensor, scores: torch.Tensor, threshold: float)
3737
boxes = boxes[order]
3838
iou_keep_out_mask = torch.zeros(num_boxes, num_boxes, dtype=torch.bool, device=boxes.device)
3939

40-
grid = lambda meta: (triton.cdiv(num_boxes, meta["BLOCK_SIZE"]), triton.cdiv(num_boxes, meta["BLOCK_SIZE"]))
40+
grid = lambda meta: ( # noqa: E731
41+
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
42+
triton.cdiv(num_boxes, meta["BLOCK_SIZE"]),
43+
)
4144
# TODO: We need to tune the config from different devices.
4245
triton_nms_IoU_kernel[grid](boxes, iou_keep_out_mask, threshold, num_boxes, BLOCK_SIZE=64, num_warps=8)
4346

0 commit comments

Comments
 (0)