36
36
_replace_with_custom_fn_if_matches_filter ,
37
37
Quantizer ,
38
38
TwoStepQuantizer ,
39
- int8da_int4w ,
40
- int4wo ,
41
- int8wo ,
42
- int8da_int8w ,
39
+ int8_dynamic_activation_int4_weight ,
40
+ int4_weight_only ,
41
+ int8_weight_only ,
42
+ int8_dynamic_activation_int8_weight ,
43
43
)
44
44
from torchao .utils import (
45
45
TORCH_VERSION_AFTER_2_3 ,
@@ -89,7 +89,7 @@ def convert(self, model: torch.nn.Module) -> torch.nn.Module:
89
89
90
90
class TorchCompileDynamicQuantizer (Quantizer ):
91
91
def quantize (self , model : torch .nn .Module ) -> torch .nn .Module :
92
- quantize (model , int8da_int8w ())
92
+ quantize (model , int8_dynamic_activation_int8_weight ())
93
93
return model
94
94
95
95
class ToyLinearModel (torch .nn .Module ):
@@ -152,7 +152,7 @@ class TestQuantFlow(TestCase):
152
152
def test_dynamic_quant_gpu_singleline (self ):
153
153
m = ToyLinearModel ().eval ()
154
154
example_inputs = m .example_inputs ()
155
- m = quantize (m , int8da_int8w ())
155
+ m = quantize (m , int8_dynamic_activation_int8_weight ())
156
156
quantized = m (* example_inputs )
157
157
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
158
158
# While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {})
@@ -195,7 +195,7 @@ def test_int8_wo_quant_save_load(self):
195
195
)
196
196
m = ToyLinearModel ().eval ().cpu ()
197
197
def api (model ):
198
- model = quantize (model , int8wo ())
198
+ model = quantize (model , int8_weight_only ())
199
199
unwrap_tensor_subclass (model )
200
200
201
201
api (m )
@@ -335,7 +335,7 @@ def test_8da4w_quantizer_eval(self):
335
335
)
336
336
337
337
@unittest .skip ("skipping until we get checkpoints for gpt-fast" )
338
- def test_gptq_quantizer_int4wo (self ):
338
+ def test_gptq_quantizer_int4_weight_only (self ):
339
339
from torchao .quantization .GPTQ import Int4WeightOnlyGPTQQuantizer
340
340
from torchao ._models ._eval import InputRecorder , TransformerEvalWrapper
341
341
torchao ._models .llama .model .use_index_put_for_kv_cache = True
@@ -397,7 +397,7 @@ def test_gptq_quantizer_int4wo(self):
397
397
)
398
398
399
399
@unittest .skip ("skipping until we get checkpoints for gpt-fast" )
400
- def test_quantizer_int4wo (self ):
400
+ def test_quantizer_int4_weight_only (self ):
401
401
from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
402
402
from torchao ._models ._eval import TransformerEvalWrapper
403
403
precision = torch .bfloat16
@@ -499,11 +499,11 @@ def test_eval_wrapper_llama3(self):
499
499
# TODO: move to a separate test file
500
500
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
501
501
def test_quantized_tensor_subclass_8da4w (self ):
502
- groupsize = 32
502
+ group_size = 32
503
503
m = ToyLinearModel ().eval ()
504
504
m_copy = copy .deepcopy (m )
505
505
example_inputs = m .example_inputs ()
506
- m = quantize (m , int8da_int4w ( groupsize = groupsize ))
506
+ m = quantize (m , int8_dynamic_activation_int4_weight ( group_size = group_size ))
507
507
508
508
assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
509
509
assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
@@ -514,7 +514,7 @@ def test_quantized_tensor_subclass_8da4w(self):
514
514
from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
515
515
from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
516
516
517
- quantizer = Int8DynActInt4WeightQuantizer (groupsize = groupsize )
517
+ quantizer = Int8DynActInt4WeightQuantizer (groupsize = group_size )
518
518
m_copy = quantizer .quantize (m_copy )
519
519
assert isinstance (m_copy .linear1 , Int8DynActInt4WeightLinear )
520
520
assert isinstance (m_copy .linear2 , Int8DynActInt4WeightLinear )
@@ -531,13 +531,13 @@ def test_quantized_tensor_subclass_int4(self):
531
531
m_copy = copy .deepcopy (m )
532
532
example_inputs = m .example_inputs (dtype = torch .bfloat16 , device = "cuda" )
533
533
534
- groupsize = 32
535
- m = quantize (m , int4wo ( groupsize = groupsize ))
534
+ group_size = 32
535
+ m = quantize (m , int4_weight_only ( group_size = group_size ))
536
536
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
537
537
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
538
538
539
539
# reference
540
- _ref_change_linear_weights_to_int4_woqtensors (m_copy , groupsize = groupsize )
540
+ _ref_change_linear_weights_to_int4_woqtensors (m_copy , groupsize = group_size )
541
541
542
542
res = m (* example_inputs )
543
543
ref = m_copy (* example_inputs )
@@ -552,7 +552,7 @@ def test_quantized_tensor_subclass_int8_wo(self):
552
552
m_copy = copy .deepcopy (m )
553
553
example_inputs = tuple (map (lambda x : x .to (torch .bfloat16 ), m .example_inputs ()))
554
554
555
- m = quantize (m , int8wo ())
555
+ m = quantize (m , int8_weight_only ())
556
556
557
557
assert isinstance (m .linear1 .weight , AffineQuantizedTensor )
558
558
assert isinstance (m .linear2 .weight , AffineQuantizedTensor )
@@ -575,7 +575,7 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
575
575
m_copy = copy .deepcopy (m )
576
576
# setting batch_size to 20 to be compatible with the kernel
577
577
example_inputs = m .example_inputs (batch_size = 20 , dtype = torch .bfloat16 , device = "cuda" )
578
- m = quantize (m , int8da_int8w ())
578
+ m = quantize (m , int8_dynamic_activation_int8_weight ())
579
579
580
580
assert isinstance (m .linear1 .weight , LinearActQuantizedTensor )
581
581
assert isinstance (m .linear2 .weight , LinearActQuantizedTensor )
@@ -602,29 +602,14 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
602
602
# make sure it compiles
603
603
torch ._export .aot_compile (m_unwrapped , example_inputs )
604
604
605
- def test_register_apply_tensor_subclass (self ):
606
- from torchao import register_apply_tensor_subclass
607
- def apply_my_dtype (weight ):
608
- return weight * 2
609
-
610
- m = ToyLinearModel ().eval ()
611
- example_inputs = m .example_inputs ()
612
- with self .assertRaisesRegex (ValueError , "not supported" ):
613
- quantize (m , "my_dtype" )
614
-
615
- register_apply_tensor_subclass ("my_dtype" , apply_my_dtype )
616
- # make sure it runs
617
- quantize (m , "my_dtype" )
618
- m (* example_inputs )
619
-
620
605
@unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "Test only enabled for 2.4+" )
621
606
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
622
607
def test_quantized_tensor_subclass_save_load (self ):
623
608
m = ToyLinearModel ().eval ().to (torch .bfloat16 )
624
609
m_copy = copy .deepcopy (m )
625
610
example_inputs = m .example_inputs (dtype = torch .bfloat16 )
626
611
627
- m = quantize (m , " int8_weight_only" )
612
+ m = quantize (m , int8_weight_only () )
628
613
ref = m (* example_inputs )
629
614
with tempfile .NamedTemporaryFile () as f :
630
615
torch .save (m .state_dict (), f )
0 commit comments