Skip to content

[Core] fix variant-identification. #9253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6b379a9
fix variant-idenitification.
sayakpaul Aug 23, 2024
f155ec7
fix variant
sayakpaul Aug 23, 2024
3f36e59
Merge branch 'main' into variant-tests
sayakpaul Aug 23, 2024
91253e8
fix sharded variant checkpoint loading.
sayakpaul Aug 27, 2024
dd5941e
Merge branch 'main' into variant-tests
sayakpaul Aug 27, 2024
564b8b4
Apply suggestions from code review
sayakpaul Aug 27, 2024
fdd0435
Merge branch 'main' into variant-tests
sayakpaul Sep 4, 2024
d5cad9e
Merge branch 'main' into variant-tests
sayakpaul Sep 10, 2024
c0b1ceb
fixes.
sayakpaul Sep 10, 2024
247dd93
more fixes.
sayakpaul Sep 10, 2024
b024a6d
remove print.
sayakpaul Sep 10, 2024
fdfdc5f
Merge branch 'main' into variant-tests
sayakpaul Sep 11, 2024
dcf1852
Merge branch 'main' into variant-tests
yiyixuxu Sep 12, 2024
3a71ad9
fixes
sayakpaul Sep 13, 2024
ab91852
fixes
sayakpaul Sep 13, 2024
aa631c5
comments
sayakpaul Sep 13, 2024
453bfa5
fixes
sayakpaul Sep 13, 2024
11e4b71
Merge branch 'main' into variant-tests
sayakpaul Sep 13, 2024
dbdf0f9
apply suggestions.
sayakpaul Sep 14, 2024
671038a
hub_utils.py
sayakpaul Sep 14, 2024
57382f2
Merge branch 'main' into variant-tests
sayakpaul Sep 14, 2024
ea5ecdb
fix test
sayakpaul Sep 14, 2024
a510a9b
Merge branch 'main' into variant-tests
sayakpaul Sep 17, 2024
f583dad
Merge branch 'main' into variant-tests
sayakpaul Sep 18, 2024
dc0255a
updates
sayakpaul Sep 19, 2024
f2ab3de
Merge branch 'main' into variant-tests
sayakpaul Sep 21, 2024
10baa9d
Merge branch 'main' into variant-tests
sayakpaul Sep 23, 2024
25ac01f
fixes
sayakpaul Sep 23, 2024
bac62ac
Merge branch 'main' into variant-tests
sayakpaul Sep 24, 2024
b6794ed
Merge branch 'main' into variant-tests
sayakpaul Sep 25, 2024
fcb4e39
Merge branch 'main' into variant-tests
sayakpaul Sep 26, 2024
4c0c5d2
fixes
sayakpaul Sep 26, 2024
0b1c2a6
Merge branch 'main' into variant-tests
sayakpaul Sep 27, 2024
8ad6b23
Apply suggestions from code review
sayakpaul Sep 28, 2024
1190f7d
updates.
sayakpaul Sep 28, 2024
59cfefb
removep patch file.
sayakpaul Sep 28, 2024
d72f5c1
Merge branch 'main' into variant-tests
sayakpaul Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_torch_version,
logging,
Expand Down Expand Up @@ -228,3 +229,67 @@ def _fetch_index_file(
index_file = None

return index_file


def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
subfolder,
use_safetensors,
cache_dir,
variant,
force_download,
proxies,
local_files_only,
token,
revision,
user_agent,
commit_hash,
):
if is_local:
index_file = Path(
pretrained_model_name_or_path,
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file.split(".")
split_index = -3 if ".cache" in index_file else -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file = ".".join(splits)
if os.path.exists(index_file):
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
index_file = Path(index_file)
else:
index_file = None
else:
if variant is not None:
index_file_in_repo = Path(
subfolder or "",
SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME,
).as_posix()
splits = index_file_in_repo.split(".")
split_index = -2
splits = splits[:-split_index] + [variant] + splits[-split_index:]
index_file_in_repo = ".".join(splits)
try:
index_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=index_file_in_repo,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
deprecate("legacy_sharded_ckpts_with_variant", "1.0.0", deprecation_message, standard_warn=False)
except (EntryNotFoundError, EnvironmentError):
index_file = None

return index_file
44 changes: 24 additions & 20 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .model_loading_utils import (
_determine_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
Expand Down Expand Up @@ -309,11 +310,9 @@ def save_pretrained(

weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weight_name_split = weights_name.split(".")
if len(weight_name_split) in [2, 3]:
weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
else:
raise ValueError(f"Invalid {weights_name} provided.")
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)

os.makedirs(save_directory, exist_ok=True)

Expand Down Expand Up @@ -624,21 +623,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
is_sharded = False
index_file = None
is_local = os.path.isdir(pretrained_model_name_or_path)
index_file = _fetch_index_file(
is_local=is_local,
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder or "",
use_safetensors=use_safetensors,
cache_dir=cache_dir,
variant=variant,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
commit_hash=commit_hash,
)
index_file_kwargs = {
"is_local": is_local,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"subfolder": subfolder or "",
"use_safetensors": use_safetensors,
"cache_dir": cache_dir,
"variant": variant,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
# this becomes applicable when the variant is not None.
if variant is not None and (index_file is None or not os.path.exists(index_file)):
index_file = _fetch_index_file_legacy(**index_file_kwargs)
if index_file is not None and index_file.is_file():
is_sharded = True

Expand Down
49 changes: 37 additions & 12 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,14 @@
DEPRECATED_REVISION_ARGS,
BaseOutput,
PushToHubMixin,
deprecate,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
logging,
numpy_to_pil,
)
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module


Expand Down Expand Up @@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
cached_folder = pretrained_model_name_or_path

# The variant filenames can have the legacy sharding checkpoint format that we check and throw
# a warning if detected.
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)

config_dict = cls.load_config(cached_folder)

# pop out "_ignore_files" as it is only needed for download
Expand All @@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
# with variant being `"fp16"`.
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
if len(model_variants) == 0 and variant is not None:
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
Expand Down Expand Up @@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
model_info_call_error = e # save error to reraise it if model is not cached locally

if not local_files_only:
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
"Please check your files carefully:\n\n"
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
"If you find any files in the deprecated format:\n"
"1. Remove all existing checkpoint files for this variant.\n"
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
"This will ensure you're using the most up-to-date and compatible checkpoint format."
)
logger.warning(warn_msg)

model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

config_file = hf_hub_download(
pretrained_model_name,
cls.config_name,
Expand All @@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
# retrieve all folder_names that contain relevant files
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]

filenames = {sibling.rfilename for sibling in info.siblings}
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)

Comment on lines -1270 to -1272
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved up to raise error earlier in code.

diffusers_module = importlib.import_module(__name__.split(".")[0])
pipelines = getattr(diffusers_module, "pipelines")

Expand All @@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)

if len(variant_filenames) == 0 and variant is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not remove this error in download

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version.

Instead, we are erroring out now from from_pretrained():

model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah got it. I think this should be resolved now.

WDYT about catching these errors without having to download the actual files and leveraging model_info() (in case we're querying the Hub) or regular string matching (in case it's local)? Currently, we're still calling download() in case we don't have the model files cached. I think many errors can be caught and warnings can be thrown without having to do that.

This could live in a future PR.

deprecation_message = (
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
"modeling files is deprecated."
)
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
raise ValueError(error_message)

# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
Expand Down
16 changes: 14 additions & 2 deletions src/diffusers/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ def move_cache(old_cache_dir: Optional[str] = None, new_cache_dir: Optional[str]
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
split_index = -2 if weights_name.endswith(".index.json") else -1
splits = splits[:-split_index] + [variant] + splits[-split_index:]
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)

return weights_name
Expand Down Expand Up @@ -502,6 +501,19 @@ def _get_checkpoint_shard_files(
return cached_folder, sharded_metadata


def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
if filenames and folder:
raise ValueError("Both `filenames` and `folder` cannot be provided.")
if not filenames:
filenames = []
for _, _, files in os.walk(folder):
for file in files:
filenames.append(os.path.basename(file))
transformers_index_format = r"\d{5}-of-\d{5}"
variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
return any(variant_file_re.match(f) is not None for f in filenames)


class PushToHubMixin:
"""
A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
Expand Down
Loading
Loading