diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 969eb5f5fa37..c9eb664443b5 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -31,6 +31,7 @@ WEIGHTS_INDEX_NAME, _add_variant, _get_model_file, + deprecate, is_accelerate_available, is_torch_version, logging, @@ -228,3 +229,67 @@ def _fetch_index_file( index_file = None return index_file + + +def _fetch_index_file_legacy( + is_local, + pretrained_model_name_or_path, + subfolder, + use_safetensors, + cache_dir, + variant, + force_download, + proxies, + local_files_only, + token, + revision, + user_agent, + commit_hash, +): + if is_local: + index_file = Path( + pretrained_model_name_or_path, + subfolder or "", + SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, + ).as_posix() + splits = index_file.split(".") + split_index = -3 if ".cache" in index_file else -2 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + index_file = ".".join(splits) + if os.path.exists(index_file): + deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." + deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) + index_file = Path(index_file) + else: + index_file = None + else: + if variant is not None: + index_file_in_repo = Path( + subfolder or "", + SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, + ).as_posix() + splits = index_file_in_repo.split(".") + split_index = -2 + splits = splits[:-split_index] + [variant] + splits[-split_index:] + index_file_in_repo = ".".join(splits) + try: + index_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=index_file_in_repo, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=None, + user_agent=user_agent, + commit_hash=commit_hash, + ) + index_file = Path(index_file) + deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." + deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False) + except (EntryNotFoundError, EnvironmentError): + index_file = None + + return index_file diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9e0c50e8b37b..ad3433889fca 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -54,6 +54,7 @@ from .model_loading_utils import ( _determine_device_map, _fetch_index_file, + _fetch_index_file_legacy, _load_state_dict_into_model, load_model_dict_into_meta, load_state_dict, @@ -309,11 +310,9 @@ def save_pretrained( weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) - weight_name_split = weights_name.split(".") - if len(weight_name_split) in [2, 3]: - weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:]) - else: - raise ValueError(f"Invalid {weights_name} provided.") + weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( + ".safetensors", "{suffix}.safetensors" + ) os.makedirs(save_directory, exist_ok=True) @@ -624,21 +623,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P is_sharded = False index_file = None is_local = os.path.isdir(pretrained_model_name_or_path) - index_file = _fetch_index_file( - is_local=is_local, - pretrained_model_name_or_path=pretrained_model_name_or_path, - subfolder=subfolder or "", - use_safetensors=use_safetensors, - cache_dir=cache_dir, - variant=variant, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - user_agent=user_agent, - commit_hash=commit_hash, - ) + index_file_kwargs = { + "is_local": is_local, + "pretrained_model_name_or_path": pretrained_model_name_or_path, + "subfolder": subfolder or "", + "use_safetensors": use_safetensors, + "cache_dir": cache_dir, + "variant": variant, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "user_agent": user_agent, + "commit_hash": commit_hash, + } + index_file = _fetch_index_file(**index_file_kwargs) + # In case the index file was not found we still have to consider the legacy format. + # this becomes applicable when the variant is not None. + if variant is not None and (index_file is None or not os.path.exists(index_file)): + index_file = _fetch_index_file_legacy(**index_file_kwargs) if index_file is not None and index_file.is_file(): is_sharded = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ccd1c9485d0e..6721706b5689 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -50,7 +50,6 @@ DEPRECATED_REVISION_ARGS, BaseOutput, PushToHubMixin, - deprecate, is_accelerate_available, is_accelerate_version, is_torch_npu_available, @@ -58,7 +57,7 @@ logging, numpy_to_pil, ) -from ..utils.hub_utils import load_or_create_model_card, populate_model_card +from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.torch_utils import is_compiled_module @@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: cached_folder = pretrained_model_name_or_path + # The variant filenames can have the legacy sharding checkpoint format that we check and throw + # a warning if detected. + if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) + config_dict = cls.load_config(cached_folder) # pop out "_ignore_files" as it is only needed for download @@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` # with variant being `"fp16"`. model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) + if len(model_variants) == 0 and variant is not None: + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it @@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: model_info_call_error = e # save error to reraise it if model is not cached locally if not local_files_only: + filenames = {sibling.rfilename for sibling in info.siblings} + if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): + warn_msg = ( + f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " + "Please check your files carefully:\n\n" + "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" + "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" + "If you find any files in the deprecated format:\n" + "1. Remove all existing checkpoint files for this variant.\n" + "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" + "This will ensure you're using the most up-to-date and compatible checkpoint format." + ) + logger.warning(warn_msg) + + model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) + config_file = hf_hub_download( pretrained_model_name, cls.config_name, @@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # retrieve all folder_names that contain relevant files folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] - filenames = {sibling.rfilename for sibling in info.siblings} - model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) - diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines") @@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ) if len(variant_filenames) == 0 and variant is not None: - deprecation_message = ( - f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" - "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" - "modeling files is deprecated." - ) - deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 1cdc02e87328..a79c6cdbfed8 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -271,8 +271,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str] def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: splits = weights_name.split(".") - split_index = -2 if weights_name.endswith(".index.json") else -1 - splits = splits[:-split_index] + [variant] + splits[-split_index:] + splits = splits[:-1] + [variant] + splits[-1:] weights_name = ".".join(splits) return weights_name @@ -502,6 +501,19 @@ def _get_checkpoint_shard_files( return cached_folder, sharded_metadata +def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None): + if filenames and folder: + raise ValueError("Both `filenames` and `folder` cannot be provided.") + if not filenames: + filenames = [] + for _, _, files in os.walk(folder): + for file in files: + filenames.append(os.path.basename(file)) + transformers_index_format = r"\d{5}-of-\d{5}" + variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$") + return any(variant_file_re.match(f) is not None for f in filenames) + + class PushToHubMixin: """ A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub. diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index b56ac233ef29..5548fdd0723d 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -27,8 +27,9 @@ import requests_mock import torch from accelerate.utils import compute_module_sizes -from huggingface_hub import ModelCard, delete_repo +from huggingface_hub import ModelCard, delete_repo, snapshot_download from huggingface_hub.utils import is_jinja_available +from parameterized import parameterized from requests.exceptions import HTTPError from diffusers.models import UNet2DConditionModel @@ -39,7 +40,13 @@ XFormersAttnProcessor, ) from diffusers.training_utils import EMAModel -from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, is_torch_npu_available, is_xformers_available, logging +from diffusers.utils import ( + SAFE_WEIGHTS_INDEX_NAME, + WEIGHTS_INDEX_NAME, + is_torch_npu_available, + is_xformers_available, + logging, +) from diffusers.utils.hub_utils import _add_variant from diffusers.utils.testing_utils import ( CaptureLogger, @@ -100,6 +107,52 @@ def test_accelerate_loading_error_message(self): # make sure that error message states what keys are missing assert "conv_out.bias" in str(error_context.exception) + @parameterized.expand( + [ + ("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", False), + ("hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds", "unet", True), + ("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, False), + ("hf-internal-testing/tiny-sd-unet-with-sharded-ckpt", None, True), + ] + ) + def test_variant_sharded_ckpt_legacy_format_raises_warning(self, repo_id, subfolder, use_local): + def load_model(path): + kwargs = {"variant": "fp16"} + if subfolder: + kwargs["subfolder"] = subfolder + return UNet2DConditionModel.from_pretrained(path, **kwargs) + + with self.assertWarns(FutureWarning) as warning: + if use_local: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = snapshot_download(repo_id=repo_id) + _ = load_model(tmpdirname) + else: + _ = load_model(repo_id) + + warning_message = str(warning.warnings[0].message) + self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message) + + # Local tests are already covered down below. + @parameterized.expand( + [ + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", None, "fp16"), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "unet", "fp16"), + ("hf-internal-testing/tiny-sd-unet-sharded-no-variants", None, None), + ("hf-internal-testing/tiny-sd-unet-sharded-no-variants-subfolder", "unet", None), + ] + ) + def test_variant_sharded_ckpt_loads_from_hub(self, repo_id, subfolder, variant=None): + def load_model(): + kwargs = {} + if variant: + kwargs["variant"] = variant + if subfolder: + kwargs["subfolder"] = subfolder + return UNet2DConditionModel.from_pretrained(repo_id, **kwargs) + + assert load_model() + def test_cached_files_are_used_when_no_internet(self): # A mock response for an HTTP head request to emulate server down response_mock = mock.Mock() @@ -924,6 +977,7 @@ def test_sharded_checkpoints_with_variant(self): # testing if loading works with the variant when the checkpoint is sharded should be # enough. model.cpu().save_pretrained(tmp_dir, max_shard_size=f"{max_shard_size}KB", variant=variant) + index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_filename))) @@ -976,6 +1030,44 @@ def test_sharded_checkpoints_device_map(self): new_output = new_model(**inputs_dict) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) + # This test is okay without a GPU because we're not running any execution. We're just serializing + # and check if the resultant files are following an expected format. + def test_variant_sharded_ckpt_right_format(self): + for use_safe in [True, False]: + extension = ".safetensors" if use_safe else ".bin" + config, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**config).eval() + + model_size = compute_module_sizes(model)[""] + max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small. + variant = "fp16" + with tempfile.TemporaryDirectory() as tmp_dir: + model.cpu().save_pretrained( + tmp_dir, variant=variant, max_shard_size=f"{max_shard_size}KB", safe_serialization=use_safe + ) + index_variant = _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safe else WEIGHTS_INDEX_NAME, variant) + self.assertTrue(os.path.exists(os.path.join(tmp_dir, index_variant))) + + # Now check if the right number of shards exists. First, let's get the number of shards. + # Since this number can be dependent on the model being tested, it's important that we calculate it + # instead of hardcoding it. + expected_num_shards = caculate_expected_num_shards(os.path.join(tmp_dir, index_variant)) + actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(extension)]) + self.assertTrue(actual_num_shards == expected_num_shards) + + # Check if the variant is present as a substring in the checkpoints. + shard_files = [ + file + for file in os.listdir(tmp_dir) + if file.endswith(extension) or ("index" in file and "json" in file) + ] + assert all(variant in f for f in shard_files) + + # Check if the sharded checkpoints were serialized in the right format. + shard_files = [file for file in os.listdir(tmp_dir) if file.endswith(extension)] + # Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors + assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files) + @is_staging_test class ModelPushToHubTester(unittest.TestCase): diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index f354950b6075..37d55cedeb28 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1036,9 +1036,15 @@ def test_ip_adapter_plus(self): assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) @require_torch_gpu - def test_load_sharded_checkpoint_from_hub(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy") + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) @@ -1046,11 +1052,15 @@ def test_load_sharded_checkpoint_from_hub(self): assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_gpu - def test_load_sharded_checkpoint_from_hub_subfolder(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained( - "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet" - ) + loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) @@ -1080,20 +1090,30 @@ def test_load_sharded_checkpoint_from_hub_local_subfolder(self): assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_gpu - def test_load_sharded_checkpoint_device_map_from_hub(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained("hf-internal-testing/unet2d-sharded-dummy", device_map="auto") + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto") new_output = loaded_model(**inputs_dict) assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) @require_torch_gpu - def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self): + @parameterized.expand( + [ + ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), + ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), + ] + ) + def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained( - "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map="auto" - ) + loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto") new_output = loaded_model(**inputs_dict) assert loaded_model @@ -1121,18 +1141,6 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) - @require_torch_gpu - def test_load_sharded_checkpoint_with_variant_from_hub(self): - _, inputs_dict = self.prepare_init_args_and_inputs_for_common() - loaded_model = self.model_class.from_pretrained( - "hf-internal-testing/unet2d-sharded-with-variant-dummy", variant="fp16" - ) - loaded_model = loaded_model.to(torch_device) - new_output = loaded_model(**inputs_dict) - - assert loaded_model - assert new_output.sample.shape == (4, 4, 16, 16) - @require_peft_backend def test_lora(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index c73a12a4cbf8..8b087db6726e 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -30,6 +30,7 @@ import safetensors.torch import torch import torch.nn as nn +from huggingface_hub import snapshot_download from parameterized import parameterized from PIL import Image from requests.exceptions import HTTPError @@ -551,6 +552,50 @@ def test_download_variant_partly(self): assert sum(f.endswith(this_format) and not f.endswith(f"{variant}{this_format}") for f in files) == 3 assert not any(f.endswith(other_format) for f in files) + def test_download_variants_with_sharded_checkpoints(self): + # Here we test for downloading of "variant" files belonging to the `unet` and + # the `text_encoder`. Their checkpoints can be sharded. + for use_safetensors in [True, False]: + for variant in ["fp16", None]: + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-pipe-variants-right-format", + safety_checker=None, + cache_dir=tmpdirname, + variant=variant, + use_safetensors=use_safetensors, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # Check for `model_ext` and `variant`. + model_ext = ".safetensors" if use_safetensors else ".bin" + unexpected_ext = ".bin" if use_safetensors else ".safetensors" + model_files = [f for f in files if f.endswith(model_ext)] + assert not any(f.endswith(unexpected_ext) for f in files) + assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None) + + def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self): + repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds" + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + deprecated_warning_msg = "Warning: The repository contains sharded checkpoints for variant" + + for is_local in [True, False]: + with CaptureLogger(logger) as cap_logger: + with tempfile.TemporaryDirectory() as tmpdirname: + local_repo_id = repo_id + if is_local: + local_repo_id = snapshot_download(repo_id, cache_dir=tmpdirname) + + _ = DiffusionPipeline.from_pretrained( + local_repo_id, + safety_checker=None, + variant="fp16", + use_safetensors=True, + ) + assert deprecated_warning_msg in str(cap_logger), "Deprecation warning not found in logs" + def test_download_safetensors_only_variant_exists_for_model(self): variant = None use_safetensors = True @@ -655,7 +700,7 @@ def test_local_save_load_index(self): out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="np").images with tempfile.TemporaryDirectory() as tmpdirname: - pipe.save_pretrained(tmpdirname) + pipe.save_pretrained(tmpdirname, variant=variant, safe_serialization=use_safe) pipe_2 = StableDiffusionPipeline.from_pretrained( tmpdirname, safe_serialization=use_safe, variant=variant ) @@ -1646,7 +1691,7 @@ def test_name_or_path(self): def test_error_no_variant_available(self): variant = "fp16" with self.assertRaises(ValueError) as error_context: - _ = StableDiffusionPipeline.download( + _ = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant ) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 49da08e2ca45..3e6f9d1278e8 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1824,6 +1824,74 @@ def callback_increase_guidance(pipe, i, t, callback_kwargs): # accounts for models that modify the number of inference steps based on strength assert pipe.guidance_scale == (inputs["guidance_scale"] + pipe.num_timesteps) + def test_serialization_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + model_components = [ + component_name for component_name, component in pipe.components.items() if isinstance(component, nn.Module) + ] + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + + with open(f"{tmpdir}/model_index.json", "r") as f: + config = json.load(f) + + for subfolder in os.listdir(tmpdir): + if not os.path.isfile(subfolder) and subfolder in model_components: + folder_path = os.path.join(tmpdir, subfolder) + is_folder = os.path.isdir(folder_path) and subfolder in config + assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)) + + def test_loading_with_variants(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + variant = "fp16" + + def is_nan(tensor): + if tensor.ndimension() == 0: + has_nan = torch.isnan(tensor).item() + else: + has_nan = torch.isnan(tensor).any() + return has_nan + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, variant=variant) + + model_components_pipe = { + component_name: component + for component_name, component in pipe.components.items() + if isinstance(component, nn.Module) + } + model_components_pipe_loaded = { + component_name: component + for component_name, component in pipe_loaded.components.items() + if isinstance(component, nn.Module) + } + for component_name in model_components_pipe: + pipe_component = model_components_pipe[component_name] + pipe_loaded_component = model_components_pipe_loaded[component_name] + for p1, p2 in zip(pipe_component.parameters(), pipe_loaded_component.parameters()): + # nan check for luminanext (mps). + if not (is_nan(p1) and is_nan(p2)): + self.assertTrue(torch.equal(p1, p2)) + + def test_loading_with_incorrect_variants_raises_error(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + variant = "fp16" + + with tempfile.TemporaryDirectory() as tmpdir: + # Don't save with variants. + pipe.save_pretrained(tmpdir, safe_serialization=False) + + with self.assertRaises(ValueError) as error: + _ = self.pipeline_class.from_pretrained(tmpdir, variant=variant) + + assert f"You are trying to load the model files of the `variant={variant}`" in str(error.exception) + def test_StableDiffusionMixin_component(self): """Any pipeline that have LDMFuncMixin should have vae and unet components.""" if not issubclass(self.pipeline_class, StableDiffusionMixin):