Skip to content

Commit c39bcdf

Browse files
authored
Re-enable classification API as fallback (#8007)
## Summary - Fallback to new classification API if legacy probe fails - Method to read model metadata - Created `StrippedModelOnDisk` class for testing - Test to verify only a single config `matches` with a model ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_ - [ ] _Updated `What's New` copy (if doing a release after this PR)_
2 parents 19ecdb1 + 32f2223 commit c39bcdf

File tree

6 files changed

+91
-38
lines changed

6 files changed

+91
-38
lines changed

invokeai/app/services/model_install/model_install_default.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AnyModelConfig,
3939
CheckpointConfigBase,
4040
InvalidModelConfigException,
41+
ModelConfigBase,
4142
)
4243
from invokeai.backend.model_manager.legacy_probe import ModelProbe
4344
from invokeai.backend.model_manager.metadata import (
@@ -646,14 +647,18 @@ def _probe(self, model_path: Path, config: Optional[ModelRecordChanges] = None):
646647
hash_algo = self._app_config.hashing_algorithm
647648
fields = config.model_dump()
648649

649-
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo)
650-
651-
# New model probe API is disabled pending resolution of issue caused by a change of the ordering of checks.
652-
# See commit message for details.
653-
# try:
654-
# return ModelConfigBase.classify(model_path=model_path, hash_algo=hash_algo, **fields)
655-
# except InvalidModelConfigException:
656-
# return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
650+
# WARNING!
651+
# The legacy probe relies on the implicit order of tests to determine model classification.
652+
# This can lead to regressions between the legacy and new probes.
653+
# Do NOT change the order of `probe` and `classify` without implementing one of the following fixes:
654+
# Short-term fix: `classify` tests `matches` in the same order as the legacy probe.
655+
# Long-term fix: Improve `matches` to be more specific so that only one config matches
656+
# any given model - eliminating ambiguity and removing reliance on order.
657+
# After implementing either of these fixes, remove @pytest.mark.xfail from `test_regression_against_model_probe`
658+
try:
659+
return ModelProbe.probe(model_path=model_path, fields=fields, hash_algo=hash_algo) # type: ignore
660+
except InvalidModelConfigException:
661+
return ModelConfigBase.classify(model_path, hash_algo, **fields)
657662

658663
def _register(
659664
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None

invokeai/backend/model_manager/config.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,33 +146,35 @@ def json_schema_extra(schema: dict[str, Any]) -> None:
146146
)
147147
usage_info: Optional[str] = Field(default=None, description="Usage information for this model")
148148

149-
_USING_LEGACY_PROBE: ClassVar[set] = set()
150-
_USING_CLASSIFY_API: ClassVar[set] = set()
149+
USING_LEGACY_PROBE: ClassVar[set] = set()
150+
USING_CLASSIFY_API: ClassVar[set] = set()
151151
_MATCH_SPEED: ClassVar[MatchSpeed] = MatchSpeed.MED
152152

153153
def __init_subclass__(cls, **kwargs):
154154
super().__init_subclass__(**kwargs)
155155
if issubclass(cls, LegacyProbeMixin):
156-
ModelConfigBase._USING_LEGACY_PROBE.add(cls)
156+
ModelConfigBase.USING_LEGACY_PROBE.add(cls)
157157
else:
158-
ModelConfigBase._USING_CLASSIFY_API.add(cls)
158+
ModelConfigBase.USING_CLASSIFY_API.add(cls)
159159

160160
@staticmethod
161161
def all_config_classes():
162-
subclasses = ModelConfigBase._USING_LEGACY_PROBE | ModelConfigBase._USING_CLASSIFY_API
162+
subclasses = ModelConfigBase.USING_LEGACY_PROBE | ModelConfigBase.USING_CLASSIFY_API
163163
concrete = {cls for cls in subclasses if not isabstract(cls)}
164164
return concrete
165165

166166
@staticmethod
167-
def classify(model_path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
167+
def classify(mod: str | Path | ModelOnDisk, hash_algo: HASHING_ALGORITHMS = "blake3_single", **overrides):
168168
"""
169169
Returns the best matching ModelConfig instance from a model's file/folder path.
170170
Raises InvalidModelConfigException if no valid configuration is found.
171171
Created to deprecate ModelProbe.probe
172172
"""
173-
candidates = ModelConfigBase._USING_CLASSIFY_API
173+
if isinstance(mod, Path | str):
174+
mod = ModelOnDisk(mod, hash_algo)
175+
176+
candidates = ModelConfigBase.USING_CLASSIFY_API
174177
sorted_by_match_speed = sorted(candidates, key=lambda cls: (cls._MATCH_SPEED, cls.__name__))
175-
mod = ModelOnDisk(model_path, hash_algo)
176178

177179
for config_cls in sorted_by_match_speed:
178180
try:

invokeai/backend/model_manager/model_on_disk.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import safetensors.torch
55
import torch
66
from picklescan.scanner import scan_file_path
7+
from safetensors import safe_open
78

89
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
910
from invokeai.backend.model_manager.taxonomy import ModelRepoVariant
@@ -35,12 +36,21 @@ def size(self) -> int:
3536
return self.path.stat().st_size
3637
return sum(file.stat().st_size for file in self.path.rglob("*"))
3738

38-
def component_paths(self) -> set[Path]:
39+
def weight_files(self) -> set[Path]:
3940
if self.path.is_file():
4041
return {self.path}
4142
extensions = {".safetensors", ".pt", ".pth", ".ckpt", ".bin", ".gguf"}
4243
return {f for f in self.path.rglob("*") if f.suffix in extensions}
4344

45+
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
46+
try:
47+
with safe_open(self.path, framework="pt", device="cpu") as f:
48+
metadata = f.metadata()
49+
assert isinstance(metadata, dict)
50+
return metadata
51+
except Exception:
52+
return {}
53+
4454
def repo_variant(self) -> Optional[ModelRepoVariant]:
4555
if self.path.is_file():
4656
return None
@@ -64,18 +74,7 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
6474
if path in sd_cache:
6575
return sd_cache[path]
6676

67-
if not path:
68-
components = list(self.component_paths())
69-
match components:
70-
case []:
71-
raise ValueError("No weight files found for this model")
72-
case [p]:
73-
path = p
74-
case ps if len(ps) >= 2:
75-
raise ValueError(
76-
f"Multiple weight files found for this model: {ps}. "
77-
f"Please specify the intended file using the 'path' argument"
78-
)
77+
path = self.resolve_weight_file(path)
7978

8079
with SilenceWarnings():
8180
if path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")):
@@ -94,3 +93,18 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
9493
state_dict = checkpoint.get("state_dict", checkpoint)
9594
sd_cache[path] = state_dict
9695
return state_dict
96+
97+
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:
98+
if not path:
99+
weight_files = list(self.weight_files())
100+
match weight_files:
101+
case []:
102+
raise ValueError("No weight files found for this model")
103+
case [p]:
104+
return p
105+
case ps if len(ps) >= 2:
106+
raise ValueError(
107+
f"Multiple weight files found for this model: {ps}. "
108+
f"Please specify the intended file using the 'path' argument"
109+
)
110+
return path

scripts/classify-model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
def classify_with_fallback(path: Path, hash_algo: HASHING_ALGORITHMS):
3030
try:
31-
return ModelConfigBase.classify(path, hash_algo)
32-
except InvalidModelConfigException:
3331
return ModelProbe.probe(path, hash_algo=hash_algo)
32+
except InvalidModelConfigException:
33+
return ModelConfigBase.classify(path, hash_algo)
3434

3535

3636
for path in args.model_path:

scripts/strip_models.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
import shutil
1919
import sys
2020
from pathlib import Path
21+
from typing import Optional
2122

2223
import humanize
2324
import torch
2425

25-
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
26+
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk, StateDict
2627
from invokeai.backend.model_manager.search import ModelSearch
2728

29+
METADATA_KEY = "metadata_key_for_stripped_models"
30+
2831

2932
def strip(v):
3033
match v:
@@ -57,9 +60,22 @@ def dress(v):
5760
def load_stripped_model(path: Path, *args, **kwargs):
5861
with open(path, "r") as f:
5962
contents = json.load(f)
63+
contents.pop(METADATA_KEY, None)
6064
return dress(contents)
6165

6266

67+
class StrippedModelOnDisk(ModelOnDisk):
68+
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
69+
path = self.resolve_weight_file(path)
70+
return load_stripped_model(path)
71+
72+
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
73+
path = self.resolve_weight_file(path)
74+
with open(path, "r") as f:
75+
contents = json.load(f)
76+
return contents.get(METADATA_KEY, {})
77+
78+
6379
def create_stripped_model(original_model_path: Path, stripped_model_path: Path) -> ModelOnDisk:
6480
original = ModelOnDisk(original_model_path)
6581
if original.path.is_file():
@@ -69,11 +85,14 @@ def create_stripped_model(original_model_path: Path, stripped_model_path: Path)
6985
stripped = ModelOnDisk(stripped_model_path)
7086
print(f"Created clone of {original.name} at {stripped.path}")
7187

72-
for component_path in stripped.component_paths():
88+
for component_path in stripped.weight_files():
7389
original_state_dict = stripped.load_state_dict(component_path)
90+
7491
stripped_state_dict = strip(original_state_dict) # type: ignore
92+
metadata = stripped.metadata()
93+
contents = {**stripped_state_dict, METADATA_KEY: metadata}
7594
with open(component_path, "w") as f:
76-
json.dump(stripped_state_dict, f, indent=4)
95+
json.dump(contents, f, indent=4)
7796

7897
before_size = humanize.naturalsize(original.size())
7998
after_size = humanize.naturalsize(stripped.size())

tests/test_model_probe.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
3030
from invokeai.backend.model_manager.search import ModelSearch
3131
from invokeai.backend.util.logging import InvokeAILogger
32+
from scripts.strip_models import StrippedModelOnDisk
33+
34+
logger = InvokeAILogger.get_logger(__file__)
3235

3336

3437
@pytest.mark.parametrize(
@@ -156,7 +159,8 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
156159
pass
157160

158161
try:
159-
new_config = ModelConfigBase.classify(path, hash=fake_hash, key=fake_key)
162+
stripped_mod = StrippedModelOnDisk(path)
163+
new_config = ModelConfigBase.classify(stripped_mod, hash=fake_hash, key=fake_key)
160164
except InvalidModelConfigException:
161165
pass
162166

@@ -165,10 +169,10 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
165169
assert legacy_config.model_dump_json() == new_config.model_dump_json()
166170

167171
elif legacy_config:
168-
assert type(legacy_config) in ModelConfigBase._USING_LEGACY_PROBE
172+
assert type(legacy_config) in ModelConfigBase.USING_LEGACY_PROBE
169173

170174
elif new_config:
171-
assert type(new_config) in ModelConfigBase._USING_CLASSIFY_API
175+
assert type(new_config) in ModelConfigBase.USING_CLASSIFY_API
172176

173177
else:
174178
raise ValueError(f"Both probe and classify failed to classify model at path {path}.")
@@ -177,7 +181,6 @@ def test_regression_against_model_probe(datadir: Path, override_model_loading):
177181
configs_with_tests.add(config_type)
178182

179183
untested_configs = ModelConfigBase.all_config_classes() - configs_with_tests - {MinimalConfigExample}
180-
logger = InvokeAILogger.get_logger(__file__)
181184
logger.warning(f"Function test_regression_against_model_probe missing test case for: {untested_configs}")
182185

183186

@@ -255,3 +258,13 @@ def test_any_model_config_includes_all_config_classes():
255258

256259
expected = set(ModelConfigBase.all_config_classes()) - {MinimalConfigExample}
257260
assert extracted == expected
261+
262+
263+
def test_config_uniquely_matches_model(datadir: Path):
264+
model_paths = ModelSearch().search(datadir / "stripped_models")
265+
for path in model_paths:
266+
mod = StrippedModelOnDisk(path)
267+
matches = {cls for cls in ModelConfigBase.USING_CLASSIFY_API if cls.matches(mod)}
268+
assert len(matches) <= 1, f"Model at path {path} matches multiple config classes: {matches}"
269+
if not matches:
270+
logger.warning(f"Model at path {path} does not match any config classes using classify API.")

0 commit comments

Comments
 (0)