Skip to content

Commit 7a3f5a0

Browse files
authored
Refactor the API for quant method argument for quantize function (#400)
Summary: Addressing feedback from #384 and #375 Test Plan: regression tests python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 512b5d6 commit 7a3f5a0

File tree

11 files changed

+73
-109
lines changed

11 files changed

+73
-109
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ All with no intrusive code changes and minimal accuracy degradation.
1919
Quantizing your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/) and a HuggingFace inference example [here](scripts/hf_eval.py)
2020

2121
```python
22-
from torchao.quantization.quant_api import quantize
23-
m = quantize(m, "int4wo")
22+
from torchao.quantization.quant_api import quantize, int4_weight_only
23+
m = quantize(m, int4_weight_only())
2424
```
2525

2626
Benchmarks are run on a machine with a single A100 GPU using the script in `_models/llama` which generates text in a latency-optimized way (batchsize=1)

test/dtypes/test_affine_quantized.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
TestCase,
33
run_tests,
44
)
5-
from torchao.quantization.quant_api import int4wo
5+
from torchao.quantization.quant_api import int4_weight_only
66
import torch
77
import unittest
88

@@ -12,8 +12,8 @@ class TestAffineQuantized(TestCase):
1212
def test_tensor_core_layout_transpose(self):
1313
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
1414
shape = t.shape
15-
apply_int4wo_quant = int4wo(groupsize=32)
16-
aqt = apply_int4wo_quant(t)
15+
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
16+
aqt = apply_int4_weight_only_quant(t)
1717
aqt_shape = aqt.shape
1818
self.assertEqual(aqt_shape, shape)
1919

test/integration/test_integration.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
DynamicallyPerAxisQuantizedLinear,
2121
)
2222
from torchao.quantization.quant_api import (
23-
int4wo,
24-
int8wo,
25-
int8da_int8w,
23+
int4_weight_only,
24+
int8_weight_only,
25+
int8_dynamic_activation_int8_weight,
2626
quantize,
2727
_replace_with_custom_fn_if_matches_filter,
2828
)
@@ -98,21 +98,21 @@
9898

9999
def _int8wo_api(mod):
100100
if TORCH_VERSION_AFTER_2_4:
101-
quantize(mod, int8wo())
101+
quantize(mod, int8_weight_only())
102102
unwrap_tensor_subclass(mod)
103103
else:
104104
change_linear_weights_to_int8_woqtensors(mod)
105105

106106
def _int8da_int8w_api(mod):
107107
if TORCH_VERSION_AFTER_2_4:
108-
quantize(mod, int8da_int8w())
108+
quantize(mod, int8_dynamic_activation_int8_weight())
109109
unwrap_tensor_subclass(mod)
110110
else:
111111
change_linear_weights_to_int8_dqtensors(mod)
112112

113113
def _int4wo_api(mod):
114114
if TORCH_VERSION_AFTER_2_4:
115-
quantize(mod, int4wo())
115+
quantize(mod, int4_weight_only())
116116
unwrap_tensor_subclass(mod)
117117
else:
118118
change_linear_weights_to_int4_woqtensors(mod)
@@ -832,7 +832,10 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
832832

833833
def api(mod):
834834
if TORCH_VERSION_AFTER_2_4:
835-
quantize(mod, int4wo(**kwargs))
835+
kwargs_copy = kwargs.copy()
836+
kwargs_copy["group_size"] = groupsize
837+
del kwargs_copy["groupsize"]
838+
quantize(mod, int4_weight_only(**kwargs_copy))
836839
unwrap_tensor_subclass(mod)
837840
else:
838841
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
@@ -853,7 +856,7 @@ def test_dynamic_quant(self):
853856
m = nn.Sequential(nn.Linear(K, N))
854857

855858
y_ref = m(x)
856-
quantize(m, int8da_int8w())
859+
quantize(m, int8_dynamic_activation_int8_weight())
857860
y_test = m(x)
858861

859862
sqnr = compute_error(y_ref, y_test)
@@ -1463,7 +1466,7 @@ def test_get_model_size_aqt(self, api, test_device, test_dtype):
14631466
api(model)
14641467
size2 = torchao.utils.get_model_size_in_bytes(model)
14651468
self.assertTrue(size2 < size)
1466-
1469+
14671470

14681471

14691472

test/quantization/test_quant_api.py

+18-33
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@
3636
_replace_with_custom_fn_if_matches_filter,
3737
Quantizer,
3838
TwoStepQuantizer,
39-
int8da_int4w,
40-
int4wo,
41-
int8wo,
42-
int8da_int8w,
39+
int8_dynamic_activation_int4_weight,
40+
int4_weight_only,
41+
int8_weight_only,
42+
int8_dynamic_activation_int8_weight,
4343
)
4444
from torchao.utils import (
4545
TORCH_VERSION_AFTER_2_3,
@@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
8989

9090
class TorchCompileDynamicQuantizer(Quantizer):
9191
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
92-
quantize(model, int8da_int8w())
92+
quantize(model, int8_dynamic_activation_int8_weight())
9393
return model
9494

9595
class ToyLinearModel(torch.nn.Module):
@@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
152152
def test_dynamic_quant_gpu_singleline(self):
153153
m = ToyLinearModel().eval()
154154
example_inputs = m.example_inputs()
155-
m = quantize(m, int8da_int8w())
155+
m = quantize(m, int8_dynamic_activation_int8_weight())
156156
quantized = m(*example_inputs)
157157
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
158158
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
195195
)
196196
m = ToyLinearModel().eval().cpu()
197197
def api(model):
198-
model = quantize(model, int8wo())
198+
model = quantize(model, int8_weight_only())
199199
unwrap_tensor_subclass(model)
200200

201201
api(m)
@@ -335,7 +335,7 @@ def test_8da4w_quantizer_eval(self):
335335
)
336336

337337
@unittest.skip("skipping until we get checkpoints for gpt-fast")
338-
def test_gptq_quantizer_int4wo(self):
338+
def test_gptq_quantizer_int4_weight_only(self):
339339
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
340340
from torchao._models._eval import InputRecorder, TransformerEvalWrapper
341341
torchao._models.llama.model.use_index_put_for_kv_cache = True
@@ -397,7 +397,7 @@ def test_gptq_quantizer_int4wo(self):
397397
)
398398

399399
@unittest.skip("skipping until we get checkpoints for gpt-fast")
400-
def test_quantizer_int4wo(self):
400+
def test_quantizer_int4_weight_only(self):
401401
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
402402
from torchao._models._eval import TransformerEvalWrapper
403403
precision = torch.bfloat16
@@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self):
499499
# TODO: move to a separate test file
500500
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
501501
def test_quantized_tensor_subclass_8da4w(self):
502-
groupsize = 32
502+
group_size = 32
503503
m = ToyLinearModel().eval()
504504
m_copy = copy.deepcopy(m)
505505
example_inputs = m.example_inputs()
506-
m = quantize(m, int8da_int4w(groupsize=groupsize))
506+
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))
507507

508508
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
509509
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self):
514514
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
515515
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
516516

517-
quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize)
517+
quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size)
518518
m_copy = quantizer.quantize(m_copy)
519519
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
520520
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)
@@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self):
531531
m_copy = copy.deepcopy(m)
532532
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
533533

534-
groupsize = 32
535-
m = quantize(m, int4wo(groupsize=groupsize))
534+
group_size = 32
535+
m = quantize(m, int4_weight_only(group_size=group_size))
536536
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
537537
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
538538

539539
# reference
540-
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize)
540+
_ref_change_linear_weights_to_int4_woqtensors(m_copy, groupsize=group_size)
541541

542542
res = m(*example_inputs)
543543
ref = m_copy(*example_inputs)
@@ -552,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
552552
m_copy = copy.deepcopy(m)
553553
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
554554

555-
m = quantize(m, int8wo())
555+
m = quantize(m, int8_weight_only())
556556

557557
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
558558
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
@@ -575,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
575575
m_copy = copy.deepcopy(m)
576576
# setting batch_size to 20 to be compatible with the kernel
577577
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
578-
m = quantize(m, int8da_int8w())
578+
m = quantize(m, int8_dynamic_activation_int8_weight())
579579

580580
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
581581
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -602,29 +602,14 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
602602
# make sure it compiles
603603
torch._export.aot_compile(m_unwrapped, example_inputs)
604604

605-
def test_register_apply_tensor_subclass(self):
606-
from torchao import register_apply_tensor_subclass
607-
def apply_my_dtype(weight):
608-
return weight * 2
609-
610-
m = ToyLinearModel().eval()
611-
example_inputs = m.example_inputs()
612-
with self.assertRaisesRegex(ValueError, "not supported"):
613-
quantize(m, "my_dtype")
614-
615-
register_apply_tensor_subclass("my_dtype", apply_my_dtype)
616-
# make sure it runs
617-
quantize(m, "my_dtype")
618-
m(*example_inputs)
619-
620605
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
621606
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
622607
def test_quantized_tensor_subclass_save_load(self):
623608
m = ToyLinearModel().eval().to(torch.bfloat16)
624609
m_copy = copy.deepcopy(m)
625610
example_inputs = m.example_inputs(dtype=torch.bfloat16)
626611

627-
m = quantize(m, "int8_weight_only")
612+
m = quantize(m, int8_weight_only())
628613
ref = m(*example_inputs)
629614
with tempfile.NamedTemporaryFile() as f:
630615
torch.save(m.state_dict(), f)

torchao/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@
3131
from torchao.quantization import (
3232
autoquant,
3333
quantize,
34-
register_apply_tensor_subclass,
3534
)
3635
from . import dtypes
3736

3837
__all__ = [
3938
"dtypes",
4039
"autoquant",
4140
"quantize",
42-
"register_apply_tensor_subclass",
4341
]

torchao/dtypes/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .nf4tensor import NF4Tensor, to_nf4
22
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
33
from .uint4 import UInt4Tensor
4-
from .aqt import AffineQuantizedTensor, to_affine_quantized
4+
from .affine_quantized_tensor import AffineQuantizedTensor, to_affine_quantized
55

66
__all__ = [
77
"NF4Tensor",
File renamed without changes.

torchao/quantization/README.md

+10-12
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ from torch._inductor.runtime.runtime_utils import do_bench_gpu
8080
import copy
8181
from torchao.quantization.quant_api import (
8282
quantize,
83-
int4wo,
83+
int4_weight_only,
8484
)
8585

8686
class ToyLinearModel(torch.nn.Module):
@@ -104,8 +104,8 @@ example_inputs = m.example_inputs(dtype=dtype, device="cuda")
104104

105105
m_bf16 = torch.compile(m_bf16, mode='max-autotune')
106106
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
107-
groupsize = 32
108-
m = quantize(m, int4wo(groupsize=groupsize))
107+
group_size = 32
108+
m = quantize(m, int4_weight_only(group_size=group_size))
109109

110110
torch._inductor.config.force_fuse_int_mm_with_mul = True
111111
torch._inductor.config.use_mixed_mm = True
@@ -152,7 +152,7 @@ for n, m in model.named_modules():
152152
The model/tensor subclass should also be compatible with AOTI and torch.export, currently we can support
153153
`torch.export.export` and `torch.aot_compile` with the following workaround:
154154
```
155-
from torchao.quantization.utils import unwrap_tensor_subclass
155+
from torchao.utils import unwrap_tensor_subclass
156156
m_unwrapped = unwrap_tensor_subclass(m)
157157
158158
@@ -169,11 +169,10 @@ torch._export.aot_compile(m_unwrapped, example_inputs)
169169
```python
170170
# Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor
171171
torch._inductor.config.force_fuse_int_mm_with_mul = True
172-
from torchao.quantization import quant_api
173172

174173
# for torch 2.4+
175-
from torchao.quantization.quant_api import quantize
176-
quantize(model, "int8_dynamic")
174+
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
175+
quantize(model, int8_dynamic_activation_int8_weight())
177176

178177
# for torch 2.2.2 and 2.3
179178
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
@@ -184,9 +183,8 @@ change_linear_weights_to_int8_dqtensors(model)
184183

185184
```python
186185
# for torch 2.4+
187-
from torchao.quantization.quant_api import quantize
188-
from torchao.quantization.quant_api import int8wo
189-
quantize(model, "int8_weight_only")
186+
from torchao.quantization import quantize, int8_weight_only
187+
quantize(model, int8_weight_only())
190188

191189
# for torch 2.2.2 and 2.3
192190
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
@@ -200,8 +198,8 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
200198

201199
```python
202200
# for torch 2.4+
203-
from torchao.quantization.quant_api import quantize
204-
quantize(model, "int4_weight_only")
201+
from torchao.quantization import quantize, int4_weight_only
202+
quantize(model, int4_weight_only())
205203

206204
# for torch 2.2.2 and 2.3
207205
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors

torchao/quantization/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,8 @@
3232
"dequantize_affine",
3333
"choose_qprams_affine",
3434
"quantize",
35-
"register_apply_tensor_subclass",
35+
"int8_dynamic_act_int4_weight",
36+
"int8_dynamic_act_int8_weight",
37+
"int4_weight_only",
38+
"int8_weight_only",
3639
]

0 commit comments

Comments
 (0)