diff --git a/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py b/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py index 51bdad4d5..6ed01e7d1 100644 --- a/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py +++ b/examples/quantization_2of4_sparse_w4a16/llama7b_sparse_w4a16.py @@ -68,7 +68,6 @@ model=model, **oneshot_kwargs, stage="sparsity_stage", - output_dir=output_dir, ) # Sparse finetune @@ -76,7 +75,6 @@ model=oneshot_applied_model, **oneshot_kwargs, **training_kwargs, - output_dir=output_dir, stage="finetuning_stage", ) diff --git a/src/llmcompressor/pytorch/model_load/helpers.py b/src/llmcompressor/pytorch/model_load/helpers.py index 8e1782328..1e665db43 100644 --- a/src/llmcompressor/pytorch/model_load/helpers.py +++ b/src/llmcompressor/pytorch/model_load/helpers.py @@ -41,6 +41,10 @@ def save_checkpoint( :param save_safetensors: save model checkpoint using safetensors file type :param save_compressed: save model checkpoint using compressed-tensors format """ + from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + get_model_compressor, # avoid circular import + ) + # saving the model also saves the recipe model.save_pretrained( save_path, @@ -51,6 +55,16 @@ def save_checkpoint( if processor is not None: processor.save_pretrained(save_path) + # saving the model modifies the model strcuture + # as this is only a checkpoint, decompress model to enable future training/oneshot + compressor = get_model_compressor( + model=model, + save_compressed=save_compressed, + skip_sparsity_compression_stats=skip_sparsity_compression_stats, + ) + if compressor is not None: + compressor.decompress_model(model) + def fallback_to_cpu(device: str) -> str: """ diff --git a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py index 71d328c8a..25efc792d 100644 --- a/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/sparsification/compressed_tensors_utils.py @@ -2,7 +2,7 @@ import re import weakref from functools import wraps -from typing import Dict, Optional +from typing import Optional import torch import transformers @@ -91,43 +91,27 @@ def save_pretrained_wrapper( # https://github.com/huggingface/transformers/pull/30488 transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size - # state_dict gets passed in as a kwarg for FSDP models - state_dict = kwargs.pop("state_dict", None) - if state_dict is None: - logger.info("Fetching state_dict - this may take some time") - state_dict = get_state_dict_offloaded_model(model) - - logger.info("Fetching compressor") + # compress model using compressor compressor = get_model_compressor( model=model, sparsity_config=sparsity_config, quantization_format=quantization_format, save_compressed=save_compressed, skip_sparsity_compression_stats=skip_sparsity_compression_stats, - state_dict=state_dict, disable_sparse_compression=disable_sparse_compression, ) + if compressor is not None: + compressor.compress_model(model) + + # save (compressed) model structure + original_save_pretrained.__get__(model, model_class)( + save_directory, + safe_serialization=safe_serialization, + **kwargs, + ) - if compressor is None: - # model is not compressed or quantized, save as normal - original_save_pretrained_func = original_save_pretrained.__get__( - model, model_class - ) - original_save_pretrained_func( - save_directory, state_dict=state_dict, **kwargs - ) - return - - # make sure we're on the main process when saving - if state_dict is not None and len(state_dict) > 0: - compressed_state_dict = compressor.compress(model, state_dict) - logger.info("Saving compressed model to disk") - original_save_pretrained.__get__(model, model_class)( - save_directory, - state_dict=compressed_state_dict, - safe_serialization=safe_serialization, - **kwargs, - ) + # update config to reflect compression + if compressor is not None: compressor.update_config(save_directory) # update existing recipe @@ -195,7 +179,6 @@ def get_model_compressor( quantization_format: Optional[str] = None, save_compressed: bool = True, skip_sparsity_compression_stats: bool = True, - state_dict: Optional[Dict] = None, disable_sparse_compression: bool = False, ): """ @@ -209,12 +192,8 @@ def get_model_compressor( :param save_compressed: boolean representing to save in a compressed format :param skip_sparsity_compression_stats: bool allowing compression stats on std out - :param state_dict: state_dict of the model :param disable_sparse_compression: bool to skip sparse compression """ - # find offloaded state dict if none is provided - if state_dict is None: - state_dict = get_state_dict_offloaded_model(model) if sparsity_config is None: """ @@ -242,6 +221,8 @@ def get_model_compressor( ) sparsity_config = None else: + state_dict = get_state_dict_offloaded_model(model) + sparsity_config = SparsityConfigMetadata.from_pretrained( model, state_dict=state_dict, diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py index 67197232d..e08421ee0 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune.py @@ -7,6 +7,9 @@ from parameterized import parameterized_class from transformers import AutoConfig +from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( + get_model_compressor, +) from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/finetune/finetune_oneshot_configs" @@ -34,17 +37,21 @@ def _test_oneshot_and_finetune(self): output_dir=self.output, ) - train_args = dict( - num_train_epochs=self.num_train_epochs, - precision="bfloat16", - bf16=True, - ) oneshot_model = oneshot( model=self.model, **oneshot_args, stage="test_oneshot_stage", ) + compressor = get_model_compressor(model=oneshot_model, save_compressed=True) + if compressor is not None: + compressor.decompress_model(oneshot_model) + + train_args = dict( + num_train_epochs=self.num_train_epochs, + precision="bfloat16", + bf16=True, + ) train( model=oneshot_model, **oneshot_args, diff --git a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py index c1e476306..dd10613ee 100644 --- a/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py +++ b/tests/llmcompressor/transformers/finetune/test_oneshot_and_finetune_with_tokenizer.py @@ -55,7 +55,6 @@ def test_oneshot_and_finetune_with_tokenizer(self): concatenate_data=concatenate_data, splits=splits, tokenizer=tokenizer, - output_dir=self.output, ) oneshot_model = oneshot( @@ -70,6 +69,7 @@ def test_oneshot_and_finetune_with_tokenizer(self): max_steps=max_steps, stage="test_train_stage", **model_and_data_kwargs, + output_dir=self.output, ) input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to( diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index d9a6d0d31..8dd1a2cf5 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -4,7 +4,7 @@ import pytest import torch -from accelerate import cpu_offload +from accelerate import dispatch_model from accelerate.accelerator import get_state_dict_offloaded_model from compressed_tensors import QUANTIZATION_CONFIG_NAME, CompressionFormat from compressed_tensors.compressors import ModelCompressor @@ -14,7 +14,7 @@ QuantizationStatus, quantize, ) -from compressed_tensors.utils import get_offloaded_device, update_prefix_dict +from compressed_tensors.utils import align_module_device, update_offload_parameter from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.utils.quantization_config import CompressedTensorsConfig @@ -234,7 +234,7 @@ def test_quant_model_reload(format, dtype, tmp_path): # technically only tie_word_embeddings=False is supported right now # setting to True is discouraged @pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device_map", + "offload,torch_dtype,tie_word_embeddings,device", [ # dtype (False, torch.float16, False, "cpu"), @@ -248,7 +248,7 @@ def test_quant_model_reload(format, dtype, tmp_path): # (True, torch.float32, True, "cpu"), # TODO: fails ], ) -def test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp_path): +def test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path): model_path = "nm-testing/llama2.c-stories15M" save_path = tmp_path / "save_path" @@ -256,18 +256,17 @@ def test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp model_path, tie_word_embeddings=tie_word_embeddings, torch_dtype=torch_dtype, - device_map=device_map, ) if offload: - model = cpu_offload(model) + model = dispatch_model(model, {"": device}, force_hooks=True) + else: + model = model.to(device) patch_tied_tensors_bug(model) modify_save_pretrained(model) model.save_pretrained(save_path, safe_serialization=True) - reloaded = AutoModelForCausalLM.from_pretrained( - save_path, torch_dtype="auto", device_map="cpu" - ) + reloaded = AutoModelForCausalLM.from_pretrained(save_path, torch_dtype="auto") model_dict = get_state_dict_offloaded_model(model) reloaded_dict = get_state_dict_offloaded_model(reloaded) @@ -278,7 +277,7 @@ def test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp @requires_gpu @pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device_map", + "offload,torch_dtype,tie_word_embeddings,device", [ (False, torch.float32, False, "cuda:0"), (True, torch.float32, False, "cuda:0"), @@ -286,14 +285,12 @@ def test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp (True, torch.float32, True, "cuda:0"), ], ) -def test_model_reload_gpu( - offload, torch_dtype, tie_word_embeddings, device_map, tmp_path -): - test_model_reload(offload, torch_dtype, tie_word_embeddings, device_map, tmp_path) +def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp_path): + test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path) @pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device_map", + "offload,torch_dtype,tie_word_embeddings,device", [ (False, torch.float16, False, "cpu"), (False, torch.float32, False, "cpu"), @@ -305,31 +302,24 @@ def test_model_reload_gpu( ], ) def test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device_map, tmp_path + offload, torch_dtype, tie_word_embeddings, device, tmp_path ): # load model model = AutoModelForCausalLM.from_pretrained( "nm-testing/llama2.c-stories15M", torch_dtype=torch_dtype, tie_word_embeddings=tie_word_embeddings, - device_map=device_map, ) patch_tied_tensors_bug(model) if offload: - model = cpu_offload(model) + model = dispatch_model(model, {"": device}, force_hooks=True) + else: + model = model.to(device) # modify lm head - with torch.no_grad(): - if offload: - model.lm_head._hf_hook.pre_forward(model.lm_head) - - model.lm_head.weight += 1 - - if offload: - device = get_offloaded_device(model.lm_head) - update_prefix_dict(model.lm_head, "weight", model.lm_head.weight.to(device)) - model.lm_head._hf_hook.post_forward(model.lm_head, None) + with torch.no_grad(), align_module_device(model.lm_head): + update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) # check that embed_tokens is not modified model_dict = get_state_dict_offloaded_model(model) @@ -343,17 +333,17 @@ def test_model_shared_tensors( @requires_gpu @pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device_map", + "offload,torch_dtype,tie_word_embeddings,device", [ (False, torch.float32, False, "cuda:0"), (False, torch.float32, True, "cuda:0"), ], ) def test_model_shared_tensors_gpu( - offload, torch_dtype, tie_word_embeddings, device_map, tmp_path + offload, torch_dtype, tie_word_embeddings, device, tmp_path ): test_model_shared_tensors( - offload, torch_dtype, tie_word_embeddings, device_map, tmp_path + offload, torch_dtype, tie_word_embeddings, device, tmp_path )