|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | +from naive_intNwo import intN_weight_only |
| 5 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 6 | + |
| 7 | +from lm_eval.models.huggingface import HFLM |
| 8 | +from lm_eval.evaluator import evaluate |
| 9 | +from lm_eval.tasks import get_task_dict |
| 10 | + |
| 11 | +from torchao.quantization import quantize_, int8_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight |
| 12 | +from torchao._models._eval import TransformerEvalWrapper |
| 13 | + |
| 14 | +from torchao.quantization.quant_primitives import ( |
| 15 | + MappingType, |
| 16 | + ZeroPointDomain, |
| 17 | +) |
| 18 | + |
| 19 | +from torchao.quantization.quant_api import autoquant |
| 20 | + |
| 21 | + |
| 22 | +torch._inductor.config.force_fuse_int_mm_with_mul = True |
| 23 | +torch._inductor.config.fx_graph_cache = True |
| 24 | + |
| 25 | + |
| 26 | +def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compile, batch_size, max_length, sensi_bit, non_sensi_bit, quant_sym, group_size): |
| 27 | + |
| 28 | + tokenizer = AutoTokenizer.from_pretrained(repo_id) |
| 29 | + model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision) |
| 30 | + |
| 31 | + if quantization == "autoquant": |
| 32 | + model = autoquant(model.to(device=device)) |
| 33 | + |
| 34 | + # naive implementation of uniform precision quantization all layers |
| 35 | + elif quantization in ["2","3","4","5","6","8"]: |
| 36 | + quantize_(model.to(device=device), intN_weight_only(n=int(quantization), group_size=group_size, symmetric=quant_sym)) |
| 37 | + |
| 38 | + # mix precision quantization for Llama3 |
| 39 | + elif quantization == "MP_llama3": |
| 40 | + |
| 41 | + # filter for sensitive layers (the first 3 and last 2 layers for Llama3) |
| 42 | + def filter_fn_sen(child: torch.nn.Module, cur_fqn:str) -> bool: |
| 43 | + return isinstance(child, nn.Linear) and any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.']) |
| 44 | + |
| 45 | + # filter for non-sensitive layers (other 27 layers for Llama3) |
| 46 | + def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: |
| 47 | + return isinstance(child, nn.Linear) and not(any(skiplayer in cur_fqn for skiplayer in ['.0.', '.1.', '.2.', '.30.', '.31.'])) |
| 48 | + |
| 49 | + # quantize the sensitive layers |
| 50 | + if sensi_bit != 16: |
| 51 | + quantize_(model.to(device=device), intN_weight_only(n=sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_sen) |
| 52 | + |
| 53 | + # quantize the less-sensitive layers |
| 54 | + if sensi_bit == 4: |
| 55 | + quantize_(model, intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) |
| 56 | + else: |
| 57 | + quantize_(model.to(device=device), intN_weight_only(n=non_sensi_bit, group_size=group_size, symmetric=quant_sym), filter_fn_nonsen) |
| 58 | + |
| 59 | + if compile: |
| 60 | + model = torch.compile(model, mode="max-autotune", fullgraph=True) |
| 61 | + |
| 62 | + with torch.no_grad(): |
| 63 | + |
| 64 | + result = evaluate( |
| 65 | + HFLM( |
| 66 | + pretrained=model, |
| 67 | + tokenizer=tokenizer, |
| 68 | + batch_size=batch_size, |
| 69 | + max_length=max_length), |
| 70 | + get_task_dict(tasks), |
| 71 | + limit = limit, |
| 72 | + ) |
| 73 | + |
| 74 | + for task, res in result["results"].items(): |
| 75 | + print(f"{task}: {res}") |
| 76 | + |
| 77 | + |
| 78 | +if __name__ == '__main__': |
| 79 | + import argparse |
| 80 | + parser = argparse.ArgumentParser(description='Run HF Model Evaluation') |
| 81 | + parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') |
| 82 | + parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') |
| 83 | + parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') |
| 84 | + parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') |
| 85 | + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') |
| 86 | + parser.add_argument('-q', '--quantization', default = "None", choices = ["2", "3", "4", "5", "6", "8", "MP_llama3", "None"], help='Which quantization technique to apply, choose from ["2", "3", "4", "5", "6", "8"] for uniform quantizatoin, choose "MP_llama3" for mixed-precision for Llama3 and need to set corresponding sensi_bit and non_sensi_bit, choose "None" for no quantization') |
| 87 | + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') |
| 88 | + parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes') |
| 89 | + parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') |
| 90 | + parser.add_argument('--sensi_bit', type=int, default=16, choices = [16, 8, 6, 5, 4, 3], help='Bit setting for sensitive layers') |
| 91 | + parser.add_argument('--non_sensi_bit', type=int, default=8, choices = [8, 6, 5, 4, 3, 2], help='Bit setting for non-sensitive layers') |
| 92 | + parser.add_argument('--quant_sym', type=bool, default=False, help='Symmetric or asymmetric quantization, asymmetric by default') |
| 93 | + parser.add_argument('--group_size', type=int, default=32, help='Group size to perform quantization on') |
| 94 | + args = parser.parse_args() |
| 95 | + run_evaluation(args.repo_id, args.tasks, args.limit, args.device, args.precision, args.quantization, args.compile, args.batch_size, args.max_length, args.sensi_bit, args.non_sensi_bit, args.quant_sym, args.group_size) |
0 commit comments