@@ -334,11 +334,13 @@ def ffn_or_attn_only(mod, fqn):
334
334
335
335
if quantization :
336
336
from torchao .quantization import (
337
+ Float8DynamicActivationFloat8SemiSparseWeightConfig ,
337
338
autoquant ,
338
339
float8_dynamic_activation_float8_weight ,
339
340
float8_weight_only ,
340
341
fpx_weight_only ,
341
342
gemlite_uintx_weight_only ,
343
+ int4_dynamic_activation_int4_weight ,
342
344
int4_weight_only ,
343
345
int8_dynamic_activation_int4_weight ,
344
346
int8_dynamic_activation_int8_weight ,
@@ -434,18 +436,30 @@ def ffn_or_attn_only(mod, fqn):
434
436
]
435
437
), f"int4wo group_size needs to be one of [32,64,128,256] but got { group_size } "
436
438
quantize_ (model , int4_weight_only (group_size = group_size , use_hqq = use_hqq ))
437
- elif "int8adq-int4w-symm " in quantization :
439
+ elif "int4dq- " in quantization :
438
440
from torchao .dtypes import CutlassInt4PackedLayout
439
441
440
- quantize_ (
441
- model ,
442
- int8_dynamic_activation_int4_weight (
443
- group_size = None ,
444
- mapping_type = MappingType .SYMMETRIC ,
445
- act_mapping_type = MappingType .SYMMETRIC ,
446
- layout = CutlassInt4PackedLayout (),
447
- ),
448
- )
442
+ nbits = int (quantization .removeprefix ("int4dq-" ))
443
+ assert nbits == 4 or nbits == 8
444
+ if nbits == 4 :
445
+ quantize_ (
446
+ model ,
447
+ int4_dynamic_activation_int4_weight (
448
+ mapping_type = MappingType .SYMMETRIC ,
449
+ act_mapping_type = MappingType .SYMMETRIC ,
450
+ layout = CutlassInt4PackedLayout (),
451
+ ),
452
+ )
453
+ elif nbits == 8 :
454
+ quantize_ (
455
+ model ,
456
+ int8_dynamic_activation_int4_weight (
457
+ group_size = None ,
458
+ mapping_type = MappingType .SYMMETRIC ,
459
+ act_mapping_type = MappingType .SYMMETRIC ,
460
+ layout = CutlassInt4PackedLayout (),
461
+ ),
462
+ )
449
463
if "marlin" in quantization :
450
464
if "qqq" in quantization :
451
465
from torchao .dtypes import MarlinQQQLayout
@@ -564,16 +578,24 @@ def ffn_or_attn_only(mod, fqn):
564
578
elif "float8wo" in quantization :
565
579
quantize_ (model , float8_weight_only ())
566
580
elif "float8dq" in quantization :
567
- granularity = str (quantization .split ("-" )[- 1 ])
568
- if granularity == "tensor" :
569
- granularity = PerTensor ()
570
- elif granularity == "row" :
571
- granularity = PerRow ()
581
+ if sparsity and "semi" in sparsity :
582
+ quantize_ (
583
+ model ,
584
+ Float8DynamicActivationFloat8SemiSparseWeightConfig (),
585
+ filter_fn = ffn_only
586
+ )
572
587
else :
573
- granularity = PerTensor ()
574
- quantize_ (
575
- model , float8_dynamic_activation_float8_weight (granularity = granularity )
576
- )
588
+ granularity = str (quantization .split ("-" )[- 1 ])
589
+ if granularity == "tensor" :
590
+ granularity = PerTensor ()
591
+ elif granularity == "row" :
592
+ granularity = PerRow ()
593
+ else :
594
+ granularity = PerTensor ()
595
+ quantize_ (
596
+ model ,
597
+ float8_dynamic_activation_float8_weight (granularity = granularity ),
598
+ )
577
599
elif "autoquant_v2" in quantization :
578
600
from torchao ._models ._eval import InputRecorder
579
601
from torchao ._models .llama .model import prepare_inputs_for_model
@@ -1130,7 +1152,7 @@ def callback(x):
1130
1152
help = (
1131
1153
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
1132
1154
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
1133
- + "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, int8adq-int4w-symm "
1155
+ + "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits> "
1134
1156
),
1135
1157
)
1136
1158
parser .add_argument (
0 commit comments