Skip to content

Use model compression pathways #1419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,13 @@
model=model,
**oneshot_kwargs,
stage="sparsity_stage",
output_dir=output_dir,
)

# Sparse finetune
finetune_applied_model = train(
model=oneshot_applied_model,
**oneshot_kwargs,
**training_kwargs,
output_dir=output_dir,
stage="finetuning_stage",
)

Expand Down
14 changes: 14 additions & 0 deletions src/llmcompressor/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def save_checkpoint(
:param save_safetensors: save model checkpoint using safetensors file type
:param save_compressed: save model checkpoint using compressed-tensors format
"""
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
get_model_compressor, # avoid circular import
)

# saving the model also saves the recipe
model.save_pretrained(
save_path,
Expand All @@ -51,6 +55,16 @@ def save_checkpoint(
if processor is not None:
processor.save_pretrained(save_path)

# saving the model modifies the model strcuture
# as this is only a checkpoint, decompress model to enable future training/oneshot
compressor = get_model_compressor(
model=model,
save_compressed=save_compressed,
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
)
if compressor is not None:
compressor.decompress_model(model)


def fallback_to_cpu(device: str) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import weakref
from functools import wraps
from typing import Dict, Optional
from typing import Optional

import torch
import transformers
Expand Down Expand Up @@ -91,43 +91,27 @@ def save_pretrained_wrapper(
# https://github.com/huggingface/transformers/pull/30488
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size

# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.pop("state_dict", None)
if state_dict is None:
logger.info("Fetching state_dict - this may take some time")
state_dict = get_state_dict_offloaded_model(model)

logger.info("Fetching compressor")
# compress model using compressor
compressor = get_model_compressor(
model=model,
sparsity_config=sparsity_config,
quantization_format=quantization_format,
save_compressed=save_compressed,
skip_sparsity_compression_stats=skip_sparsity_compression_stats,
state_dict=state_dict,
disable_sparse_compression=disable_sparse_compression,
)
if compressor is not None:
compressor.compress_model(model)

# save (compressed) model structure
original_save_pretrained.__get__(model, model_class)(
save_directory,
safe_serialization=safe_serialization,
**kwargs,
)

if compressor is None:
# model is not compressed or quantized, save as normal
original_save_pretrained_func = original_save_pretrained.__get__(
model, model_class
)
original_save_pretrained_func(
save_directory, state_dict=state_dict, **kwargs
)
return

# make sure we're on the main process when saving
if state_dict is not None and len(state_dict) > 0:
compressed_state_dict = compressor.compress(model, state_dict)
logger.info("Saving compressed model to disk")
original_save_pretrained.__get__(model, model_class)(
save_directory,
state_dict=compressed_state_dict,
safe_serialization=safe_serialization,
**kwargs,
)
# update config to reflect compression
if compressor is not None:
compressor.update_config(save_directory)

# update existing recipe
Expand Down Expand Up @@ -195,7 +179,6 @@ def get_model_compressor(
quantization_format: Optional[str] = None,
save_compressed: bool = True,
skip_sparsity_compression_stats: bool = True,
state_dict: Optional[Dict] = None,
disable_sparse_compression: bool = False,
):
"""
Expand All @@ -209,12 +192,8 @@ def get_model_compressor(
:param save_compressed: boolean representing to save in a compressed
format
:param skip_sparsity_compression_stats: bool allowing compression stats on std out
:param state_dict: state_dict of the model
:param disable_sparse_compression: bool to skip sparse compression
"""
# find offloaded state dict if none is provided
if state_dict is None:
state_dict = get_state_dict_offloaded_model(model)

if sparsity_config is None:
"""
Expand Down Expand Up @@ -242,6 +221,8 @@ def get_model_compressor(
)
sparsity_config = None
else:
state_dict = get_state_dict_offloaded_model(model)

sparsity_config = SparsityConfigMetadata.from_pretrained(
model,
state_dict=state_dict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from parameterized import parameterized_class
from transformers import AutoConfig

from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
get_model_compressor,
)
from tests.testing_utils import parse_params, requires_gpu

CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_oneshot_configs"
Expand Down Expand Up @@ -34,17 +37,21 @@ def _test_oneshot_and_finetune(self):
output_dir=self.output,
)

train_args = dict(
num_train_epochs=self.num_train_epochs,
precision="bfloat16",
bf16=True,
)
oneshot_model = oneshot(
model=self.model,
**oneshot_args,
stage="test_oneshot_stage",
)

compressor = get_model_compressor(model=oneshot_model, save_compressed=True)
if compressor is not None:
compressor.decompress_model(oneshot_model)

train_args = dict(
num_train_epochs=self.num_train_epochs,
precision="bfloat16",
bf16=True,
)
train(
model=oneshot_model,
**oneshot_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test_oneshot_and_finetune_with_tokenizer(self):
concatenate_data=concatenate_data,
splits=splits,
tokenizer=tokenizer,
output_dir=self.output,
)

oneshot_model = oneshot(
Expand All @@ -70,6 +69,7 @@ def test_oneshot_and_finetune_with_tokenizer(self):
max_steps=max_steps,
stage="test_train_stage",
**model_and_data_kwargs,
output_dir=self.output,
)

input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
import torch
from accelerate import cpu_offload
from accelerate import dispatch_model
from accelerate.accelerator import get_state_dict_offloaded_model
from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat
from compressed_tensors.compressors import ModelCompressor
Expand All @@ -14,7 +14,7 @@
QuantizationStatus,
quantize,
)
from compressed_tensors.utils import get_offloaded_device, update_prefix_dict
from compressed_tensors.utils import align_module_device, update_offload_parameter
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils.quantization_config import CompressedTensorsConfig
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_quant_model_reload(format, dtype, tmp_path):
# technically only tie_word_embeddings=False is supported right now
# setting to True is discouraged
@pytest.mark.parametrize(
"offload,torch_dtype,tie_word_embeddings,device_map",
"offload,torch_dtype,tie_word_embeddings,device",
[
# dtype
(False, torch.float16, False, "cpu"),
Expand All @@ -248,26 +248,25 @@ def test_quant_model_reload(format, dtype, tmp_path):
# (True, torch.float32, True, "cpu"), # TODO: fails
],
)
def test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp_path):
def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path):
model_path = "nm-testing/llama2.c-stories15M"
save_path = tmp_path / "save_path"

model = AutoModelForCausalLM.from_pretrained(
model_path,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
device_map=device_map,
)
if offload:
model = cpu_offload(model)
model = dispatch_model(model, {"": device}, force_hooks=True)
else:
model = model.to(device)

patch_tied_tensors_bug(model)
modify_save_pretrained(model)
model.save_pretrained(save_path, safe_serialization=True)

reloaded = AutoModelForCausalLM.from_pretrained(
save_path, torch_dtype="auto", device_map="cpu"
)
reloaded = AutoModelForCausalLM.from_pretrained(save_path, torch_dtype="auto")

model_dict = get_state_dict_offloaded_model(model)
reloaded_dict = get_state_dict_offloaded_model(reloaded)
Expand All @@ -278,22 +277,20 @@ def test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp

@requires_gpu
@pytest.mark.parametrize(
"offload,torch_dtype,tie_word_embeddings,device_map",
"offload,torch_dtype,tie_word_embeddings,device",
[
(False, torch.float32, False, "cuda:0"),
(True, torch.float32, False, "cuda:0"),
(True, torch.float16, True, "cuda:0"),
(True, torch.float32, True, "cuda:0"),
],
)
def test_model_reload_gpu(
offload, torch_dtype, tie_word_embeddings, device_map, tmp_path
):
test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp_path)
def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp_path):
test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path)


@pytest.mark.parametrize(
"offload,torch_dtype,tie_word_embeddings,device_map",
"offload,torch_dtype,tie_word_embeddings,device",
[
(False, torch.float16, False, "cpu"),
(False, torch.float32, False, "cpu"),
Expand All @@ -305,31 +302,24 @@ def test_model_reload_gpu(
],
)
def test_model_shared_tensors(
offload, torch_dtype, tie_word_embeddings, device_map, tmp_path
offload, torch_dtype, tie_word_embeddings, device, tmp_path
):
# load model
model = AutoModelForCausalLM.from_pretrained(
"nm-testing/llama2.c-stories15M",
torch_dtype=torch_dtype,
tie_word_embeddings=tie_word_embeddings,
device_map=device_map,
)
patch_tied_tensors_bug(model)

if offload:
model = cpu_offload(model)
model = dispatch_model(model, {"": device}, force_hooks=True)
else:
model = model.to(device)

# modify lm head
with torch.no_grad():
if offload:
model.lm_head._hf_hook.pre_forward(model.lm_head)

model.lm_head.weight += 1

if offload:
device = get_offloaded_device(model.lm_head)
update_prefix_dict(model.lm_head, "weight", model.lm_head.weight.to(device))
model.lm_head._hf_hook.post_forward(model.lm_head, None)
with torch.no_grad(), align_module_device(model.lm_head):
update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1)

# check that embed_tokens is not modified
model_dict = get_state_dict_offloaded_model(model)
Expand All @@ -343,17 +333,17 @@ def test_model_shared_tensors(

@requires_gpu
@pytest.mark.parametrize(
"offload,torch_dtype,tie_word_embeddings,device_map",
"offload,torch_dtype,tie_word_embeddings,device",
[
(False, torch.float32, False, "cuda:0"),
(False, torch.float32, True, "cuda:0"),
],
)
def test_model_shared_tensors_gpu(
offload, torch_dtype, tie_word_embeddings, device_map, tmp_path
offload, torch_dtype, tie_word_embeddings, device, tmp_path
):
test_model_shared_tensors(
offload, torch_dtype, tie_word_embeddings, device_map, tmp_path
offload, torch_dtype, tie_word_embeddings, device, tmp_path
)


Expand Down