Skip to content

Commit c023f71

Browse files
authored
mixed-precision quantization milestone1: naive_intNwo + eval/benchmark framework (#531)
* milestone1: naive_intNwo + eval/benchmark * remove experiment scripts * remove exp files * use default ZeroPointDomain.INT for int2/3/5/6 * renamed test_naive_intNwo.py to test_mixed_precision.py * updated intNwo with _get_linear_subclass_inserter * adjust sqnr threshold according to bit width * fixed test for int4wo and add __init__.py * skip test_aq_int8_weight_only_quant_3_subclass due to seg fault on nightly * edit the sqnr threshold * add unittest * correct import path
1 parent 013cce3 commit c023f71

File tree

5 files changed

+188
-0
lines changed

5 files changed

+188
-0
lines changed
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn as nn
5+
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
6+
from torchao.quantization.utils import compute_error
7+
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
8+
9+
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
10+
11+
class TestWeightOnlyQuantNaive(unittest.TestCase):
12+
13+
def test_quantization_intNwo(self):
14+
#skip test int4wo for now since it is under development in torchao
15+
for quantization_bit in [2, 3, 5, 6, 8]:
16+
for symmetric in [False, True]:
17+
with self.subTest(quantization_bit=quantization_bit, symmetric=symmetric):
18+
for x_shape in [[64, 32], [80, 80, 80, 32], [16, 64, 32]]:
19+
x = torch.randn(*x_shape, dtype=torch.bfloat16)
20+
m = nn.Sequential(nn.Linear(32, 80)).bfloat16()
21+
y_ref = m(x)
22+
quantize_(m, intN_weight_only(n=quantization_bit, group_size=32, symmetric=symmetric))
23+
y_wo = m(x)
24+
sqnr = compute_error(y_ref, y_wo)
25+
# SQNR_dB can be approximated by 6.02n, where n is the bit width of the quantization
26+
# e.g., we set sqnr threshold = 44 for 8-bit, so that 6.02 * 8= 48.16 fullfills
27+
expected_sqnr_threshold = 44.0 - (8 - quantization_bit) * 6.02
28+
self.assertGreater(sqnr, expected_sqnr_threshold, f"sqnr: {sqnr} is too low")
29+
30+
31+
if __name__ == '__main__':
32+
unittest.main()

torchao/quantization/prototype/mixed_precision/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .naive_intNwo import intN_weight_only
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from naive_intNwo import intN_weight_only
5+
from transformers import AutoModelForCausalLM, AutoTokenizer
6+
7+
from lm_eval.models.huggingface import HFLM
8+
from lm_eval.evaluator import evaluate
9+
from lm_eval.tasks import get_task_dict
10+
11+
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight
12+
from torchao._models._eval import TransformerEvalWrapper
13+
14+
from torchao.quantization.quant_primitives import (
15+
MappingType,
16+
ZeroPointDomain,
17+
)
18+
19+
from torchao.quantization.quant_api import autoquant
20+
21+
22+
torch._inductor.config.force_fuse_int_mm_with_mul = True
23+
torch._inductor.config.fx_graph_cache = True
24+
25+
26+
def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size):
27+
28+
tokenizer = AutoTokenizer.from_pretrained(repo_id)
29+
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
30+
31+
if quantization == "autoquant":
32+
model = autoquant(model.to(device=device))
33+
34+
# naive implementation of uniform precision quantization all layers
35+
elif quantization in ["2","3","4","5","6","8"]:
36+
quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym))
37+
38+
# mix precision quantization for Llama3
39+
elif quantization == "MP_llama3":
40+
41+
# filter for sensitive layers (the first 3 and last 2 layers for Llama3)
42+
def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool:
43+
return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])
44+
45+
# filter for non-sensitive layers (other 27 layers for Llama3)
46+
def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool:
47+
return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']))
48+
49+
# quantize the sensitive layers
50+
if sensi_bit != 16:
51+
quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen)
52+
53+
# quantize the less-sensitive layers
54+
if sensi_bit == 4:
55+
quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)
56+
else:
57+
quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen)
58+
59+
if compile:
60+
model = torch.compile(model, mode="max-autotune", fullgraph=True)
61+
62+
with torch.no_grad():
63+
64+
result = evaluate(
65+
HFLM(
66+
pretrained=model,
67+
tokenizer=tokenizer,
68+
batch_size=batch_size,
69+
max_length=max_length),
70+
get_task_dict(tasks),
71+
limit = limit,
72+
)
73+
74+
for task, res in result["results"].items():
75+
print(f"{task}: {res}")
76+
77+
78+
if __name__ == '__main__':
79+
import argparse
80+
parser = argparse.ArgumentParser(description='Run HF Model Evaluation')
81+
parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.')
82+
parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2')
83+
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
84+
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
85+
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
86+
parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization')
87+
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
88+
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
89+
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
90+
parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers')
91+
parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers')
92+
parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default')
93+
parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on')
94+
args = parser.parse_args()
95+
run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import torch
2+
3+
from torchao.quantization.quant_primitives import (
4+
MappingType,
5+
ZeroPointDomain,
6+
)
7+
8+
from torchao.quantization import int8_weight_only, int4_weight_only
9+
from torchao.quantization.quant_api import _get_linear_subclass_inserter
10+
11+
def intN_weight_only(group_size=32, n=8, symmetric=False):
12+
'''
13+
Apply int N-bit weight only quantization to a linear layer.
14+
Args:
15+
`groupsize`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
16+
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
17+
Usage:
18+
from torchao.quantization import quantize_
19+
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
20+
'''
21+
# for asymmetric quantization
22+
def apply_intN_weight_only_quant_asym(weight):
23+
# avoid circular dependency
24+
from torchao.dtypes import to_affine_quantized
25+
mapping_type = MappingType.ASYMMETRIC
26+
block_size = (1, group_size)
27+
target_dtype = torch.uint8
28+
quant_min = 0
29+
quant_max = 2**n-1
30+
eps = 1e-6
31+
preserve_zero = True
32+
zero_point_dtype = torch.int64
33+
zero_point_domain = ZeroPointDomain.INT
34+
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype)#, preserve_zero=preserve_zero,zero_point_domain=zero_point_domain)
35+
36+
# for symmetric quantization
37+
def apply_intN_weight_only_quant_sym(weight):
38+
# avoid circular dependency
39+
from torchao.dtypes import to_affine_quantized
40+
mapping_type = MappingType.SYMMETRIC
41+
block_size = (1, group_size)
42+
target_dtype = torch.int8
43+
eps = 1e-6
44+
zero_point_dtype = torch.int64
45+
return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
46+
47+
try:
48+
assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]"
49+
if n == 8:
50+
return int8_weight_only()
51+
elif n == 4:
52+
return int4_weight_only(group_size=group_size)
53+
else:
54+
if symmetric:
55+
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_sym)
56+
else:
57+
return _get_linear_subclass_inserter(apply_intN_weight_only_quant_asym)
58+
except Exception as e:
59+
raise
60+

0 commit comments

Comments
 (0)