Skip to content

Commit 2a38836

Browse files
sayakpaulyiyixuxu
andcommitted
[Core] fix variant-identification. (#9253)
* fix variant-idenitification. * fix variant * fix sharded variant checkpoint loading. * Apply suggestions from code review * fixes. * more fixes. * remove print. * fixes * fixes * comments * fixes * apply suggestions. * hub_utils.py * fix test * updates * fixes * fixes * Apply suggestions from code review Co-authored-by: YiYi Xu <yixu310@gmail.com> * updates. * removep patch file. --------- Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 9405b1a commit 2a38836

File tree

8 files changed

+381
-62
lines changed

8 files changed

+381
-62
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
WEIGHTS_INDEX_NAME,
3232
_add_variant,
3333
_get_model_file,
34+
deprecate,
3435
is_accelerate_available,
3536
is_torch_version,
3637
logging,
@@ -228,3 +229,67 @@ def _fetch_index_file(
228229
index_file = None
229230

230231
return index_file
232+
233+
234+
def _fetch_index_file_legacy(
235+
is_local,
236+
pretrained_model_name_or_path,
237+
subfolder,
238+
use_safetensors,
239+
cache_dir,
240+
variant,
241+
force_download,
242+
proxies,
243+
local_files_only,
244+
token,
245+
revision,
246+
user_agent,
247+
commit_hash,
248+
):
249+
if is_local:
250+
index_file = Path(
251+
pretrained_model_name_or_path,
252+
subfolder or "",
253+
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
254+
).as_posix()
255+
splits = index_file.split(".")
256+
split_index = -3 if ".cache" in index_file else -2
257+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
258+
index_file = ".".join(splits)
259+
if os.path.exists(index_file):
260+
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
261+
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
262+
index_file = Path(index_file)
263+
else:
264+
index_file = None
265+
else:
266+
if variant is not None:
267+
index_file_in_repo = Path(
268+
subfolder or "",
269+
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
270+
).as_posix()
271+
splits = index_file_in_repo.split(".")
272+
split_index = -2
273+
splits = splits[:-split_index] + [variant] + splits[-split_index:]
274+
index_file_in_repo = ".".join(splits)
275+
try:
276+
index_file = _get_model_file(
277+
pretrained_model_name_or_path,
278+
weights_name=index_file_in_repo,
279+
cache_dir=cache_dir,
280+
force_download=force_download,
281+
proxies=proxies,
282+
local_files_only=local_files_only,
283+
token=token,
284+
revision=revision,
285+
subfolder=None,
286+
user_agent=user_agent,
287+
commit_hash=commit_hash,
288+
)
289+
index_file = Path(index_file)
290+
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
291+
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
292+
except (EntryNotFoundError, EnvironmentError):
293+
index_file = None
294+
295+
return index_file

src/diffusers/models/modeling_utils.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .model_loading_utils import (
5555
_determine_device_map,
5656
_fetch_index_file,
57+
_fetch_index_file_legacy,
5758
_load_state_dict_into_model,
5859
load_model_dict_into_meta,
5960
load_state_dict,
@@ -309,11 +310,9 @@ def save_pretrained(
309310

310311
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
311312
weights_name = _add_variant(weights_name, variant)
312-
weight_name_split = weights_name.split(".")
313-
if len(weight_name_split) in [2, 3]:
314-
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
315-
else:
316-
raise ValueError(f"Invalid {weights_name} provided.")
313+
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
314+
".safetensors", "{suffix}.safetensors"
315+
)
317316

318317
os.makedirs(save_directory, exist_ok=True)
319318

@@ -624,21 +623,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
624623
is_sharded = False
625624
index_file = None
626625
is_local = os.path.isdir(pretrained_model_name_or_path)
627-
index_file = _fetch_index_file(
628-
is_local=is_local,
629-
pretrained_model_name_or_path=pretrained_model_name_or_path,
630-
subfolder=subfolder or "",
631-
use_safetensors=use_safetensors,
632-
cache_dir=cache_dir,
633-
variant=variant,
634-
force_download=force_download,
635-
proxies=proxies,
636-
local_files_only=local_files_only,
637-
token=token,
638-
revision=revision,
639-
user_agent=user_agent,
640-
commit_hash=commit_hash,
641-
)
626+
index_file_kwargs = {
627+
"is_local": is_local,
628+
"pretrained_model_name_or_path": pretrained_model_name_or_path,
629+
"subfolder": subfolder or "",
630+
"use_safetensors": use_safetensors,
631+
"cache_dir": cache_dir,
632+
"variant": variant,
633+
"force_download": force_download,
634+
"proxies": proxies,
635+
"local_files_only": local_files_only,
636+
"token": token,
637+
"revision": revision,
638+
"user_agent": user_agent,
639+
"commit_hash": commit_hash,
640+
}
641+
index_file = _fetch_index_file(**index_file_kwargs)
642+
# In case the index file was not found we still have to consider the legacy format.
643+
# this becomes applicable when the variant is not None.
644+
if variant is not None and (index_file is None or not os.path.exists(index_file)):
645+
index_file = _fetch_index_file_legacy(**index_file_kwargs)
642646
if index_file is not None and index_file.is_file():
643647
is_sharded = True
644648

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,14 @@
5050
DEPRECATED_REVISION_ARGS,
5151
BaseOutput,
5252
PushToHubMixin,
53-
deprecate,
5453
is_accelerate_available,
5554
is_accelerate_version,
5655
is_torch_npu_available,
5756
is_torch_version,
5857
logging,
5958
numpy_to_pil,
6059
)
61-
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
60+
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
6261
from ..utils.torch_utils import is_compiled_module
6362

6463

@@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
735734
else:
736735
cached_folder = pretrained_model_name_or_path
737736

737+
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
738+
# a warning if detected.
739+
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
740+
warn_msg = (
741+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
742+
"Please check your files carefully:\n\n"
743+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
744+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
745+
"If you find any files in the deprecated format:\n"
746+
"1. Remove all existing checkpoint files for this variant.\n"
747+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
748+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
749+
)
750+
logger.warning(warn_msg)
751+
738752
config_dict = cls.load_config(cached_folder)
739753

740754
# pop out "_ignore_files" as it is only needed for download
@@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
745759
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
746760
# with variant being `"fp16"`.
747761
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
762+
if len(model_variants) == 0 and variant is not None:
763+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
764+
raise ValueError(error_message)
748765

749766
# 3. Load the pipeline class, if using custom module then load it from the hub
750767
# if we load from explicit class, let's use it
@@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12511268
model_info_call_error = e # save error to reraise it if model is not cached locally
12521269

12531270
if not local_files_only:
1271+
filenames = {sibling.rfilename for sibling in info.siblings}
1272+
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
1273+
warn_msg = (
1274+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
1275+
"Please check your files carefully:\n\n"
1276+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
1277+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
1278+
"If you find any files in the deprecated format:\n"
1279+
"1. Remove all existing checkpoint files for this variant.\n"
1280+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
1281+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
1282+
)
1283+
logger.warning(warn_msg)
1284+
1285+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1286+
12541287
config_file = hf_hub_download(
12551288
pretrained_model_name,
12561289
cls.config_name,
@@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12671300
# retrieve all folder_names that contain relevant files
12681301
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
12691302

1270-
filenames = {sibling.rfilename for sibling in info.siblings}
1271-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
1272-
12731303
diffusers_module = importlib.import_module(__name__.split(".")[0])
12741304
pipelines = getattr(diffusers_module, "pipelines")
12751305

@@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12921322
)
12931323

12941324
if len(variant_filenames) == 0 and variant is not None:
1295-
deprecation_message = (
1296-
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1297-
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
1298-
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1299-
"modeling files is deprecated."
1300-
)
1301-
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
1325+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
1326+
raise ValueError(error_message)
13021327

13031328
# remove ignored filenames
13041329
model_filenames = set(model_filenames) - set(ignore_filenames)

src/diffusers/utils/hub_utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
271271
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
272272
if variant is not None:
273273
splits = weights_name.split(".")
274-
split_index = -2 if weights_name.endswith(".index.json") else -1
275-
splits = splits[:-split_index] + [variant] + splits[-split_index:]
274+
splits = splits[:-1] + [variant] + splits[-1:]
276275
weights_name = ".".join(splits)
277276

278277
return weights_name
@@ -502,6 +501,19 @@ def _get_checkpoint_shard_files(
502501
return cached_folder, sharded_metadata
503502

504503

504+
def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
505+
if filenames and folder:
506+
raise ValueError("Both `filenames` and `folder` cannot be provided.")
507+
if not filenames:
508+
filenames = []
509+
for _, _, files in os.walk(folder):
510+
for file in files:
511+
filenames.append(os.path.basename(file))
512+
transformers_index_format = r"\d{5}-of-\d{5}"
513+
variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
514+
return any(variant_file_re.match(f) is not None for f in filenames)
515+
516+
505517
class PushToHubMixin:
506518
"""
507519
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.

0 commit comments

Comments
 (0)