We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 04d611a commit 53d2486Copy full SHA for 53d2486
torchao/dtypes/uintx/semi_sparse_layout.py
@@ -44,14 +44,15 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(
44
# must pad
45
row, col = tmp.shape
46
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
47
+
48
tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp)
49
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
50
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
51
w_vals_int8,
52
tmp_padded.t(),
53
alpha=w_scales.to(torch.float32),
54
out_dtype=torch.bfloat16,
- ).t()[:row, :]
55
+ ).t()[:row, :]
56
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
57
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
58
)
0 commit comments