4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from enum import Enum
7
+ from enum import Enum , auto
8
8
from typing import List , Optional , Tuple , Dict
9
9
import torch
10
10
11
11
from torchao .kernel .intmm import int_scaled_matmul
12
12
from torchao .kernel .intmm import safe_int_mm
13
- from torchao .utils import TORCH_VERSION_AFTER_2_3
13
+ from torchao .utils import (
14
+ TORCH_VERSION_AFTER_2_3 ,
15
+ TORCH_VERSION_AFTER_2_5 ,
16
+ )
14
17
15
18
16
19
__all__ = [
@@ -34,17 +37,17 @@ class MappingType(Enum):
34
37
based on this mapping
35
38
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
36
39
"""
37
- SYMMETRIC = 0
38
- ASYMMETRIC = 1
40
+ SYMMETRIC = auto ()
41
+ ASYMMETRIC = auto ()
39
42
40
43
class ZeroPointDomain (Enum ):
41
44
"""Enum that indicate whether zero_point is in integer domain or floating point domain
42
45
43
46
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
44
47
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
45
48
"""
46
- INT = 0
47
- FLOAT = 1
49
+ INT = auto ()
50
+ FLOAT = auto ()
48
51
49
52
"""
50
53
Map from dtype to the bound value of integers
@@ -69,6 +72,90 @@ class ZeroPointDomain(Enum):
69
72
})
70
73
71
74
75
+ # def register_custom_op(name: str):
76
+ # from torch._inductor.decomposition import register_decomposition
77
+
78
+ # def decorator(fn):
79
+ # if TORCH_VERSION_AFTER_2_5:
80
+ # opdef = torch.library.custom_op(name, mutates_args=())(fn)
81
+ # opdef.register_fake(fn)
82
+ # register_decomposition([opdef._opoverload])(fn)
83
+ # return opdef
84
+ # else:
85
+ # return fn
86
+
87
+ # return decorator
88
+
89
+ quant_lib = torch .library .Library ("quant" , "FRAGMENT" )
90
+
91
+ # def register_custom_op(lib, schema: str):
92
+ # """This decorator is used to preserve some high level operators for torch.export.export
93
+ # while still allow them to be decomposed for inductor path
94
+
95
+ # NOTE: This should be applied at the top, after all other decorators have been applied
96
+ # """
97
+ # from torch._inductor.decomposition import register_decomposition
98
+
99
+ # def decorator(fn):
100
+ # if TORCH_VERSION_AFTER_2_5:
101
+ # # TODO: change order
102
+ # lib_namespace = lib.ns
103
+ # op_name = schema.split("(")[0]
104
+ # lib.define(schema)
105
+ # lib.impl(op_name, fn, "CompositeImplicitAutograd")
106
+ # op = getattr(getattr(torch.ops, lib_namespace), op_name)
107
+ # register_decomposition([op])(fn)
108
+ # return op
109
+ # else:
110
+ # return fn
111
+
112
+ # return decorator
113
+
114
+ def register_custom_op (lib ):
115
+ """This decorator is used to preserve some high level operators for torch.export.export
116
+ while still allow them to be decomposed for inductor path
117
+
118
+ requirement: make sure `fn.__name__[1:]` is the operator name you want to register
119
+
120
+ NOTE: This should be applied at the top, after all other decorators have been applied
121
+ NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
122
+ e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
123
+ sense for downstream system (like executorch) to accept as well
124
+
125
+ Example:
126
+ lib = torch.library.Library("my_namespace', "FRAGMENT")
127
+ @register_custom_op(lib)
128
+ def _the_op_that_needs_to_be_preserved(...)
129
+ ...
130
+
131
+ # after this, `_the_op_that_needs_to_be_preserved` will be preserved as
132
+ # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
133
+ # torch.export.export
134
+
135
+ """
136
+ from torch ._inductor .decomposition import register_decomposition
137
+
138
+ def decorator (fn ):
139
+ if TORCH_VERSION_AFTER_2_5 :
140
+ from torch ._library .infer_schema import infer_schema
141
+
142
+ # assuming fn.__name__ starts with `_` and we want to take the rest
143
+ # to be the name of the custom op
144
+ op_name = fn .__name__ [1 :]
145
+ schema = op_name + infer_schema (fn )
146
+ lib .define (schema )
147
+ lib .impl (op_name , fn , "CompositeImplicitAutograd" )
148
+
149
+ lib_namespace = lib .ns
150
+ op = getattr (getattr (torch .ops , lib_namespace ), op_name )
151
+ register_decomposition ([op ])(fn )
152
+ return op
153
+ else :
154
+ return fn
155
+
156
+ return decorator
157
+
158
+
72
159
# TODO: decide on if we want to allow custom quant_min/quant_max here
73
160
def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
74
161
"""Get quant_min and quant_max args based on dtype and also
@@ -140,7 +227,7 @@ def quantize_affine(
140
227
quant_min : Optional [int ] = None ,
141
228
quant_max : Optional [int ] = None ,
142
229
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143
- ):
230
+ ) -> torch . Tensor :
144
231
"""
145
232
Args:
146
233
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +261,31 @@ def quantize_affine(
174
261
Output:
175
262
quantized tensor with requested dtype
176
263
"""
264
+ return _quantize_affine (
265
+ input ,
266
+ block_size ,
267
+ scale ,
268
+ zero_point ,
269
+ output_dtype ,
270
+ quant_min ,
271
+ quant_max ,
272
+ zero_point_domain .name ,
273
+ )
274
+
275
+
276
+ @register_custom_op (quant_lib )
277
+ def _quantize_affine (
278
+ input : torch .Tensor ,
279
+ block_size : List [int ],
280
+ scale : torch .Tensor ,
281
+ zero_point : Optional [torch .Tensor ],
282
+ output_dtype : torch .dtype ,
283
+ quant_min : Optional [int ] = None ,
284
+ quant_max : Optional [int ] = None ,
285
+ zero_point_domain : str = "INT" ,
286
+ ) -> torch .Tensor :
287
+ """op definition that has compatible signatures with custom op library
288
+ """
177
289
# TODO: validations
178
290
# TODO: validate scale/zero_point dimensions are compatible with block_size
179
291
assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +300,12 @@ def quantize_affine(
188
300
if zero_point is not None :
189
301
zero_point = zero_point .view (shape_after_reduction )
190
302
191
- if zero_point_domain == ZeroPointDomain .INT :
303
+ if zero_point_domain == ZeroPointDomain .INT . name :
192
304
quant = torch .clamp (
193
305
torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194
306
).to (output_dtype )
195
307
else :
196
- assert zero_point_domain == ZeroPointDomain .FLOAT
308
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197
309
mid_point = (quant_max + quant_min + 1 ) / 2
198
310
min_val = zero_point - scale * mid_point
199
311
quant = (
@@ -216,7 +328,7 @@ def dequantize_affine(
216
328
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217
329
* ,
218
330
output_dtype : torch .dtype = torch .float32 ,
219
- ):
331
+ ) -> torch . Tensor :
220
332
"""
221
333
Args:
222
334
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +350,34 @@ def dequantize_affine(
238
350
Output:
239
351
dequantized Tensor, with requested dtype or fp32
240
352
"""
353
+ return _dequantize_affine (
354
+ input ,
355
+ block_size ,
356
+ scale ,
357
+ zero_point ,
358
+ input_dtype ,
359
+ quant_min ,
360
+ quant_max ,
361
+ zero_point_domain .name ,
362
+ output_dtype = output_dtype ,
363
+ )
364
+
365
+
366
+ # @register_custom_op(quant_lib, 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor')
367
+ @register_custom_op (quant_lib )
368
+ def _dequantize_affine (
369
+ input : torch .Tensor ,
370
+ block_size : List [int ],
371
+ scale : torch .Tensor ,
372
+ zero_point : Optional [torch .Tensor ],
373
+ input_dtype : torch .dtype ,
374
+ quant_min : Optional [int ] = None ,
375
+ quant_max : Optional [int ] = None ,
376
+ zero_point_domain : str = "INT" ,
377
+ output_dtype : torch .dtype = torch .float32 ,
378
+ ) -> torch .Tensor :
379
+ """op definition that has compatible signatures with custom op library
380
+ """
241
381
242
382
# TODO: validations
243
383
# TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +395,16 @@ def dequantize_affine(
255
395
if zero_point is not None :
256
396
zero_point = zero_point .view (shape_after_reduction )
257
397
258
- if zero_point_domain == ZeroPointDomain .INT :
398
+ if zero_point_domain == ZeroPointDomain .INT . name :
259
399
# Force a copy to avoid input modification due
260
400
# to upcoming in-place operations.
261
401
dequant = input .to (torch .int32 , copy = True )
262
402
if zero_point is not None :
263
- dequant -= zero_point .to (torch .int32 )
403
+ dequant = dequant - zero_point .to (torch .int32 )
264
404
dequant = dequant .to (output_dtype )
265
- dequant *= scale
405
+ dequant = dequant * scale
266
406
else :
267
- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
407
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268
408
mid_point = (quant_max + quant_min + 1 ) / 2
269
409
# This should allocate new memory and avoid input modification
270
410
dequant = input - mid_point
@@ -320,8 +460,39 @@ def choose_qparams_affine(
320
460
Output:
321
461
Tuple of scales and zero_points Tensor with requested dtype
322
462
"""
463
+ return _choose_qparams_affine (
464
+ input ,
465
+ mapping_type .name ,
466
+ block_size ,
467
+ target_dtype ,
468
+ quant_min ,
469
+ quant_max ,
470
+ eps ,
471
+ scale_dtype ,
472
+ zero_point_dtype ,
473
+ preserve_zero ,
474
+ zero_point_domain .name
475
+ )
476
+
477
+ # @register_custom_op(quant_lib, 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)')
478
+ @register_custom_op (quant_lib )
479
+ def _choose_qparams_affine (
480
+ input : torch .Tensor ,
481
+ mapping_type : str ,
482
+ block_size : List [int ],
483
+ target_dtype : torch .dtype ,
484
+ quant_min : Optional [int ] = None ,
485
+ quant_max : Optional [int ] = None ,
486
+ eps : Optional [float ] = None ,
487
+ scale_dtype : Optional [torch .dtype ] = None ,
488
+ zero_point_dtype : Optional [torch .dtype ] = None ,
489
+ preserve_zero : bool = True ,
490
+ zero_point_domain : str = "INT" ,
491
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
492
+ """op definition that has compatible signatures with custom op library
493
+ """
323
494
quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324
- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
495
+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325
496
326
497
if scale_dtype is None :
327
498
scale_dtype = input .dtype
@@ -342,21 +513,22 @@ def choose_qparams_affine(
342
513
min_val_neg = min_val
343
514
max_val_pos = max_val
344
515
345
- if mapping_type == MappingType .SYMMETRIC :
516
+ if mapping_type == MappingType .SYMMETRIC . name :
346
517
max_val_pos = torch .max (- min_val_neg , max_val_pos )
347
518
scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348
519
if not preserve_zero :
349
520
raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350
- if zero_point_domain != ZeroPointDomain .INT :
521
+ if zero_point_domain != ZeroPointDomain .INT . name :
351
522
raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352
523
zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353
524
else :
525
+ assert mapping_type == MappingType .ASYMMETRIC .name
354
526
scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355
527
if preserve_zero :
356
528
zero_point = quant_min - torch .round (min_val_neg / scale )
357
529
zero_point = torch .clamp (zero_point , quant_min , quant_max )
358
530
else :
359
- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
531
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360
532
mid_point = (quant_max + quant_min + 1 ) / 2
361
533
zero_point = min_val_neg + scale * mid_point
362
534
0 commit comments