Skip to content

Commit 0601b5c

Browse files
authored
Expose hqq through uintx_weight_only API (#786)
Expose hqq through `int4_weight_only` API Summary: att, this is a follow up for #605 to make hqq available in quantize_ API `quantize_(model, int4_weight_only(group_size, use_hqq=True)` Test Plan: python generate.py --compile --quantization int4wo-hqq-64 --precision bfloat16 Average tokens/sec: 195.24 Average Bandwidth: 729.40 GB/s Peak Memory Usage: 5.09 GB Model Size: 3.74 GB python eval.py --compile --quantization int4wo-hqq-64 --precision bfloat16 wikitext: {'word_perplexity,none': 12.823631773497512, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.611400903914048, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6883154699192412, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'} Reviewers: Subscribers: Tasks: Tags:
1 parent 92dcc62 commit 0601b5c

File tree

7 files changed

+92
-71
lines changed

7 files changed

+92
-71
lines changed

test/hqq/test_hqq_affine.py

+30-48
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@
1111
)
1212

1313
from torchao.utils import (
14-
TORCH_VERSION_AT_LEAST_2_4,
15-
TORCH_VERSION_AT_LEAST_2_5,
14+
TORCH_VERSION_AT_LEAST_2_3,
15+
)
16+
from torchao.quantization import (
17+
uintx_weight_only,
18+
int4_weight_only,
1619
)
1720

1821
cuda_available = torch.cuda.is_available()
1922

2023
#Parameters
2124
device = 'cuda:0'
2225
compute_dtype = torch.bfloat16
23-
group_size = 64
26+
group_size = 64
2427
mapping_type = MappingType.ASYMMETRIC
2528
block_size = (1, group_size) #axis=1
2629
preserve_zero = False
@@ -34,81 +37,60 @@
3437

3538
def _init_data(in_features, out_features, compute_dtype, device, torch_seed):
3639
torch.random.manual_seed(torch_seed)
37-
linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device)
40+
linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device)
3841
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
3942
y_ref = linear_layer(x)
4043
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)
4144
return W, x, y_ref
4245

43-
def _eval_hqq(nbits, layout_type):
44-
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)
45-
46-
#Plain layout
47-
target_dtype = torch.uint8
48-
#Tensorcore layout
49-
if isinstance(layout_type, TensorCoreTiledLayoutType):
50-
target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32
51-
52-
q_tensor_hqq = to_affine_quantized_intx(
53-
input_float=W,
54-
mapping_type=mapping_type,
55-
block_size=block_size,
56-
target_dtype=target_dtype,
57-
quant_min=0,
58-
quant_max=2**nbits - 1,
59-
zero_point_domain=zero_point_domain,
60-
preserve_zero=preserve_zero,
61-
layout_type=layout_type,
62-
use_hqq=True,
63-
)
46+
def _eval_hqq(dtype):
47+
W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed)
48+
49+
dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False)
50+
dummy_linear.weight.data = W
51+
if dtype == torch.uint4:
52+
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight
53+
else:
54+
q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight
6455

6556
quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
66-
del quant_linear_layer.weight
57+
del quant_linear_layer.weight
6758
quant_linear_layer.weight = q_tensor_hqq
6859
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
6960
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()
7061

7162
return dequantize_error, dot_product_error
7263

7364

74-
class TestHQQBase(unittest.TestCase):
75-
@unittest.skipIf(not cuda_available, "Need CUDA available")
76-
def test_hqq(self, nbits=None, layout_type=None, ref_dequantize_error=None, ref_dot_product_error=None):
77-
if(nbits is None): return
78-
dequantize_error, dot_product_error = _eval_hqq(nbits=nbits, layout_type=layout_type)
65+
@unittest.skipIf(not cuda_available, "Need CUDA available")
66+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+")
67+
class TestHQQ(unittest.TestCase):
68+
def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None):
69+
if(dtype is None): return
70+
dequantize_error, dot_product_error = _eval_hqq(dtype)
7971
self.assertTrue(dequantize_error < ref_dequantize_error)
8072
self.assertTrue(dot_product_error < ref_dot_product_error)
8173

82-
class TestHQQ8Bit(TestHQQBase):
8374
def test_hqq_plain_8bit(self):
84-
self.test_hqq(nbits=8, layout_type=PlainLayoutType(), ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)
75+
self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013)
8576

86-
class TestHQQ7Bit(TestHQQBase):
8777
def test_hqq_plain_7bit(self):
88-
self.test_hqq(nbits=7, layout_type=PlainLayoutType(), ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)
78+
self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193)
8979

90-
class TestHQQ6Bit(TestHQQBase):
9180
def test_hqq_plain_6bit(self):
92-
self.test_hqq(nbits=6, layout_type=PlainLayoutType(), ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)
81+
self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353)
9382

94-
class TestHQQ5Bit(TestHQQBase):
9583
def test_hqq_plain_5bit(self):
96-
self.test_hqq(nbits=5, layout_type=PlainLayoutType(), ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)
84+
self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704)
9785

98-
class TestHQQ4bit(TestHQQBase):
9986
def test_hqq_plain_4bit(self):
100-
self.test_hqq(nbits=4, layout_type=PlainLayoutType(), ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)
101-
102-
def test_hqq_tensorcore_4bit(self):
103-
self.test_hqq(nbits=4, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), ref_dequantize_error=0.000487, ref_dot_product_error=0.00147)
87+
self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472)
10488

105-
class TestHQQ3Bit(TestHQQBase):
10689
def test_hqq_plain_3bit(self):
107-
self.test_hqq(nbits=3, layout_type=PlainLayoutType(), ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)
90+
self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047)
10891

109-
class TestHQQ2Bit(TestHQQBase):
11092
def test_hqq_plain_2bit(self):
111-
self.test_hqq(nbits=2, layout_type=PlainLayoutType(), ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)
93+
self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255)
11294

11395
if __name__ == "__main__":
11496
unittest.main()

torchao/_models/llama/eval.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,19 @@ def run_evaluation(
8181
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
8282
quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq))
8383
if "uintx" in quantization:
84-
# uintx-nbits-group_size
84+
# uintx-nbits-groupsize
8585
# "uintx-2-64"
86+
if "hqq" in quantization:
87+
use_hqq = True
88+
quantization = quantization[:-4]
89+
else:
90+
use_hqq = False
8691
_quant_args = quantization.split("-")
87-
nbits = int(_quant_args[1])
92+
nbits = int(_quant_args[0])
8893
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
8994
dtype = _NBITS_TO_DTYPE[nbits]
90-
group_size = int(_quant_args[2])
91-
quantize_(model, uintx_weight_only(dtype, group_size))
95+
group_size = int(_quant_args[1])
96+
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
9297
if "int4wo" in quantization and "gptq" in quantization:
9398
groupsize=int(quantization.split("-")[-2])
9499
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
@@ -135,7 +140,7 @@ def run_evaluation(
135140
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
136141
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
137142
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
138-
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<group_size>")
143+
parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, int4wo-<groupsize>-gptq, int4wo-<groupsize>-hqq, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq")
139144
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
140145
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
141146
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')

torchao/_models/llama/generate.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,20 @@ def main(
276276
if "fp6" in quantization:
277277
quantize_(model, fpx_weight_only(3, 2))
278278
if "uintx" in quantization:
279-
# uintx-nbits-group_size
280-
# "uintx-2-64"
279+
# uintx-nbits-groupsize, e.g. "uintx-2-64"
280+
if "hqq" in quantization:
281+
# uintx-nbits-groupsize-hqq
282+
quantization = quantization[:-4]
283+
use_hqq = True
284+
else:
285+
use_hqq = False
281286
_quant_args = quantization.split("-")
282-
nbits = int(_quant_args[1])
287+
nbits = int(_quant_args[0])
283288
assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8"
284289
_NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8}
285290
dtype = _NBITS_TO_DTYPE[nbits]
286-
group_size = int(_quant_args[2])
287-
quantize_(model, uintx_weight_only(dtype, group_size))
291+
group_size = int(_quant_args[1])
292+
quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq))
288293
if "autoquant" in quantization:
289294
if "autoquant-int4" == quantization:
290295
model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST)
@@ -454,7 +459,7 @@ def callback(x):
454459
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
455460
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
456461
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
457-
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<group_size>')
462+
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant, autoquant-int4, int4wo-<groupsize>-hqq, autoround-<model_device>-<quant_lm_head>-<iters>-<groupsize>-<batch_size>-<seqlen>-<nsamples>, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq')
458463
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
459464
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
460465
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')

torchao/dtypes/affine_quantized_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
ZeroPointDomain,
1212
MappingType,
1313
int_scaled_matmul,
14-
quantize_affine_hqq,
14+
choose_qparams_and_quantize_affine_hqq,
1515
FP8_TYPES,
1616
choose_qparams_affine_fpx,
1717
quantize_affine_fpx,
@@ -266,7 +266,7 @@ def from_hp_to_intx(
266266
group_size = max(block_size)
267267
compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype
268268
device = input_float.device
269-
data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
269+
data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)
270270
data = data.to(target_dtype)
271271
else:
272272
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)

torchao/quantization/README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,12 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is
228228
```python
229229
# for torch 2.4+
230230
from torchao.quantization import quantize_, int4_weight_only
231-
quantize_(model, int4_weight_only())
231+
group_size = 32
232+
233+
# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through
234+
# use_hqq flag for `int4_weight_only` quantization
235+
use_hqq = False
236+
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
232237

233238
# for torch 2.2.2 and 2.3
234239
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors

torchao/quantization/quant_api.py

+31-7
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def input_quant_func(x: torch.Tensor):
694694
return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant)
695695

696696

697-
def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
697+
def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False):
698698
"""
699699
Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where
700700
x is the number of bits specified by `dtype`
@@ -704,23 +704,46 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1):
704704
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
705705
size is more fine grained, defaults to 64
706706
`pack_dim`: the dimension we use for packing, defaults to -1
707+
`use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight
707708
"""
708-
def apply_uintx_weight_only_quant(weight):
709-
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
709+
from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS
710+
711+
SUPPORTED_DTYPES = {torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8}
712+
assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}"
713+
714+
def apply_uintx_weight_only_quant(weight, dtype):
710715
mapping_type = MappingType.ASYMMETRIC
711716
block_size = (1, group_size)
712-
eps = torch.finfo(torch.float32).eps
713-
zero_point_dtype = torch.int32
714-
zero_point_domain = ZeroPointDomain.INT
717+
718+
if use_hqq:
719+
if dtype == torch.uint4:
720+
logger.warn(f"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance")
721+
quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype]
722+
dtype = torch.uint8
723+
eps = None
724+
zero_point_dtype = None
725+
zero_point_domain = ZeroPointDomain.FLOAT
726+
preserve_zero = False
727+
layout_type = PlainLayoutType()
728+
else:
729+
quant_min, quant_max = None, None
730+
eps = torch.finfo(torch.float32).eps
731+
zero_point_dtype = torch.int32
732+
zero_point_domain = ZeroPointDomain.INT
733+
preserve_zero = True
734+
layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim)
715735

716736
return to_affine_quantized_intx(
717737
weight, mapping_type, block_size, dtype,
738+
quant_min=quant_min, quant_max=quant_max,
718739
eps=eps, zero_point_dtype=zero_point_dtype,
719740
zero_point_domain=zero_point_domain,
741+
preserve_zero=preserve_zero,
720742
layout_type=layout_type,
743+
use_hqq=use_hqq,
721744
)
722745

723-
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant)
746+
return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype)
724747

725748
def fpx_weight_only(ebits: int, mbits: int):
726749
"""Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits
@@ -750,5 +773,6 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor:
750773
return to_affine_quantized_fpx(weight, layout_type)
751774
return _get_linear_subclass_inserter(apply_quant_llm)
752775

776+
753777
if TORCH_VERSION_AT_LEAST_2_5:
754778
torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant])

torchao/quantization/quant_primitives.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
"dequantize_affine_fpx",
3131
"fake_quantize_affine",
3232
"fake_quantize_affine_cachemask",
33-
"quantize_affine_hqq",
33+
"choose_qparams_and_quantize_affine_hqq",
3434
]
3535

3636
class MappingType(Enum):
@@ -842,7 +842,7 @@ def _convert_to_affinequantized_format(W_q: torch.Tensor, scale: torch.Tensor, z
842842
return W_q_ao, scale_ao, zero_ao
843843

844844
# Main hqq quantizer function
845-
def quantize_affine_hqq(
845+
def choose_qparams_and_quantize_affine_hqq(
846846
tensor: torch.Tensor,
847847
nbits: float = 4,
848848
group_size: int = 64,

0 commit comments

Comments
 (0)