|
11 | 11 | )
|
12 | 12 |
|
13 | 13 | from torchao.utils import (
|
14 |
| - TORCH_VERSION_AT_LEAST_2_4, |
15 |
| - TORCH_VERSION_AT_LEAST_2_5, |
| 14 | + TORCH_VERSION_AT_LEAST_2_3, |
| 15 | +) |
| 16 | +from torchao.quantization import ( |
| 17 | + uintx_weight_only, |
| 18 | + int4_weight_only, |
16 | 19 | )
|
17 | 20 |
|
18 | 21 | cuda_available = torch.cuda.is_available()
|
19 | 22 |
|
20 | 23 | #Parameters
|
21 | 24 | device = 'cuda:0'
|
22 | 25 | compute_dtype = torch.bfloat16
|
23 |
| -group_size = 64 |
| 26 | +group_size = 64 |
24 | 27 | mapping_type = MappingType.ASYMMETRIC
|
25 | 28 | block_size = (1, group_size) #axis=1
|
26 | 29 | preserve_zero = False
|
|
34 | 37 |
|
35 | 38 | def _init_data(in_features, out_features, compute_dtype, device, torch_seed):
|
36 | 39 | torch.random.manual_seed(torch_seed)
|
37 |
| - linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device) |
| 40 | + linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device) |
38 | 41 | x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
|
39 | 42 | y_ref = linear_layer(x)
|
40 | 43 | W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
|
41 | 44 | return W, x, y_ref
|
42 | 45 |
|
43 |
| -def _eval_hqq(nbits, layout_type): |
44 |
| - W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed) |
45 |
| - |
46 |
| - #Plain layout |
47 |
| - target_dtype = torch.uint8 |
48 |
| - #Tensorcore layout |
49 |
| - if isinstance(layout_type, TensorCoreTiledLayoutType): |
50 |
| - target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32 |
51 |
| - |
52 |
| - q_tensor_hqq = to_affine_quantized_intx( |
53 |
| - input_float=W, |
54 |
| - mapping_type=mapping_type, |
55 |
| - block_size=block_size, |
56 |
| - target_dtype=target_dtype, |
57 |
| - quant_min=0, |
58 |
| - quant_max=2**nbits - 1, |
59 |
| - zero_point_domain=zero_point_domain, |
60 |
| - preserve_zero=preserve_zero, |
61 |
| - layout_type=layout_type, |
62 |
| - use_hqq=True, |
63 |
| - ) |
| 46 | +def _eval_hqq(dtype): |
| 47 | + W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed) |
| 48 | + |
| 49 | + dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False) |
| 50 | + dummy_linear.weight.data = W |
| 51 | + if dtype == torch.uint4: |
| 52 | + q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight |
| 53 | + else: |
| 54 | + q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight |
64 | 55 |
|
65 | 56 | quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
|
66 |
| - del quant_linear_layer.weight |
| 57 | + del quant_linear_layer.weight |
67 | 58 | quant_linear_layer.weight = q_tensor_hqq
|
68 | 59 | dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
|
69 | 60 | dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
|
70 | 61 |
|
71 | 62 | return dequantize_error, dot_product_error
|
72 | 63 |
|
73 | 64 |
|
74 |
| -class TestHQQBase(unittest.TestCase): |
75 |
| - @unittest.skipIf(not cuda_available, "Need CUDA available") |
76 |
| - def test_hqq(self, nbits=None, layout_type=None, ref_dequantize_error=None, ref_dot_product_error=None): |
77 |
| - if(nbits is None): return |
78 |
| - dequantize_error, dot_product_error = _eval_hqq(nbits=nbits, layout_type=layout_type) |
| 65 | +@unittest.skipIf(not cuda_available, "Need CUDA available") |
| 66 | +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") |
| 67 | +class TestHQQ(unittest.TestCase): |
| 68 | + def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None): |
| 69 | + if(dtype is None): return |
| 70 | + dequantize_error, dot_product_error = _eval_hqq(dtype) |
79 | 71 | self.assertTrue(dequantize_error < ref_dequantize_error)
|
80 | 72 | self.assertTrue(dot_product_error < ref_dot_product_error)
|
81 | 73 |
|
82 |
| -class TestHQQ8Bit(TestHQQBase): |
83 | 74 | def test_hqq_plain_8bit(self):
|
84 |
| - self.test_hqq(nbits=8, layout_type=PlainLayoutType(), ref_dequantize_error=5e-5, ref_dot_product_error=0.00013) |
| 75 | + self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013) |
85 | 76 |
|
86 |
| -class TestHQQ7Bit(TestHQQBase): |
87 | 77 | def test_hqq_plain_7bit(self):
|
88 |
| - self.test_hqq(nbits=7, layout_type=PlainLayoutType(), ref_dequantize_error=6e-05, ref_dot_product_error=0.000193) |
| 78 | + self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193) |
89 | 79 |
|
90 |
| -class TestHQQ6Bit(TestHQQBase): |
91 | 80 | def test_hqq_plain_6bit(self):
|
92 |
| - self.test_hqq(nbits=6, layout_type=PlainLayoutType(), ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353) |
| 81 | + self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353) |
93 | 82 |
|
94 |
| -class TestHQQ5Bit(TestHQQBase): |
95 | 83 | def test_hqq_plain_5bit(self):
|
96 |
| - self.test_hqq(nbits=5, layout_type=PlainLayoutType(), ref_dequantize_error=0.00023, ref_dot_product_error=0.000704) |
| 84 | + self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704) |
97 | 85 |
|
98 |
| -class TestHQQ4bit(TestHQQBase): |
99 | 86 | def test_hqq_plain_4bit(self):
|
100 |
| - self.test_hqq(nbits=4, layout_type=PlainLayoutType(), ref_dequantize_error=0.000487, ref_dot_product_error=0.001472) |
101 |
| - |
102 |
| - def test_hqq_tensorcore_4bit(self): |
103 |
| - self.test_hqq(nbits=4, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), ref_dequantize_error=0.000487, ref_dot_product_error=0.00147) |
| 87 | + self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472) |
104 | 88 |
|
105 |
| -class TestHQQ3Bit(TestHQQBase): |
106 | 89 | def test_hqq_plain_3bit(self):
|
107 |
| - self.test_hqq(nbits=3, layout_type=PlainLayoutType(), ref_dequantize_error=0.00101, ref_dot_product_error=0.003047) |
| 90 | + self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047) |
108 | 91 |
|
109 |
| -class TestHQQ2Bit(TestHQQBase): |
110 | 92 | def test_hqq_plain_2bit(self):
|
111 |
| - self.test_hqq(nbits=2, layout_type=PlainLayoutType(), ref_dequantize_error=0.002366, ref_dot_product_error=0.007255) |
| 93 | + self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255) |
112 | 94 |
|
113 | 95 | if __name__ == "__main__":
|
114 | 96 | unittest.main()
|
0 commit comments