Skip to content

Commit f2ca7f2

Browse files
committed
Add decorator for custom op and inductor decomp registration
Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. This is a redo for #408, difference is we can preserve the enums on the python side in this PR Test Plan: regression tests: python test/quantization/test_quant_api.py python test/integration/test_integration.py also need to check performance with python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags:
1 parent c2cf973 commit f2ca7f2

File tree

2 files changed

+204
-21
lines changed

2 files changed

+204
-21
lines changed

test/integration/test_integration.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -1244,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype):
12441244
out3 = mod(example_input)
12451245
sqnr2 = SQNR(out, out3)
12461246
self.assertTrue(sqnr2 >= 30)
1247-
1247+
12481248

12491249
@parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE,
12501250
[
@@ -1376,7 +1376,7 @@ class TestExport(unittest.TestCase):
13761376
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
13771377
)
13781378
@run_supported_device_dtype
1379-
def test_aoti(self, api, test_device, test_dtype):
1379+
def test_export(self, api, test_device, test_dtype):
13801380
if not TORCH_VERSION_AFTER_2_4:
13811381
self.skipTest("aoti compatibility requires 2.4+.")
13821382

@@ -1413,9 +1413,20 @@ def forward(self, x):
14131413

14141414
# make sure it compiles
14151415
example_inputs = (x,)
1416-
model = torch.export.export(model, example_inputs).module()
1416+
from torch._export import capture_pre_autograd_graph
1417+
# TODO: export changes numerics right now, this is because of functionalization according to Zhengxu
1418+
# we can re-enable this after non-functional IR is enabled in export
1419+
# model = torch.export.export(model, example_inputs).module()
1420+
model = capture_pre_autograd_graph(model, example_inputs)
14171421
after_export = model(x)
14181422
self.assertTrue(torch.equal(after_export, ref))
1423+
if api is _int8da_int8w_api:
1424+
targets = [n.target for n in model.graph.nodes]
1425+
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
1426+
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
1427+
1428+
1429+
14191430

14201431
class TestUtils(unittest.TestCase):
14211432
@parameterized.expand(COMMON_DEVICE_DTYPE)

torchao/quantization/quant_primitives.py

+190-18
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from enum import Enum
7+
from enum import Enum, auto
88
from typing import List, Optional, Tuple, Dict
99
import torch
1010

1111
from torchao.kernel.intmm import int_scaled_matmul
1212
from torchao.kernel.intmm import safe_int_mm
13-
from torchao.utils import TORCH_VERSION_AFTER_2_3
13+
from torchao.utils import (
14+
TORCH_VERSION_AFTER_2_3,
15+
TORCH_VERSION_AFTER_2_5,
16+
)
1417

1518

1619
__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
3437
based on this mapping
3538
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3639
"""
37-
SYMMETRIC = 0
38-
ASYMMETRIC = 1
40+
SYMMETRIC = auto()
41+
ASYMMETRIC = auto()
3942

4043
class ZeroPointDomain(Enum):
4144
"""Enum that indicate whether zero_point is in integer domain or floating point domain
4245
4346
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4447
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4548
"""
46-
INT = 0
47-
FLOAT = 1
49+
INT = auto()
50+
FLOAT = auto()
4851

4952
"""
5053
Map from dtype to the bound value of integers
@@ -69,6 +72,90 @@ class ZeroPointDomain(Enum):
6972
})
7073

7174

75+
# def register_custom_op(name: str):
76+
# from torch._inductor.decomposition import register_decomposition
77+
78+
# def decorator(fn):
79+
# if TORCH_VERSION_AFTER_2_5:
80+
# opdef = torch.library.custom_op(name, mutates_args=())(fn)
81+
# opdef.register_fake(fn)
82+
# register_decomposition([opdef._opoverload])(fn)
83+
# return opdef
84+
# else:
85+
# return fn
86+
87+
# return decorator
88+
89+
quant_lib = torch.library.Library("quant", "FRAGMENT")
90+
91+
# def register_custom_op(lib, schema: str):
92+
# """This decorator is used to preserve some high level operators for torch.export.export
93+
# while still allow them to be decomposed for inductor path
94+
95+
# NOTE: This should be applied at the top, after all other decorators have been applied
96+
# """
97+
# from torch._inductor.decomposition import register_decomposition
98+
99+
# def decorator(fn):
100+
# if TORCH_VERSION_AFTER_2_5:
101+
# # TODO: change order
102+
# lib_namespace = lib.ns
103+
# op_name = schema.split("(")[0]
104+
# lib.define(schema)
105+
# lib.impl(op_name, fn, "CompositeImplicitAutograd")
106+
# op = getattr(getattr(torch.ops, lib_namespace), op_name)
107+
# register_decomposition([op])(fn)
108+
# return op
109+
# else:
110+
# return fn
111+
112+
# return decorator
113+
114+
def register_custom_op(lib):
115+
"""This decorator is used to preserve some high level operators for torch.export.export
116+
while still allow them to be decomposed for inductor path
117+
118+
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
119+
120+
NOTE: This should be applied at the top, after all other decorators have been applied
121+
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
122+
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
123+
sense for downstream system (like executorch) to accept as well
124+
125+
Example:
126+
lib = torch.library.Library("my_namespace', "FRAGMENT")
127+
@register_custom_op(lib)
128+
def _the_op_that_needs_to_be_preserved(...)
129+
...
130+
131+
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
132+
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
133+
# torch.export.export
134+
135+
"""
136+
from torch._inductor.decomposition import register_decomposition
137+
138+
def decorator(fn):
139+
if TORCH_VERSION_AFTER_2_5:
140+
from torch._library.infer_schema import infer_schema
141+
142+
# assuming fn.__name__ starts with `_` and we want to take the rest
143+
# to be the name of the custom op
144+
op_name = fn.__name__[1:]
145+
schema = op_name + infer_schema(fn)
146+
lib.define(schema)
147+
lib.impl(op_name, fn, "CompositeImplicitAutograd")
148+
149+
lib_namespace = lib.ns
150+
op = getattr(getattr(torch.ops, lib_namespace), op_name)
151+
register_decomposition([op])(fn)
152+
return op
153+
else:
154+
return fn
155+
156+
return decorator
157+
158+
72159
# TODO: decide on if we want to allow custom quant_min/quant_max here
73160
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
74161
"""Get quant_min and quant_max args based on dtype and also
@@ -140,7 +227,7 @@ def quantize_affine(
140227
quant_min: Optional[int] = None,
141228
quant_max: Optional[int] = None,
142229
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
143-
):
230+
) -> torch.Tensor:
144231
"""
145232
Args:
146233
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +261,31 @@ def quantize_affine(
174261
Output:
175262
quantized tensor with requested dtype
176263
"""
264+
return _quantize_affine(
265+
input,
266+
block_size,
267+
scale,
268+
zero_point,
269+
output_dtype,
270+
quant_min,
271+
quant_max,
272+
zero_point_domain.name,
273+
)
274+
275+
276+
@register_custom_op(quant_lib)
277+
def _quantize_affine(
278+
input: torch.Tensor,
279+
block_size: List[int],
280+
scale: torch.Tensor,
281+
zero_point: Optional[torch.Tensor],
282+
output_dtype: torch.dtype,
283+
quant_min: Optional[int] = None,
284+
quant_max: Optional[int] = None,
285+
zero_point_domain: str = "INT",
286+
) -> torch.Tensor:
287+
"""op definition that has compatible signatures with custom op library
288+
"""
177289
# TODO: validations
178290
# TODO: validate scale/zero_point dimensions are compatible with block_size
179291
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
@@ -188,12 +300,12 @@ def quantize_affine(
188300
if zero_point is not None:
189301
zero_point = zero_point.view(shape_after_reduction)
190302

191-
if zero_point_domain == ZeroPointDomain.INT:
303+
if zero_point_domain == ZeroPointDomain.INT.name:
192304
quant = torch.clamp(
193305
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
194306
).to(output_dtype)
195307
else:
196-
assert zero_point_domain == ZeroPointDomain.FLOAT
308+
assert zero_point_domain == ZeroPointDomain.FLOAT.name
197309
mid_point = (quant_max + quant_min + 1) / 2
198310
min_val = zero_point - scale * mid_point
199311
quant = (
@@ -216,7 +328,7 @@ def dequantize_affine(
216328
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
217329
*,
218330
output_dtype: torch.dtype = torch.float32,
219-
):
331+
) -> torch.Tensor:
220332
"""
221333
Args:
222334
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +350,34 @@ def dequantize_affine(
238350
Output:
239351
dequantized Tensor, with requested dtype or fp32
240352
"""
353+
return _dequantize_affine(
354+
input,
355+
block_size,
356+
scale,
357+
zero_point,
358+
input_dtype,
359+
quant_min,
360+
quant_max,
361+
zero_point_domain.name,
362+
output_dtype=output_dtype,
363+
)
364+
365+
366+
# @register_custom_op(quant_lib, 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor')
367+
@register_custom_op(quant_lib)
368+
def _dequantize_affine(
369+
input: torch.Tensor,
370+
block_size: List[int],
371+
scale: torch.Tensor,
372+
zero_point: Optional[torch.Tensor],
373+
input_dtype: torch.dtype,
374+
quant_min: Optional[int] = None,
375+
quant_max: Optional[int] = None,
376+
zero_point_domain: str = "INT",
377+
output_dtype: torch.dtype = torch.float32,
378+
) -> torch.Tensor:
379+
"""op definition that has compatible signatures with custom op library
380+
"""
241381

242382
# TODO: validations
243383
# TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +395,16 @@ def dequantize_affine(
255395
if zero_point is not None:
256396
zero_point = zero_point.view(shape_after_reduction)
257397

258-
if zero_point_domain == ZeroPointDomain.INT:
398+
if zero_point_domain == ZeroPointDomain.INT.name:
259399
# Force a copy to avoid input modification due
260400
# to upcoming in-place operations.
261401
dequant = input.to(torch.int32, copy=True)
262402
if zero_point is not None:
263-
dequant -= zero_point.to(torch.int32)
403+
dequant = dequant - zero_point.to(torch.int32)
264404
dequant = dequant.to(output_dtype)
265-
dequant *= scale
405+
dequant = dequant * scale
266406
else:
267-
assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}"
407+
assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}"
268408
mid_point = (quant_max + quant_min + 1) / 2
269409
# This should allocate new memory and avoid input modification
270410
dequant = input - mid_point
@@ -320,8 +460,39 @@ def choose_qparams_affine(
320460
Output:
321461
Tuple of scales and zero_points Tensor with requested dtype
322462
"""
463+
return _choose_qparams_affine(
464+
input,
465+
mapping_type.name,
466+
block_size,
467+
target_dtype,
468+
quant_min,
469+
quant_max,
470+
eps,
471+
scale_dtype,
472+
zero_point_dtype,
473+
preserve_zero,
474+
zero_point_domain.name
475+
)
476+
477+
# @register_custom_op(quant_lib, 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)')
478+
@register_custom_op(quant_lib)
479+
def _choose_qparams_affine(
480+
input: torch.Tensor,
481+
mapping_type: str,
482+
block_size: List[int],
483+
target_dtype: torch.dtype,
484+
quant_min: Optional[int] = None,
485+
quant_max: Optional[int] = None,
486+
eps: Optional[float] = None,
487+
scale_dtype: Optional[torch.dtype] = None,
488+
zero_point_dtype: Optional[torch.dtype] = None,
489+
preserve_zero: bool = True,
490+
zero_point_domain: str = "INT",
491+
) -> Tuple[torch.Tensor, torch.Tensor]:
492+
"""op definition that has compatible signatures with custom op library
493+
"""
323494
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
324-
assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}"
495+
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
325496

326497
if scale_dtype is None:
327498
scale_dtype = input.dtype
@@ -342,21 +513,22 @@ def choose_qparams_affine(
342513
min_val_neg = min_val
343514
max_val_pos = max_val
344515

345-
if mapping_type == MappingType.SYMMETRIC:
516+
if mapping_type == MappingType.SYMMETRIC.name:
346517
max_val_pos = torch.max(-min_val_neg, max_val_pos)
347518
scale = max_val_pos / (float(quant_max - quant_min) / 2)
348519
if not preserve_zero:
349520
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
350-
if zero_point_domain != ZeroPointDomain.INT:
521+
if zero_point_domain != ZeroPointDomain.INT.name:
351522
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
352523
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
353524
else:
525+
assert mapping_type == MappingType.ASYMMETRIC.name
354526
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
355527
if preserve_zero:
356528
zero_point = quant_min - torch.round(min_val_neg / scale)
357529
zero_point = torch.clamp(zero_point, quant_min, quant_max)
358530
else:
359-
assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain"
531+
assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain"
360532
mid_point = (quant_max + quant_min + 1) / 2
361533
zero_point = min_val_neg + scale * mid_point
362534

0 commit comments

Comments
 (0)