4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- from enum import Enum
7
+ from enum import Enum , auto
8
8
from typing import List , Optional , Tuple , Dict
9
9
import torch
10
10
11
11
from torchao .kernel .intmm import int_scaled_matmul
12
12
from torchao .kernel .intmm import safe_int_mm
13
- from torchao .utils import TORCH_VERSION_AFTER_2_3
13
+ from torchao .utils import (
14
+ TORCH_VERSION_AFTER_2_3 ,
15
+ TORCH_VERSION_AFTER_2_5 ,
16
+ )
17
+ from torchao .utils import _register_custom_op
14
18
15
19
16
20
__all__ = [
@@ -34,17 +38,17 @@ class MappingType(Enum):
34
38
based on this mapping
35
39
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
36
40
"""
37
- SYMMETRIC = 0
38
- ASYMMETRIC = 1
41
+ SYMMETRIC = auto ()
42
+ ASYMMETRIC = auto ()
39
43
40
44
class ZeroPointDomain (Enum ):
41
45
"""Enum that indicate whether zero_point is in integer domain or floating point domain
42
46
43
47
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
44
48
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
45
49
"""
46
- INT = 0
47
- FLOAT = 1
50
+ INT = auto ()
51
+ FLOAT = auto ()
48
52
49
53
"""
50
54
Map from dtype to the bound value of integers
@@ -69,6 +73,10 @@ class ZeroPointDomain(Enum):
69
73
})
70
74
71
75
76
+ quant_lib = torch .library .Library ("quant" , "FRAGMENT" )
77
+
78
+ register_custom_op = _register_custom_op (quant_lib )
79
+
72
80
# TODO: decide on if we want to allow custom quant_min/quant_max here
73
81
def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
74
82
"""Get quant_min and quant_max args based on dtype and also
@@ -140,7 +148,7 @@ def quantize_affine(
140
148
quant_min : Optional [int ] = None ,
141
149
quant_max : Optional [int ] = None ,
142
150
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143
- ):
151
+ ) -> torch . Tensor :
144
152
"""
145
153
Args:
146
154
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +182,31 @@ def quantize_affine(
174
182
Output:
175
183
quantized tensor with requested dtype
176
184
"""
185
+ return _quantize_affine (
186
+ input ,
187
+ block_size ,
188
+ scale ,
189
+ zero_point ,
190
+ output_dtype ,
191
+ quant_min ,
192
+ quant_max ,
193
+ zero_point_domain .name ,
194
+ )
195
+
196
+
197
+ @register_custom_op
198
+ def _quantize_affine (
199
+ input : torch .Tensor ,
200
+ block_size : List [int ],
201
+ scale : torch .Tensor ,
202
+ zero_point : Optional [torch .Tensor ],
203
+ output_dtype : torch .dtype ,
204
+ quant_min : Optional [int ] = None ,
205
+ quant_max : Optional [int ] = None ,
206
+ zero_point_domain : str = "INT" ,
207
+ ) -> torch .Tensor :
208
+ """op definition that has compatible signatures with custom op library
209
+ """
177
210
# TODO: validations
178
211
# TODO: validate scale/zero_point dimensions are compatible with block_size
179
212
assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +221,12 @@ def quantize_affine(
188
221
if zero_point is not None :
189
222
zero_point = zero_point .view (shape_after_reduction )
190
223
191
- if zero_point_domain == ZeroPointDomain .INT :
224
+ if zero_point_domain == ZeroPointDomain .INT . name :
192
225
quant = torch .clamp (
193
226
torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194
227
).to (output_dtype )
195
228
else :
196
- assert zero_point_domain == ZeroPointDomain .FLOAT
229
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197
230
mid_point = (quant_max + quant_min + 1 ) / 2
198
231
min_val = zero_point - scale * mid_point
199
232
quant = (
@@ -216,7 +249,7 @@ def dequantize_affine(
216
249
zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217
250
* ,
218
251
output_dtype : torch .dtype = torch .float32 ,
219
- ):
252
+ ) -> torch . Tensor :
220
253
"""
221
254
Args:
222
255
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +271,32 @@ def dequantize_affine(
238
271
Output:
239
272
dequantized Tensor, with requested dtype or fp32
240
273
"""
274
+ return _dequantize_affine (
275
+ input ,
276
+ block_size ,
277
+ scale ,
278
+ zero_point ,
279
+ input_dtype ,
280
+ quant_min ,
281
+ quant_max ,
282
+ zero_point_domain .name ,
283
+ output_dtype = output_dtype ,
284
+ )
285
+
286
+ @register_custom_op
287
+ def _dequantize_affine (
288
+ input : torch .Tensor ,
289
+ block_size : List [int ],
290
+ scale : torch .Tensor ,
291
+ zero_point : Optional [torch .Tensor ],
292
+ input_dtype : torch .dtype ,
293
+ quant_min : Optional [int ] = None ,
294
+ quant_max : Optional [int ] = None ,
295
+ zero_point_domain : str = "INT" ,
296
+ output_dtype : torch .dtype = torch .float32 ,
297
+ ) -> torch .Tensor :
298
+ """op definition that has compatible signatures with custom op library
299
+ """
241
300
242
301
# TODO: validations
243
302
# TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +314,16 @@ def dequantize_affine(
255
314
if zero_point is not None :
256
315
zero_point = zero_point .view (shape_after_reduction )
257
316
258
- if zero_point_domain == ZeroPointDomain .INT :
317
+ if zero_point_domain == ZeroPointDomain .INT . name :
259
318
# Force a copy to avoid input modification due
260
319
# to upcoming in-place operations.
261
320
dequant = input .to (torch .int32 , copy = True )
262
321
if zero_point is not None :
263
- dequant -= zero_point .to (torch .int32 )
322
+ dequant = dequant - zero_point .to (torch .int32 )
264
323
dequant = dequant .to (output_dtype )
265
- dequant *= scale
324
+ dequant = dequant * scale
266
325
else :
267
- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
326
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268
327
mid_point = (quant_max + quant_min + 1 ) / 2
269
328
# This should allocate new memory and avoid input modification
270
329
dequant = input - mid_point
@@ -320,8 +379,38 @@ def choose_qparams_affine(
320
379
Output:
321
380
Tuple of scales and zero_points Tensor with requested dtype
322
381
"""
382
+ return _choose_qparams_affine (
383
+ input ,
384
+ mapping_type .name ,
385
+ block_size ,
386
+ target_dtype ,
387
+ quant_min ,
388
+ quant_max ,
389
+ eps ,
390
+ scale_dtype ,
391
+ zero_point_dtype ,
392
+ preserve_zero ,
393
+ zero_point_domain .name
394
+ )
395
+
396
+ @register_custom_op
397
+ def _choose_qparams_affine (
398
+ input : torch .Tensor ,
399
+ mapping_type : str ,
400
+ block_size : List [int ],
401
+ target_dtype : torch .dtype ,
402
+ quant_min : Optional [int ] = None ,
403
+ quant_max : Optional [int ] = None ,
404
+ eps : Optional [float ] = None ,
405
+ scale_dtype : Optional [torch .dtype ] = None ,
406
+ zero_point_dtype : Optional [torch .dtype ] = None ,
407
+ preserve_zero : bool = True ,
408
+ zero_point_domain : str = "INT" ,
409
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
410
+ """op definition that has compatible signatures with custom op library
411
+ """
323
412
quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324
- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
413
+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325
414
326
415
if scale_dtype is None :
327
416
scale_dtype = input .dtype
@@ -342,21 +431,22 @@ def choose_qparams_affine(
342
431
min_val_neg = min_val
343
432
max_val_pos = max_val
344
433
345
- if mapping_type == MappingType .SYMMETRIC :
434
+ if mapping_type == MappingType .SYMMETRIC . name :
346
435
max_val_pos = torch .max (- min_val_neg , max_val_pos )
347
436
scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348
437
if not preserve_zero :
349
438
raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350
- if zero_point_domain != ZeroPointDomain .INT :
439
+ if zero_point_domain != ZeroPointDomain .INT . name :
351
440
raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352
441
zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353
442
else :
443
+ assert mapping_type == MappingType .ASYMMETRIC .name
354
444
scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355
445
if preserve_zero :
356
446
zero_point = quant_min - torch .round (min_val_neg / scale )
357
447
zero_point = torch .clamp (zero_point , quant_min , quant_max )
358
448
else :
359
- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
449
+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360
450
mid_point = (quant_max + quant_min + 1 ) / 2
361
451
zero_point = min_val_neg + scale * mid_point
362
452
0 commit comments