-
Notifications
You must be signed in to change notification settings - Fork 6k
[Core] fix variant-identification. #9253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6b379a9
f155ec7
3f36e59
91253e8
dd5941e
564b8b4
fdd0435
d5cad9e
c0b1ceb
247dd93
b024a6d
fdfdc5f
dcf1852
3a71ad9
ab91852
aa631c5
453bfa5
11e4b71
dbdf0f9
671038a
57382f2
ea5ecdb
a510a9b
f583dad
dc0255a
f2ab3de
10baa9d
25ac01f
bac62ac
b6794ed
fcb4e39
4c0c5d2
0b1c2a6
8ad6b23
1190f7d
59cfefb
d72f5c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -50,15 +50,14 @@ | |||
DEPRECATED_REVISION_ARGS, | ||||
BaseOutput, | ||||
PushToHubMixin, | ||||
deprecate, | ||||
is_accelerate_available, | ||||
is_accelerate_version, | ||||
is_torch_npu_available, | ||||
is_torch_version, | ||||
logging, | ||||
numpy_to_pil, | ||||
) | ||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card | ||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card | ||||
from ..utils.torch_utils import is_compiled_module | ||||
|
||||
|
||||
|
@@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
else: | ||||
cached_folder = pretrained_model_name_or_path | ||||
|
||||
# The variant filenames can have the legacy sharding checkpoint format that we check and throw | ||||
# a warning if detected. | ||||
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant): | ||||
warn_msg = ( | ||||
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " | ||||
"Please check your files carefully:\n\n" | ||||
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" | ||||
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" | ||||
"If you find any files in the deprecated format:\n" | ||||
"1. Remove all existing checkpoint files for this variant.\n" | ||||
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n" | ||||
"This will ensure you're using the most up-to-date and compatible checkpoint format." | ||||
) | ||||
logger.warning(warn_msg) | ||||
|
||||
config_dict = cls.load_config(cached_folder) | ||||
|
||||
# pop out "_ignore_files" as it is only needed for download | ||||
|
@@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |||
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` | ||||
# with variant being `"fp16"`. | ||||
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict) | ||||
if len(model_variants) == 0 and variant is not None: | ||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
raise ValueError(error_message) | ||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub | ||||
# if we load from explicit class, let's use it | ||||
|
@@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
model_info_call_error = e # save error to reraise it if model is not cached locally | ||||
|
||||
if not local_files_only: | ||||
filenames = {sibling.rfilename for sibling in info.siblings} | ||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant): | ||||
warn_msg = ( | ||||
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " | ||||
"Please check your files carefully:\n\n" | ||||
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" | ||||
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" | ||||
"If you find any files in the deprecated format:\n" | ||||
"1. Remove all existing checkpoint files for this variant.\n" | ||||
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n" | ||||
"This will ensure you're using the most up-to-date and compatible checkpoint format." | ||||
) | ||||
logger.warning(warn_msg) | ||||
|
||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|
||||
config_file = hf_hub_download( | ||||
pretrained_model_name, | ||||
cls.config_name, | ||||
|
@@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
# retrieve all folder_names that contain relevant files | ||||
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||||
|
||||
filenames = {sibling.rfilename for sibling in info.siblings} | ||||
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant) | ||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||||
pipelines = getattr(diffusers_module, "pipelines") | ||||
|
||||
|
@@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: | |||
) | ||||
|
||||
if len(variant_filenames) == 0 and variant is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's not remove this error in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not an error, though. It's a deprecation. Do we exactly want to keep it that way? If so, we will have to remove it anyway because the deprecation is supposed to expire after "0.24.0" version. Instead, we are erroring out now from
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah got it. I think this should be resolved now. WDYT about catching these errors without having to download the actual files and leveraging This could live in a future PR.
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
deprecation_message = ( | ||||
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`" | ||||
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant" | ||||
"modeling files is deprecated." | ||||
) | ||||
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False) | ||||
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||||
raise ValueError(error_message) | ||||
|
||||
# remove ignored filenames | ||||
model_filenames = set(model_filenames) - set(ignore_filenames) | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was moved up to raise error earlier in code.