Skip to content

Commit 35ac8f5

Browse files
committed
Update torchao after pytorch/pytorch#129940
Summary: Fixes torchao code after the bc breaking change in pytorch/pytorch#129940 Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4 python test/integration/test_integration.py -k test_save_load_int4woqtensors_2_cpu Reviewers: Subscribers: Tasks: Tags:
1 parent 1029df3 commit 35ac8f5

File tree

5 files changed

+13
-10
lines changed

5 files changed

+13
-10
lines changed

test/dtypes/test_uint4.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,14 @@
44
PerChannelSymmetricWeightUInt4Tensor,
55
)
66
import unittest
7-
from unittest import TestCase, main
87
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e
98
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
109

1110
from torch._export import capture_pre_autograd_graph
12-
from torch._export import dynamic_dim
1311
from torch.testing._internal.common_quantization import (
1412
NodeSpec as ns,
1513
QuantizationTestCase,
1614
)
17-
from torchao.quantization.utils import (
18-
compute_error,
19-
)
2015
from torchao.quantization.quant_api import (
2116
_replace_with_custom_fn_if_matches_filter,
2217
)
@@ -30,7 +25,6 @@
3025
QuantizationAnnotation,
3126
)
3227
import copy
33-
from packaging import version
3428

3529

3630
def _apply_weight_only_uint4_quant(model):
@@ -229,4 +223,4 @@ def forward(self, x):
229223
)
230224

231225
if __name__ == "__main__":
232-
main()
226+
unittest.main()

test/integration/test_integration.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
from torchao.utils import (
8282
TORCH_VERSION_AFTER_2_3,
8383
TORCH_VERSION_AFTER_2_4,
84+
TORCH_VERSION_AFTER_2_5,
8485
unwrap_tensor_subclass,
8586
is_fbcode,
8687
benchmark_model
@@ -734,6 +735,7 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
734735

735736
@parameterized.expand(COMMON_DEVICE_DTYPE)
736737
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
738+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
737739
def test_int4_weight_only_quant_subclass(self, device, dtype):
738740
if dtype != torch.bfloat16:
739741
self.skipTest(f"Fails for {dtype}")
@@ -744,6 +746,7 @@ def test_int4_weight_only_quant_subclass(self, device, dtype):
744746

745747
@parameterized.expand(COMMON_DEVICE_DTYPE)
746748
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
749+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
747750
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
748751
if dtype != torch.bfloat16:
749752
self.skipTest(f"Fails for {dtype}")
@@ -1020,7 +1023,8 @@ def test_save_load_int8woqtensors(self, device, dtype):
10201023
self._test_handle_save_load_meta_impl(_int8wo_api, device, test_dtype=dtype)
10211024

10221025
@parameterized.expand(COMMON_DEVICE_DTYPE)
1023-
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
1026+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch 2.3+.")
1027+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
10241028
@torch.no_grad()
10251029
def test_save_load_int4woqtensors(self, device, dtype):
10261030
if dtype != torch.bfloat16:
@@ -1500,7 +1504,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):
15001504

15011505

15021506
class TestBenchmarkModel(unittest.TestCase):
1503-
1507+
15041508
class ToyLinearModel(torch.nn.Module):
15051509
def __init__(self, m=64, n=32, k=64):
15061510
super().__init__()

test/quantization/test_quant_api.py

+2
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from torchao.utils import (
4545
TORCH_VERSION_AFTER_2_3,
4646
TORCH_VERSION_AFTER_2_4,
47+
TORCH_VERSION_AFTER_2_5,
4748
)
4849
from pathlib import Path
4950
from torchao._models.llama.tokenizer import get_tokenizer
@@ -522,6 +523,7 @@ def test_quantized_tensor_subclass_8da4w(self):
522523
self.assertTrue(torch.equal(res, ref))
523524

524525
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
526+
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
525527
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
526528
def test_quantized_tensor_subclass_int4(self):
527529
# use 1024 so that we don't need padding

torchao/dtypes/affine_quantized_tensor.py

+2
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ def __tensor_unflatten__(
461461

462462
@classmethod
463463
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
464+
# assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
465+
# packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, inner_k_tiles)
464466
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
465467
scale = scale.reshape(int_data.shape[0], -1)
466468
zero_point = zero_point.reshape(int_data.shape[0], -1)

torchao/quantization/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
348348
quant_min = 0
349349
quant_max = 2 ** n_bit - 1
350350

351-
return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
351+
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
352+
return int_data
352353

353354
def groupwise_affine_dequantize_tensor_from_qparams(
354355
w_int4x8,

0 commit comments

Comments
 (0)