21
21
import torch .nn .functional as F
22
22
from typing import Any , Callable , Union , Dict , Optional
23
23
24
- from torchao .dtypes .uintx .Uintx import UintxLayoutType
25
- from torchao .dtypes import (
26
- to_affine_quantized ,
27
- TensorCoreTiledLayoutType ,
28
- PlainLayoutType ,
29
- AffineQuantizedTensor ,
30
- SemiSparseLayoutType
31
- )
24
+ from torchao .dtypes import PlainLayoutType
32
25
from torchao .utils import (
33
26
TORCH_VERSION_AFTER_2_4 ,
34
27
unwrap_tensor_subclass ,
@@ -189,6 +182,9 @@ def _replace_with_custom_fn_if_matches_filter(
189
182
190
183
191
184
def _is_linear (mod , * args ):
185
+ # avoid circular dep
186
+ from torchao .dtypes import AffineQuantizedTensor
187
+
192
188
# adding weight tensor subclass isinstance check to make sure the weight is only quantized once
193
189
# when it is shared by multiple linear modules
194
190
return (
@@ -332,6 +328,9 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
332
328
)
333
329
334
330
def _int8_asymm_per_token_quant (x : torch .Tensor ) -> torch .Tensor :
331
+ # avoid circular dep
332
+ from torchao .dtypes import to_affine_quantized
333
+
335
334
mapping_type = MappingType .ASYMMETRIC
336
335
target_dtype = torch .int8
337
336
return to_affine_quantized (x , mapping_type , _get_per_token_block_size (x ), target_dtype )
@@ -340,6 +339,9 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
340
339
if weight .shape [- 1 ] % group_size != 0 :
341
340
return weight
342
341
342
+ # avoid circular dep
343
+ from torchao .dtypes import to_affine_quantized
344
+
343
345
# weight settings
344
346
mapping_type = MappingType .SYMMETRIC
345
347
block_size = (1 , group_size )
@@ -371,7 +373,7 @@ def insert_subclass(lin):
371
373
return insert_subclass
372
374
373
375
374
- def int4_weight_only (group_size = 128 , layout_type = TensorCoreTiledLayoutType ( inner_k_tiles = 8 ) ):
376
+ def int4_weight_only (group_size = 128 , inner_k_tiles = 8 ):
375
377
"""
376
378
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
377
379
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -387,12 +389,16 @@ def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner
387
389
Args:
388
390
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
389
391
size is more fine grained, choices are [256, 128, 64, 32]
390
- `layout_type `: layout type for quantized tensor, default is `TensorCoreTiledLayoutType(inner_k_tiles=8)`
392
+ `inner_k_tiles `: parameter for int4 mm kernel, choices are [8, 4, 2]
391
393
"""
392
394
def apply_int4_weight_only_quant (weight ):
393
395
if weight .shape [- 1 ] % group_size != 0 :
394
396
return weight
395
397
398
+ # avoid circular dep
399
+ from torchao .dtypes import to_affine_quantized
400
+ from torchao .dtypes import TensorCoreTiledLayoutType
401
+
396
402
mapping_type = MappingType .ASYMMETRIC
397
403
block_size = (1 , group_size )
398
404
target_dtype = torch .int32
@@ -402,6 +408,7 @@ def apply_int4_weight_only_quant(weight):
402
408
preserve_zero = False
403
409
zero_point_dtype = torch .bfloat16
404
410
zero_point_domain = ZeroPointDomain .FLOAT
411
+ layout_type = TensorCoreTiledLayoutType (inner_k_tiles = inner_k_tiles )
405
412
return to_affine_quantized (weight , mapping_type , block_size , target_dtype , quant_min , quant_max , eps , zero_point_dtype = zero_point_dtype , preserve_zero = preserve_zero , zero_point_domain = zero_point_domain , layout_type = layout_type )
406
413
407
414
return _get_linear_subclass_inserter (apply_int4_weight_only_quant )
@@ -412,6 +419,9 @@ def int8_weight_only():
412
419
Applies int8 weight-only symmetric per-channel quantization to linear layers.
413
420
"""
414
421
def apply_int8wo_quant (weight ):
422
+ # avoid circular dep
423
+ from torchao .dtypes import to_affine_quantized
424
+
415
425
mapping_type = MappingType .SYMMETRIC
416
426
target_dtype = torch .int8
417
427
eps = torch .finfo (torch .float32 ).eps
@@ -422,6 +432,8 @@ def apply_int8wo_quant(weight):
422
432
return _get_linear_subclass_inserter (apply_int8wo_quant )
423
433
424
434
def _int8_symm_per_token_reduced_range_quant (x : torch .Tensor ) -> torch .Tensor :
435
+ # avoid circular dep
436
+ from torchao .dtypes import to_affine_quantized
425
437
mapping_type = MappingType .SYMMETRIC
426
438
target_dtype = torch .int8
427
439
eps = 1e-5
@@ -441,6 +453,8 @@ def apply_int8_dynamic_activation_int8_weight_quant(weight):
441
453
if in_features <= 16 :
442
454
return weight
443
455
456
+ # avoid circular dep
457
+ from torchao .dtypes import to_affine_quantized
444
458
# weight settings
445
459
mapping_type = MappingType .SYMMETRIC
446
460
def get_weight_block_size (x ):
@@ -465,6 +479,7 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
465
479
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
466
480
quantization + 2:4 sparsity to linear layers.
467
481
"""
482
+ from torchao .dtypes import SemiSparseLayoutType
468
483
return int8_dynamic_activation_int8_weight (layout_type = SemiSparseLayoutType ())
469
484
470
485
@@ -480,6 +495,8 @@ def uintx_weight_only(bit_width, group_size=64, pack_dim=-1):
480
495
quantize_affine ,
481
496
dequantize_affine ,
482
497
)
498
+ from torchao .dtypes .uintx .Uintx import UintxLayoutType
499
+ from torchao .dtypes import to_affine_quantized
483
500
from torchao .quantization .quant_api import _get_linear_subclass_inserter
484
501
def apply_uintx_weight_only_quant (weight ):
485
502
0 commit comments