Skip to content

Commit d1e15b4

Browse files
authored
Add decorator for custom op and inductor decomp registration (#434)
Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. This is a redo for #408, difference is we can preserve the enums on the python side in this PR Test Plan: regression tests: python test/quantization/test_quant_api.py python test/integration/test_integration.py also need to check performance with python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent f22e8e8 commit d1e15b4

File tree

3 files changed

+173
-22
lines changed

3 files changed

+173
-22
lines changed

test/integration/test_integration.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1259,7 +1259,7 @@ def test_autoquant_manual(self, device, dtype):
12591259
out3 = mod(example_input)
12601260
sqnr2 = SQNR(out, out3)
12611261
self.assertTrue(sqnr2 >= 30)
1262-
1262+
12631263

12641264
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
12651265
[
@@ -1393,7 +1393,7 @@ class TestExport(unittest.TestCase):
13931393
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
13941394
)
13951395
@run_supported_device_dtype
1396-
def test_aoti(self, api, test_device, test_dtype):
1396+
def test_export(self, api, test_device, test_dtype):
13971397
if not TORCH_VERSION_AFTER_2_4:
13981398
self.skipTest("aoti compatibility requires 2.4+.")
13991399

@@ -1430,9 +1430,20 @@ def forward(self, x):
14301430

14311431
# make sure it compiles
14321432
example_inputs = (x,)
1433-
model = torch.export.export(model, example_inputs).module()
1433+
from torch._export import capture_pre_autograd_graph
1434+
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
1435+
# we can re-enable this after non-functional IR is enabled in export
1436+
# model = torch.export.export(model, example_inputs).module()
1437+
model = capture_pre_autograd_graph(model, example_inputs)
14341438
after_export = model(x)
14351439
self.assertTrue(torch.equal(after_export, ref))
1440+
if api is _int8da_int8w_api:
1441+
targets = [n.target for n in model.graph.nodes]
1442+
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
1443+
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
1444+
1445+
1446+
14361447

14371448
class TestUtils(unittest.TestCase):
14381449
@parameterized.expand(COMMON_DEVICE_DTYPE)

torchao/quantization/quant_primitives.py

+108-18
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from enum import Enum
7+
from enum import Enum, auto
88
from typing import List, Optional, Tuple, Dict
99
import torch
1010

1111
from torchao.kernel.intmm import int_scaled_matmul
1212
from torchao.kernel.intmm import safe_int_mm
13-
from torchao.utils import TORCH_VERSION_AFTER_2_3
13+
from torchao.utils import (
14+
TORCH_VERSION_AFTER_2_3,
15+
TORCH_VERSION_AFTER_2_5,
16+
)
17+
from torchao.utils import _register_custom_op
1418

1519

1620
__all__ = [
@@ -34,17 +38,17 @@ class MappingType(Enum):
3438
based on this mapping
3539
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3640
"""
37-
SYMMETRIC = 0
38-
ASYMMETRIC = 1
41+
SYMMETRIC = auto()
42+
ASYMMETRIC = auto()
3943

4044
class ZeroPointDomain(Enum):
4145
"""Enum that indicate whether zero_point is in integer domain or floating point domain
4246
4347
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4448
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4549
"""
46-
INT = 0
47-
FLOAT = 1
50+
INT = auto()
51+
FLOAT = auto()
4852

4953
"""
5054
Map from dtype to the bound value of integers
@@ -69,6 +73,10 @@ class ZeroPointDomain(Enum):
6973
})
7074

7175

76+
quant_lib = torch.library.Library("quant", "FRAGMENT")
77+
78+
register_custom_op = _register_custom_op(quant_lib)
79+
7280
# TODO: decide on if we want to allow custom quant_min/quant_max here
7381
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
7482
"""Get quant_min and quant_max args based on dtype and also
@@ -140,7 +148,7 @@ def quantize_affine(
140148
quant_min: Optional[int] = None,
141149
quant_max: Optional[int] = None,
142150
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
143-
):
151+
) -> torch.Tensor:
144152
"""
145153
Args:
146154
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +182,31 @@ def quantize_affine(
174182
Output:
175183
quantized tensor with requested dtype
176184
"""
185+
return _quantize_affine(
186+
input,
187+
block_size,
188+
scale,
189+
zero_point,
190+
output_dtype,
191+
quant_min,
192+
quant_max,
193+
zero_point_domain.name,
194+
)
195+
196+
197+
@register_custom_op
198+
def _quantize_affine(
199+
input: torch.Tensor,
200+
block_size: List[int],
201+
scale: torch.Tensor,
202+
zero_point: Optional[torch.Tensor],
203+
output_dtype: torch.dtype,
204+
quant_min: Optional[int] = None,
205+
quant_max: Optional[int] = None,
206+
zero_point_domain: str = "INT",
207+
) -> torch.Tensor:
208+
"""op definition that has compatible signatures with custom op library
209+
"""
177210
# TODO: validations
178211
# TODO: validate scale/zero_point dimensions are compatible with block_size
179212
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
@@ -188,12 +221,12 @@ def quantize_affine(
188221
if zero_point is not None:
189222
zero_point = zero_point.view(shape_after_reduction)
190223

191-
if zero_point_domain == ZeroPointDomain.INT:
224+
if zero_point_domain == ZeroPointDomain.INT.name:
192225
quant = torch.clamp(
193226
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
194227
).to(output_dtype)
195228
else:
196-
assert zero_point_domain == ZeroPointDomain.FLOAT
229+
assert zero_point_domain == ZeroPointDomain.FLOAT.name
197230
mid_point = (quant_max + quant_min + 1) / 2
198231
min_val = zero_point - scale * mid_point
199232
quant = (
@@ -216,7 +249,7 @@ def dequantize_affine(
216249
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
217250
*,
218251
output_dtype: torch.dtype = torch.float32,
219-
):
252+
) -> torch.Tensor:
220253
"""
221254
Args:
222255
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +271,32 @@ def dequantize_affine(
238271
Output:
239272
dequantized Tensor, with requested dtype or fp32
240273
"""
274+
return _dequantize_affine(
275+
input,
276+
block_size,
277+
scale,
278+
zero_point,
279+
input_dtype,
280+
quant_min,
281+
quant_max,
282+
zero_point_domain.name,
283+
output_dtype=output_dtype,
284+
)
285+
286+
@register_custom_op
287+
def _dequantize_affine(
288+
input: torch.Tensor,
289+
block_size: List[int],
290+
scale: torch.Tensor,
291+
zero_point: Optional[torch.Tensor],
292+
input_dtype: torch.dtype,
293+
quant_min: Optional[int] = None,
294+
quant_max: Optional[int] = None,
295+
zero_point_domain: str = "INT",
296+
output_dtype: torch.dtype = torch.float32,
297+
) -> torch.Tensor:
298+
"""op definition that has compatible signatures with custom op library
299+
"""
241300

242301
# TODO: validations
243302
# TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +314,16 @@ def dequantize_affine(
255314
if zero_point is not None:
256315
zero_point = zero_point.view(shape_after_reduction)
257316

258-
if zero_point_domain == ZeroPointDomain.INT:
317+
if zero_point_domain == ZeroPointDomain.INT.name:
259318
# Force a copy to avoid input modification due
260319
# to upcoming in-place operations.
261320
dequant = input.to(torch.int32, copy=True)
262321
if zero_point is not None:
263-
dequant -= zero_point.to(torch.int32)
322+
dequant = dequant - zero_point.to(torch.int32)
264323
dequant = dequant.to(output_dtype)
265-
dequant *= scale
324+
dequant = dequant * scale
266325
else:
267-
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
326+
assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}"
268327
mid_point = (quant_max + quant_min + 1) / 2
269328
# This should allocate new memory and avoid input modification
270329
dequant = input - mid_point
@@ -320,8 +379,38 @@ def choose_qparams_affine(
320379
Output:
321380
Tuple of scales and zero_points Tensor with requested dtype
322381
"""
382+
return _choose_qparams_affine(
383+
input,
384+
mapping_type.name,
385+
block_size,
386+
target_dtype,
387+
quant_min,
388+
quant_max,
389+
eps,
390+
scale_dtype,
391+
zero_point_dtype,
392+
preserve_zero,
393+
zero_point_domain.name
394+
)
395+
396+
@register_custom_op
397+
def _choose_qparams_affine(
398+
input: torch.Tensor,
399+
mapping_type: str,
400+
block_size: List[int],
401+
target_dtype: torch.dtype,
402+
quant_min: Optional[int] = None,
403+
quant_max: Optional[int] = None,
404+
eps: Optional[float] = None,
405+
scale_dtype: Optional[torch.dtype] = None,
406+
zero_point_dtype: Optional[torch.dtype] = None,
407+
preserve_zero: bool = True,
408+
zero_point_domain: str = "INT",
409+
) -> Tuple[torch.Tensor, torch.Tensor]:
410+
"""op definition that has compatible signatures with custom op library
411+
"""
323412
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
324-
assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}"
413+
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
325414

326415
if scale_dtype is None:
327416
scale_dtype = input.dtype
@@ -342,21 +431,22 @@ def choose_qparams_affine(
342431
min_val_neg = min_val
343432
max_val_pos = max_val
344433

345-
if mapping_type == MappingType.SYMMETRIC:
434+
if mapping_type == MappingType.SYMMETRIC.name:
346435
max_val_pos = torch.max(-min_val_neg, max_val_pos)
347436
scale = max_val_pos / (float(quant_max - quant_min) / 2)
348437
if not preserve_zero:
349438
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
350-
if zero_point_domain != ZeroPointDomain.INT:
439+
if zero_point_domain != ZeroPointDomain.INT.name:
351440
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
352441
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
353442
else:
443+
assert mapping_type == MappingType.ASYMMETRIC.name
354444
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
355445
if preserve_zero:
356446
zero_point = quant_min - torch.round(min_val_neg / scale)
357447
zero_point = torch.clamp(zero_point, quant_min, quant_max)
358448
else:
359-
assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
449+
assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain"
360450
mid_point = (quant_max + quant_min + 1) / 2
361451
zero_point = min_val_neg + scale * mid_point
362452

torchao/utils.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"skip_if_compute_capability_less_than",
1414
"benchmark_torch_function_in_microseconds",
1515
"find_multiple",
16+
"_register_custom_op",
1617
"get_model_size_in_bytes",
1718
"unwrap_tensor_subclass",
1819
"TORCH_VERSION_AFTER_2_2",
@@ -65,7 +66,7 @@ def wrapper(*args, **kwargs):
6566

6667
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
6768
import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded
68-
69+
6970
# Manual warmup
7071
f(*args, **kwargs)
7172
f(*args, **kwargs)
@@ -84,6 +85,55 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
8485
return n
8586
return n + k - (n % k)
8687

88+
def _register_custom_op(lib):
89+
"""This decorator is used to preserve some high level operators for torch.export.export
90+
while still allow them to be decomposed for inductor path
91+
92+
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
93+
94+
NOTE: This should be applied at the top, after all other decorators have been applied
95+
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
96+
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
97+
sense for downstream system (like executorch) to accept as well
98+
99+
Example:
100+
lib = torch.library.Library("my_namespace', "FRAGMENT")
101+
102+
register_custom_op = _register_custom_op(lib)
103+
104+
@register_custom_op
105+
def _the_op_that_needs_to_be_preserved(...)
106+
...
107+
108+
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
109+
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
110+
# torch.export.export / torch._export.capture_pre_autograd_graph
111+
112+
"""
113+
from torch._inductor.decomposition import register_decomposition
114+
115+
def decorator(fn):
116+
if TORCH_VERSION_AFTER_2_5:
117+
from torch._library.infer_schema import infer_schema
118+
119+
# expecting fn.__name__ starts with `_` and we want to take the rest
120+
# to be the name of the custom op
121+
assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}"
122+
assert not any(c in fn.__name__ for c in ".<>"), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
123+
op_name = fn.__name__[1:]
124+
schema = op_name + infer_schema(fn)
125+
lib.define(schema)
126+
lib.impl(op_name, fn, "CompositeImplicitAutograd")
127+
128+
lib_namespace = lib.ns
129+
op = getattr(getattr(torch.ops, lib_namespace), op_name)
130+
register_decomposition([op])(fn)
131+
return op
132+
else:
133+
return fn
134+
135+
return decorator
136+
87137
def get_model_size_in_bytes(model, ignore_embeddings=False):
88138
"""
89139
Returns the model size in bytes. The option to ignore embeddings

0 commit comments

Comments
 (0)