Skip to content

Commit 31f119e

Browse files
authored
Initial ParetoQ commit (#1876)
This project contains the training code of ParetoQ introduced in: "ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization" (https://arxiv.org/abs/2502.02631). All code is written by @liuzechun and @zxdmike and migrated from https://github.com/facebookresearch/ParetoQ. ParetoQ is the first unified framework that facilitates rigorous comparisons across 1-bit, 1.58-bit, 2-bit, 3-bit, and 4-bit quantization settings. By optimizing training schemes and refining quantization functions, ParetoQ surpasses all previous methods tailored to specific bit widths. Specifically, the 1.58-bit ParetoQ LLaMA-3 8B model reduces the performance gap to full precision by relatively 37.8% compared to the 1-bit Era’s 1.58-bit LLaMA-3 8B model, while using only 30% of the training tokens.
1 parent 6726b0b commit 31f119e

18 files changed

+2297
-0
lines changed

ruff.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ lint.ignore = ["E731"]
77
# Exclude third-party modules
88
exclude = [
99
"third_party/*",
10+
"torchao/prototype/paretoq/*",
1011
]

test/prototype/test_paretoq.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
from torchao.prototype.paretoq.models.utils_quant import (
12+
LsqBinaryTernaryExtension,
13+
QuantizeLinear,
14+
StretchedElasticQuant,
15+
)
16+
17+
18+
class M(torch.nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float32)
22+
23+
def forward(self, x):
24+
return self.linear(x)
25+
26+
27+
class TestParetoQ(unittest.TestCase):
28+
def test_quantized_linear(self):
29+
m = M()
30+
example_inputs = torch.randn(1, 256).to(torch.float32)
31+
for w_bits in [0, 1, 2, 3, 4, 16]:
32+
m.linear = QuantizeLinear(
33+
m.linear.in_features,
34+
m.linear.out_features,
35+
bias=False,
36+
w_bits=w_bits,
37+
)
38+
m(example_inputs)
39+
40+
def test_quantize_functions(self):
41+
x = torch.randn(256, 256).to(torch.float32)
42+
alpha = torch.Tensor(256, 1)
43+
for layerwise in [True, False]:
44+
LsqBinaryTernaryExtension.apply(x, alpha, 1, layerwise)
45+
LsqBinaryTernaryExtension.apply(x, alpha, 3, layerwise)
46+
LsqBinaryTernaryExtension.apply(x, alpha, 4, layerwise)
47+
StretchedElasticQuant.apply(x, alpha, 0, layerwise)
48+
StretchedElasticQuant.apply(x, alpha, 2, layerwise)
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
torchrun --nnodes=1 --nproc_per_node=1 train.py \
7+
--local_dir "/tmp/llama/" \
8+
--input_model_filename "meta-llama/Llama-3.2-1B" \
9+
--output_model_filename "1B-finetuned" \
10+
--train_data_local_path "/tmp/train.jsonl" \
11+
--do_train True \
12+
--do_eval False \
13+
--model_max_length 2048 \
14+
--fp16 False \
15+
--bf16 True \
16+
--log_on_each_node False \
17+
--logging_dir /tmp/output/runs/current \
18+
--num_train_epochs 1 \
19+
--per_device_train_batch_size 2 \
20+
--per_device_eval_batch_size 1 \
21+
--gradient_accumulation_steps 1 \
22+
--evaluation_strategy "no" \
23+
--save_strategy "steps" \
24+
--save_steps 2000 \
25+
--report_to "tensorboard" \
26+
--save_total_limit 1 \
27+
--learning_rate 2e-5 \
28+
--weight_decay 0. \
29+
--warmup_ratio 0. \
30+
--lr_scheduler_type "cosine" \
31+
--logging_steps 1 \
32+
--tf32 False \
33+
--gradient_checkpointing False \
34+
--qat True \
35+
--w_bits 4 \
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
torchrun --nnodes=1 --nproc_per_node=1 train.py \
7+
CUDA_VISIBLE_DEVICES=0 torchrun --nnodes=1 --nproc_per_node=1 train.py \
8+
--local_dir "/tmp/llama/" \
9+
--input_model_filename "/tmp/llama_1B/llama_1B_bit1" \
10+
--output_model_filename "1B-finetuned" \
11+
--train_data_local_path "/tmp/train.jsonl" \
12+
--eval_data_local_path "/tmp/wikitext-2/test.jsonl" \
13+
--do_train False \
14+
--do_eval True \
15+
--model_max_length 2048 \
16+
--fp16 False \
17+
--bf16 True \
18+
--log_on_each_node False \
19+
--logging_dir /tmp/output/runs/current \
20+
--num_train_epochs 1 \
21+
--per_device_train_batch_size 2 \
22+
--per_device_eval_batch_size 4 \
23+
--gradient_accumulation_steps 1 \
24+
--evaluation_strategy "no" \
25+
--save_strategy "steps" \
26+
--save_steps 2000 \
27+
--report_to "tensorboard" \
28+
--save_total_limit 1 \
29+
--learning_rate 2e-5 \
30+
--weight_decay 0. \
31+
--warmup_ratio 0. \
32+
--lr_scheduler_type "cosine" \
33+
--logging_steps 1 \
34+
--tf32 False \
35+
--gradient_checkpointing False \
36+
--qat True \
37+
--w_bits 1 \
38+
--contain_weight_clip_val True \

torchao/prototype/paretoq/README.md

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# ParetoQ
2+
3+
4+
This repository contains the training code of ParetoQ introduced in our work: "[ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization](https://arxiv.org/abs/2502.02631)"
5+
6+
In this work, we present ParetoQ, the first unified framework that facilitates rigorous comparisons across 1-bit, 1.58-bit, 2-bit, 3-bit, and 4-bit quantization settings. By optimizing training schemes and refining quantization functions, ParetoQ surpasses all previous methods tailored to specific bit widths. Specifically, the 1.58-bit ParetoQ LLaMA-3 8B model reduces the performance gap to full precision by relatively 37.8% compared to the 1-bit Era’s 1.58-bit LLaMA-3 8B model, while using only 30% of the training tokens.
7+
8+
<div align=center>
9+
<img width=50% src="./main_result_ternary.jpg"/>
10+
</div>
11+
12+
<div align=center>
13+
<img width=100% src="./main_result_234bit.jpg"/>
14+
</div>
15+
16+
With the SoTA points obtained through ParetoQ, we are able to improve the scaling law analysis. Figure (a) (b) demonstrates that sub-4-bit quantization, including binary, ternary, 2-bit, and 3-bit, often outperform 4-bit quantization. Notably, 2-bit and ternary models reside on the Pareto frontier. When considering hardware-friendliness and real-time speed, we generally recommend exploring 2-bit quantization for on-device applications.
17+
18+
<div align=center>
19+
<img width=100% src="./main_result_scaling_law.jpg"/>
20+
</div>
21+
## Citation
22+
23+
If you find our code useful for your research, please consider citing:
24+
25+
@article{liu2025paretoq,
26+
title={ParetoQ: Scaling Laws in Extremely Low-bit LLM Quantization},
27+
author={Liu, Zechun and Zhao, Changsheng and Huang, Hanxian and Chen, Sijia and Zhang, Jing and Zhao, Jiawei and Roy, Scott and Jin, Lisa and Xiong, Yunyang and Shi, Yangyang and others},
28+
journal={arXiv preprint arXiv:2502.02631},
29+
year={2025}
30+
}
31+
32+
## Run
33+
34+
### 1. Requirements:
35+
* python 3.11
36+
* pip3 install torch
37+
* pip install -r requirement.txt
38+
39+
### 2. Steps to run:
40+
* Specify the data path and the pre-trained full-precision model path in run_train.sh file
41+
* Run `bash 1_run_train.sh $w_bit` E.g. `bash 1_run_train.sh 2` for 2-bit weight quantization.
42+
43+
## Comparison to SoTA Ternary LLM methods
44+
The results reported in the paper is run with the internal LLaMA codebase in Meta. We reproduced our experiments with huggingface codebase and released code here. The results are close to those in the paper.
45+
46+
| Method | #Params | Arc-e | Arc-c | Boolq | Piqa | Siqa | HellaSwag | Obqa | WinoGrande | Avg. | Wiki |
47+
| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |
48+
| RTN | 600M | 26.2 | 24.6 | 62.2 | 49.5 | 36.3 | 26.1 | 27.1 | 48.8 | 37.6 | 6.60E+05 |
49+
| LLM-QAT | 600M | 34.0 | 23.0 | 59.4 | 53.6 | 38.9 | 28.7 | 32.3 | 51.4 | 40.2 | 71.7 |
50+
| 1-bit era | 700M | 49.5 | 29.0 | 59.2 | 67.5 | 43.6 | 43.2 | 38.9 | 53.5 | 48.1 | 17.3 |
51+
| Spectra | 560M | 50.2 | 21.0 | 57.3 | 67.5 | -- | 33.8 | -- | 53.1 | | -- |
52+
| **ParetoQ** | **600M** | **65.5** | **43.8** | **62.3** | **70.6** | **44.7** | **51.3** | **47.1** | **58.8** | **55.5** | **11.4** |
53+
| RTN | 1B | 25.7 | 24.8 | 37.8 | 49.3 | 37.1 | 26.2 | 25.2 | 50.2 | 34.5 | 1.40E+05 |
54+
| LLM-QAT | 1B | 36.0 | 26.2 | 47.7 | 55.1 | 39.7 | 31.3 | 33.5 | 49.6 | 39.9 | 56.9 |
55+
| 1-bit era | 1.3B | 52.4 | 34.1 | 61.9 | 69.1 | 44.7 | 47.4 | 41.1 | 55.3 | 50.8 | 23.6 |
56+
| Spectra | 1.1B | 56.3 | 24.6 | 59.1 | 69.3 | -- | 38.8 | -- | 55.5 | | -- |
57+
| **ParetoQ** | **1B** | **68.5** | **47.6** | **62.8** | **72.1** | **45.3** | **57.4** | **52.9** | **61.3** | **58.5** | **10.0** |
58+
| RTN | 3B | 26.9 | 23.6 | 62.2 | 51.3 | 37.6 | 26.4 | 27.0 | 49.3 | 38.0 | 4.40E+05 |
59+
| LLM-QAT | 3B | 44.5 | 30.7 | 62.1 | 62.7 | 41.0 | 43.4 | 35.0 | 50.6 | 46.3 | 6.50E+02 |
60+
| 1-bit era | 3B | 58.7 | 37.2 | 61.3 | 71.3 | 45.2 | 56.0 | 45.8 | 60.3 | 54.5 | 265.6 |
61+
| Spectra | 3.9B | 66.0 | 31.9 | 66.5 | 74.4 | -- | 48.3 | -- | 62.1 | | -- |
62+
| **ParetoQ** | **3B** | **71.5** | **48.6** | **68.2** | **75.5** | **46.4** | **67.9** | **54.3** | **63.1** | **61.9** | **9.9** |
63+
64+
More results for other bit widths can be found in the [paper](https://arxiv.org/abs/2502.02631).
65+
66+
## Acknowledgement
67+
68+
This code is partially based on HuggingFace [Transformers](https://github.com/huggingface/transformers) repo under [Apache License](https://github.com/huggingface/transformers/blob/main/LICENSE).
69+
70+
## Contact
71+
72+
Zechun Liu, Reality Labs, Meta Inc (zechunliu at meta dot com)
73+
74+
Changsheng Zhao, Reality Labs, Meta Inc (cszhao at meta dot com)
75+
76+
## License
77+
78+
ParetoQ is released under the [BSD 3](https://github.com/facebookresearch/ParetoQ/blob/main/LICENSE) license.
79+

torchao/prototype/paretoq/__init__.py

Whitespace-only changes.
1.06 MB
Loading
Loading
Loading

torchao/prototype/paretoq/models/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)