Skip to content

Commit 53d2486

Browse files
authored
fix lint (#1379)
Summary: run `ruff format` to fix lint on main branch Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent 04d611a commit 53d2486

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchao/dtypes/uintx/semi_sparse_layout.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(
4444
# must pad
4545
row, col = tmp.shape
4646
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
47+
4748
tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp)
4849
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
4950
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
5051
w_vals_int8,
5152
tmp_padded.t(),
5253
alpha=w_scales.to(torch.float32),
5354
out_dtype=torch.bfloat16,
54-
).t()[:row, :]
55+
).t()[:row, :]
5556
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
5657
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
5758
)

0 commit comments

Comments
 (0)