Skip to content

Commit b8d0e03

Browse files
committed
py: extract model id component for base model and datasets if using huggingface url
1 parent 1d295cf commit b8d0e03

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

gguf-py/gguf/metadata.py

+38-1
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,27 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
370370
if isinstance(model_id, str):
371371
if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"):
372372
base_model["repo_url"] = model_id
373+
374+
# Check if Hugging Face ID is present in URL
375+
if "huggingface.co" in model_id:
376+
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id)
377+
if match:
378+
model_id_component = match.group(1)
379+
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params)
380+
381+
# Populate model dictionary with extracted components
382+
if model_full_name_component is not None:
383+
base_model["name"] = Metadata.id_to_title(model_full_name_component)
384+
if org_component is not None:
385+
base_model["organization"] = Metadata.id_to_title(org_component)
386+
if version is not None:
387+
base_model["version"] = version
388+
373389
else:
374390
# Likely a Hugging Face ID
375391
model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params)
392+
393+
# Populate model dictionary with extracted components
376394
if model_full_name_component is not None:
377395
base_model["name"] = Metadata.id_to_title(model_full_name_component)
378396
if org_component is not None:
@@ -405,11 +423,29 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
405423
# NOTE: model size of base model is assumed to be similar to the size of the current model
406424
dataset = {}
407425
if isinstance(dataset_id, str):
408-
if dataset_id.startswith("http://") or dataset_id.startswith("https://") or dataset_id.startswith("ssh://"):
426+
if dataset_id.startswith(("http://", "https://", "ssh://")):
409427
dataset["repo_url"] = dataset_id
428+
429+
# Check if Hugging Face ID is present in URL
430+
if "huggingface.co" in dataset_id:
431+
match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id)
432+
if match:
433+
dataset_id_component = match.group(1)
434+
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params)
435+
436+
# Populate dataset dictionary with extracted components
437+
if dataset_name_component is not None:
438+
dataset["name"] = Metadata.id_to_title(dataset_name_component)
439+
if org_component is not None:
440+
dataset["organization"] = Metadata.id_to_title(org_component)
441+
if version is not None:
442+
dataset["version"] = version
443+
410444
else:
411445
# Likely a Hugging Face ID
412446
dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params)
447+
448+
# Populate dataset dictionary with extracted components
413449
if dataset_name_component is not None:
414450
dataset["name"] = Metadata.id_to_title(dataset_name_component)
415451
if org_component is not None:
@@ -418,6 +454,7 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
418454
dataset["version"] = version
419455
if org_component is not None and dataset_name_component is not None:
420456
dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}"
457+
421458
elif isinstance(dataset_id, dict):
422459
dataset = dataset_id
423460
else:

gguf-py/tests/test_metadata.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,14 @@ def test_apply_metadata_heuristic_from_model_card(self):
186186
self.assertEqual(got, expect)
187187

188188
# Base Model spec is inferred from model id
189-
model_card = {'base_models': ['teknium/OpenHermes-2.5']}
189+
model_card = {'base_models': 'teknium/OpenHermes-2.5'}
190190
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
191191
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
192192
self.assertEqual(got, expect)
193193

194194
# Base Model spec is only url
195195
model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']}
196-
expect = gguf.Metadata(base_models=[{'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
196+
expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
197197
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
198198
self.assertEqual(got, expect)
199199

@@ -204,14 +204,14 @@ def test_apply_metadata_heuristic_from_model_card(self):
204204
self.assertEqual(got, expect)
205205

206206
# Dataset spec is inferred from model id
207-
model_card = {'datasets': ['teknium/OpenHermes-2.5']}
207+
model_card = {'datasets': 'teknium/OpenHermes-2.5'}
208208
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
209209
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
210210
self.assertEqual(got, expect)
211211

212212
# Dataset spec is only url
213213
model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']}
214-
expect = gguf.Metadata(datasets=[{'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
214+
expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}])
215215
got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None)
216216
self.assertEqual(got, expect)
217217

0 commit comments

Comments
 (0)