50
50
DEPRECATED_REVISION_ARGS ,
51
51
BaseOutput ,
52
52
PushToHubMixin ,
53
- deprecate ,
54
53
is_accelerate_available ,
55
54
is_accelerate_version ,
56
55
is_torch_npu_available ,
57
56
is_torch_version ,
58
57
logging ,
59
58
numpy_to_pil ,
60
59
)
61
- from ..utils .hub_utils import load_or_create_model_card , populate_model_card
60
+ from ..utils .hub_utils import _check_legacy_sharding_variant_format , load_or_create_model_card , populate_model_card
62
61
from ..utils .torch_utils import is_compiled_module
63
62
64
63
@@ -735,6 +734,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
735
734
else :
736
735
cached_folder = pretrained_model_name_or_path
737
736
737
+ # The variant filenames can have the legacy sharding checkpoint format that we check and throw
738
+ # a warning if detected.
739
+ if variant is not None and _check_legacy_sharding_variant_format (folder = cached_folder , variant = variant ):
740
+ warn_msg = (
741
+ f"Warning: The repository contains sharded checkpoints for variant '{ variant } ' maybe in a deprecated format. "
742
+ "Please check your files carefully:\n \n "
743
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n "
744
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n \n "
745
+ "If you find any files in the deprecated format:\n "
746
+ "1. Remove all existing checkpoint files for this variant.\n "
747
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n \n "
748
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
749
+ )
750
+ logger .warning (warn_msg )
751
+
738
752
config_dict = cls .load_config (cached_folder )
739
753
740
754
# pop out "_ignore_files" as it is only needed for download
@@ -745,6 +759,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
745
759
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
746
760
# with variant being `"fp16"`.
747
761
model_variants = _identify_model_variants (folder = cached_folder , variant = variant , config = config_dict )
762
+ if len (model_variants ) == 0 and variant is not None :
763
+ error_message = f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
764
+ raise ValueError (error_message )
748
765
749
766
# 3. Load the pipeline class, if using custom module then load it from the hub
750
767
# if we load from explicit class, let's use it
@@ -1251,6 +1268,22 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1251
1268
model_info_call_error = e # save error to reraise it if model is not cached locally
1252
1269
1253
1270
if not local_files_only :
1271
+ filenames = {sibling .rfilename for sibling in info .siblings }
1272
+ if variant is not None and _check_legacy_sharding_variant_format (filenames = filenames , variant = variant ):
1273
+ warn_msg = (
1274
+ f"Warning: The repository contains sharded checkpoints for variant '{ variant } ' maybe in a deprecated format. "
1275
+ "Please check your files carefully:\n \n "
1276
+ "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n "
1277
+ "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n \n "
1278
+ "If you find any files in the deprecated format:\n "
1279
+ "1. Remove all existing checkpoint files for this variant.\n "
1280
+ "2. Re-obtain the correct files by running `save_pretrained()`.\n \n "
1281
+ "This will ensure you're using the most up-to-date and compatible checkpoint format."
1282
+ )
1283
+ logger .warning (warn_msg )
1284
+
1285
+ model_filenames , variant_filenames = variant_compatible_siblings (filenames , variant = variant )
1286
+
1254
1287
config_file = hf_hub_download (
1255
1288
pretrained_model_name ,
1256
1289
cls .config_name ,
@@ -1267,9 +1300,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1267
1300
# retrieve all folder_names that contain relevant files
1268
1301
folder_names = [k for k , v in config_dict .items () if isinstance (v , list ) and k != "_class_name" ]
1269
1302
1270
- filenames = {sibling .rfilename for sibling in info .siblings }
1271
- model_filenames , variant_filenames = variant_compatible_siblings (filenames , variant = variant )
1272
-
1273
1303
diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
1274
1304
pipelines = getattr (diffusers_module , "pipelines" )
1275
1305
@@ -1292,13 +1322,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
1292
1322
)
1293
1323
1294
1324
if len (variant_filenames ) == 0 and variant is not None :
1295
- deprecation_message = (
1296
- f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
1297
- f"The default model files: { model_filenames } will be loaded instead. Make sure to not load from `variant={ variant } `"
1298
- "if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
1299
- "modeling files is deprecated."
1300
- )
1301
- deprecate ("no variant default" , "0.24.0" , deprecation_message , standard_warn = False )
1325
+ error_message = f"You are trying to load the model files of the `variant={ variant } `, but no such modeling files are available."
1326
+ raise ValueError (error_message )
1302
1327
1303
1328
# remove ignored filenames
1304
1329
model_filenames = set (model_filenames ) - set (ignore_filenames )
0 commit comments