Skip to content

Commit 88440e8

Browse files
authored
Unbreak build after #621 (#826)
1 parent 0601b5c commit 88440e8

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchao/dtypes/affine_quantized_tensor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
from typing import Tuple, Optional, Union
3-
import torchao.ops
43
from collections import defaultdict
54
import functools
65
import math
@@ -1425,6 +1424,8 @@ def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor,
14251424

14261425
def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias):
14271426
from torchao.sparsity.marlin import marlin_24_workspace, const
1427+
from torchao.ops import marlin_24_gemm
1428+
14281429
assert isinstance(weight_tensor, AffineQuantizedTensor)
14291430

14301431
sparse_w_int4 = weight_tensor.layout_tensor.int_data
@@ -1441,7 +1442,7 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b
14411442
size_k = input_2d.shape[1]
14421443
workspace_24 = marlin_24_workspace(original_shape[1])
14431444

1444-
out = torchao.ops.marlin_24_gemm(
1445+
out = marlin_24_gemm(
14451446
input_2d, sparse_w_int4, meta, scale,
14461447
workspace_24, num_bits, size_m, size_n, size_k
14471448
)

0 commit comments

Comments
 (0)