Skip to content

Commit 2ce9bc4

Browse files
committed
Renaming quantize to quantize_
Summary: Addressing feedback for `quantize` API from pytorch#391 (comment) this is an API that changes model inplace, so we want to change the name to reflect that. inplace model quantization is important especially for LLM since it will be hard to load the model to memory. we typically load the model to meta device and then load the quantized weight. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
1 parent 5d22ad2 commit 2ce9bc4

File tree

8 files changed

+37
-37
lines changed

8 files changed

+37
-37
lines changed

test/integration/test_integration.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
int4_weight_only,
2424
int8_weight_only,
2525
int8_dynamic_activation_int8_weight,
26-
quantize,
26+
quantize_,
2727
_replace_with_custom_fn_if_matches_filter,
2828
)
2929
# APIs to be deprecated (used for torch 2.2.2 and 2.3)
@@ -98,21 +98,21 @@
9898

9999
def _int8wo_api(mod):
100100
if TORCH_VERSION_AFTER_2_4:
101-
quantize(mod, int8_weight_only(), set_inductor_config=False)
101+
quantize_(mod, int8_weight_only(), set_inductor_config=False)
102102
unwrap_tensor_subclass(mod)
103103
else:
104104
change_linear_weights_to_int8_woqtensors(mod)
105105

106106
def _int8da_int8w_api(mod):
107107
if TORCH_VERSION_AFTER_2_4:
108-
quantize(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
108+
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
109109
unwrap_tensor_subclass(mod)
110110
else:
111111
change_linear_weights_to_int8_dqtensors(mod)
112112

113113
def _int4wo_api(mod):
114114
if TORCH_VERSION_AFTER_2_4:
115-
quantize(mod, int4_weight_only(), set_inductor_config=False)
115+
quantize_(mod, int4_weight_only(), set_inductor_config=False)
116116
unwrap_tensor_subclass(mod)
117117
else:
118118
change_linear_weights_to_int4_woqtensors(mod)
@@ -127,8 +127,8 @@ def _int4wo_api(mod):
127127
def undo_recommended_configs():
128128
torch._inductor.config.coordinate_descent_tuning = False
129129
torch._inductor.config.coordinate_descent_check_all_directions = False
130-
torch._inductor.config.force_fuse_int_mm_with_mul = False
131-
torch._inductor.config.fx_graph_cache = False
130+
torch._inductor.config.force_fuse_int_mm_with_mul = False
131+
torch._inductor.config.fx_graph_cache = False
132132
torch._inductor.config.triton.unique_kernel_names = False
133133
torch.set_float32_matmul_precision("highest")
134134

@@ -844,7 +844,7 @@ def api(mod):
844844
kwargs_copy = kwargs.copy()
845845
kwargs_copy["group_size"] = groupsize
846846
del kwargs_copy["groupsize"]
847-
quantize(mod, int4_weight_only(**kwargs_copy))
847+
quantize_(mod, int4_weight_only(**kwargs_copy))
848848
unwrap_tensor_subclass(mod)
849849
else:
850850
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
@@ -865,7 +865,7 @@ def test_dynamic_quant(self):
865865
m = nn.Sequential(nn.Linear(K, N))
866866

867867
y_ref = m(x)
868-
quantize(m, int8_dynamic_activation_int8_weight())
868+
quantize_(m, int8_dynamic_activation_int8_weight())
869869
y_test = m(x)
870870

871871
sqnr = compute_error(y_ref, y_test)
@@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype):
12591259
out3 = mod(example_input)
12601260
sqnr2 = SQNR(out, out3)
12611261
self.assertTrue(sqnr2 >= 30)
1262-
1262+
12631263

12641264
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
12651265
[

test/prototype/test_quant_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
1818
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
19-
from torchao.quantization.quant_api import quantize
19+
from torchao.quantization.quant_api import quantize_
2020

2121

2222
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
@@ -91,7 +91,7 @@ def test_quant_llm_quantize(self, ebits, mbits, bias):
9191

9292
linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
9393
fpx_linear = copy.deepcopy(linear)
94-
quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
94+
quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))
9595

9696
x = torch.randn(N, IC, device=device, dtype=torch.half)
9797
expected = fpx_linear(x)

test/quantization/test_quant_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
Int8WeightOnlyQuantizedLinearWeight,
3232
Int4WeightOnlyQuantizedLinearWeight,
3333
)
34-
from torchao import quantize
34+
from torchao import quantize_
3535
from torchao.quantization.quant_api import (
3636
_replace_with_custom_fn_if_matches_filter,
3737
Quantizer,
@@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
8989

9090
class TorchCompileDynamicQuantizer(Quantizer):
9191
def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
92-
quantize(model, int8_dynamic_activation_int8_weight())
92+
quantize_(model, int8_dynamic_activation_int8_weight())
9393
return model
9494

9595
class ToyLinearModel(torch.nn.Module):
@@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
152152
def test_dynamic_quant_gpu_singleline(self):
153153
m = ToyLinearModel().eval()
154154
example_inputs = m.example_inputs()
155-
m = quantize(m, int8_dynamic_activation_int8_weight())
155+
quantize_(m, int8_dynamic_activation_int8_weight())
156156
quantized = m(*example_inputs)
157157
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
158158
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
195195
)
196196
m = ToyLinearModel().eval().cpu()
197197
def api(model):
198-
model = quantize(model, int8_weight_only())
198+
quantize_(model, int8_weight_only())
199199
unwrap_tensor_subclass(model)
200200

201201
api(m)
@@ -501,7 +501,7 @@ def test_quantized_tensor_subclass_8da4w(self):
501501
m = ToyLinearModel().eval()
502502
m_copy = copy.deepcopy(m)
503503
example_inputs = m.example_inputs()
504-
m = quantize(m, int8_dynamic_activation_int4_weight(group_size=group_size))
504+
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))
505505

506506
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
507507
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -530,7 +530,7 @@ def test_quantized_tensor_subclass_int4(self):
530530
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
531531

532532
group_size = 32
533-
m = quantize(m, int4_weight_only(group_size=group_size))
533+
quantize_(m, int4_weight_only(group_size=group_size))
534534
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
535535
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
536536

@@ -550,7 +550,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
550550
m_copy = copy.deepcopy(m)
551551
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))
552552

553-
m = quantize(m, int8_weight_only())
553+
quantize_(m, int8_weight_only())
554554

555555
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
556556
assert isinstance(m.linear2.weight, AffineQuantizedTensor)
@@ -573,7 +573,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
573573
m_copy = copy.deepcopy(m)
574574
# setting batch_size to 20 to be compatible with the kernel
575575
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
576-
m = quantize(m, int8_dynamic_activation_int8_weight())
576+
quantize_(m, int8_dynamic_activation_int8_weight())
577577

578578
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
579579
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
@@ -607,7 +607,7 @@ def test_quantized_tensor_subclass_save_load(self):
607607
m_copy = copy.deepcopy(m)
608608
example_inputs = m.example_inputs(dtype=torch.bfloat16)
609609

610-
m = quantize(m, int8_weight_only())
610+
quantize_(m, int8_weight_only())
611611
ref = m(*example_inputs)
612612
with tempfile.NamedTemporaryFile() as f:
613613
torch.save(m.state_dict(), f)

torchao/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030

3131
from torchao.quantization import (
3232
autoquant,
33-
quantize,
33+
quantize_,
3434
)
3535
from . import dtypes
3636

3737
__all__ = [
3838
"dtypes",
3939
"autoquant",
40-
"quantize",
40+
"quantize_",
4141
]
4242

4343
# test-pytorchbot

torchao/quantization/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
7474
from torchao.dtypes import to_affine_quantized
7575
import copy
7676
from torchao.quantization.quant_api import (
77-
quantize,
77+
quantize_,
7878
int4_weight_only,
7979
)
8080

@@ -101,7 +101,7 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
101101
# apply int4 weight only quant (compatible with tinygemm int4 weight only quant mm kernel in torchao)
102102
group_size = 32
103103
# only works for torch 2.4+
104-
m = quantize(m, int4_weight_only(group_size=group_size))
104+
quantize_(m, int4_weight_only(group_size=group_size))
105105

106106
# temporary workaround for tensor subclass + torch.compile
107107
from torchao.utils import unwrap_tensor_subclass
@@ -168,7 +168,7 @@ torch._inductor.config.force_fuse_int_mm_with_mul = True
168168

169169
# for torch 2.4+
170170
from torchao.quantization import quantize, int8_dynamic_activation_int8_weight
171-
quantize(model, int8_dynamic_activation_int8_weight())
171+
quantize_(model, int8_dynamic_activation_int8_weight())
172172

173173
# for torch 2.2.2 and 2.3
174174
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
@@ -180,7 +180,7 @@ change_linear_weights_to_int8_dqtensors(model)
180180
```python
181181
# for torch 2.4+
182182
from torchao.quantization import quantize, int8_weight_only
183-
quantize(model, int8_weight_only())
183+
quantize_(model, int8_weight_only())
184184

185185
# for torch 2.2.2 and 2.3
186186
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
@@ -195,7 +195,7 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
195195
```python
196196
# for torch 2.4+
197197
from torchao.quantization import quantize, int4_weight_only
198-
quantize(model, int4_weight_only())
198+
quantize_(model, int4_weight_only())
199199

200200
# for torch 2.2.2 and 2.3
201201
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors

torchao/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"quantize_affine",
3030
"dequantize_affine",
3131
"choose_qprams_affine",
32-
"quantize",
32+
"quantize_",
3333
"int8_dynamic_activation_int4_weight",
3434
"int8_dynamic_activation_int8_weight",
3535
"int4_weight_only",

torchao/quantization/quant_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
"Int4WeightOnlyQuantizer",
5555
"autoquant",
5656
"_get_subclass_inserter",
57-
"quantize",
57+
"quantize_",
5858
"int8_dynamic_activation_int4_weight",
5959
"int8_dynamic_activation_int8_weight",
6060
"int4_weight_only",
@@ -259,8 +259,8 @@ def insert_subclass(lin):
259259

260260
return insert_subclass
261261

262-
def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True) -> torch.nn.Module:
263-
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`
262+
def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True):
263+
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
264264
265265
Args:
266266
model (torch.nn.Module): input model
@@ -273,7 +273,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
273273
274274
import torch
275275
import torch.nn as nn
276-
from torchao import quantize
276+
from torchao import quantize_
277277
278278
# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
279279
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
@@ -286,7 +286,7 @@ def quantize(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tens
286286
from torchao.quantization.quant_api import int4_weight_only
287287
288288
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
289-
m = quantize(m, int4_weight_only(group_size=32))
289+
quantize_(m, int4_weight_only(group_size=32))
290290
291291
# 2. write your own new apply_tensor_subclass
292292
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
@@ -305,7 +305,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
305305
return isinstance(module, nn.Linear)
306306
307307
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
308-
m = quantize(m, apply_weight_quant, filter_fn)
308+
quantize_(m, apply_weight_quant, filter_fn)
309309
310310
"""
311311
if set_inductor_config:
@@ -315,7 +315,7 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
315315
_get_linear_subclass_inserter(apply_tensor_subclass),
316316
_is_linear if filter_fn is None else filter_fn,
317317
)
318-
return model
318+
319319

320320
def int8_dynamic_activation_int4_weight(group_size=32):
321321
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear

tutorials/quantize_vit/run_vit_b_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
# for APIs for earlier torch version and other quantization techniques
2020

2121
# for torch 2.4+
22-
from torchao.quantization.quant_api import quantize
22+
from torchao.quantization.quant_api import quantize_
2323
from torchao.quantization.quant_api import int8_dynamic_activation_int8_weight
24-
quantize(model, int8_dynamic_activation_int8_weight())
24+
quantize_(model, int8_dynamic_activation_int8_weight())
2525
## Quantization code - end
2626

2727
## compilation configs

0 commit comments

Comments
 (0)