@@ -41,7 +41,7 @@ class Metadata:
41
41
base_models : Optional [list [dict ]] = None
42
42
tags : Optional [list [str ]] = None
43
43
languages : Optional [list [str ]] = None
44
- datasets : Optional [list [str ]] = None
44
+ datasets : Optional [list [dict ]] = None
45
45
46
46
@staticmethod
47
47
def load (metadata_override_path : Optional [Path ] = None , model_path : Optional [Path ] = None , model_name : Optional [str ] = None , total_params : int = 0 ) -> Metadata :
@@ -91,9 +91,11 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat
91
91
# Base Models is received here as an array of models
92
92
metadata .base_models = metadata_override .get ("general.base_models" , metadata .base_models )
93
93
94
+ # Datasets is received here as an array of datasets
95
+ metadata .datasets = metadata_override .get ("general.datasets" , metadata .datasets )
96
+
94
97
metadata .tags = metadata_override .get (Keys .General .TAGS , metadata .tags )
95
98
metadata .languages = metadata_override .get (Keys .General .LANGUAGES , metadata .languages )
96
- metadata .datasets = metadata_override .get (Keys .General .DATASETS , metadata .datasets )
97
99
98
100
# Direct Metadata Override (via direct cli argument)
99
101
if model_name is not None :
@@ -346,12 +348,12 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
346
348
use_model_card_metadata ("author" , "model_creator" )
347
349
use_model_card_metadata ("basename" , "model_type" )
348
350
349
- if "base_model" in model_card :
351
+ if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card :
350
352
# This represents the parent models that this is based on
351
353
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
352
354
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
353
355
metadata_base_models = []
354
- base_model_value = model_card .get ("base_model" , None )
356
+ base_model_value = model_card .get ("base_model" , model_card . get ( "base_models" , model_card . get ( "base_model_sources" , None )) )
355
357
356
358
if base_model_value is not None :
357
359
if isinstance (base_model_value , str ):
@@ -364,18 +366,106 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
364
366
365
367
for model_id in metadata_base_models :
366
368
# NOTE: model size of base model is assumed to be similar to the size of the current model
367
- model_full_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (model_id , total_params )
368
369
base_model = {}
369
- if model_full_name_component is not None :
370
- base_model ["name" ] = Metadata .id_to_title (model_full_name_component )
371
- if org_component is not None :
372
- base_model ["organization" ] = Metadata .id_to_title (org_component )
373
- if version is not None :
374
- base_model ["version" ] = version
375
- if org_component is not None and model_full_name_component is not None :
376
- base_model ["repo_url" ] = f"https://huggingface.co/{ org_component } /{ model_full_name_component } "
370
+ if isinstance (model_id , str ):
371
+ if model_id .startswith ("http://" ) or model_id .startswith ("https://" ) or model_id .startswith ("ssh://" ):
372
+ base_model ["repo_url" ] = model_id
373
+
374
+ # Check if Hugging Face ID is present in URL
375
+ if "huggingface.co" in model_id :
376
+ match = re .match (r"https?://huggingface.co/([^/]+/[^/]+)$" , model_id )
377
+ if match :
378
+ model_id_component = match .group (1 )
379
+ model_full_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (model_id_component , total_params )
380
+
381
+ # Populate model dictionary with extracted components
382
+ if model_full_name_component is not None :
383
+ base_model ["name" ] = Metadata .id_to_title (model_full_name_component )
384
+ if org_component is not None :
385
+ base_model ["organization" ] = Metadata .id_to_title (org_component )
386
+ if version is not None :
387
+ base_model ["version" ] = version
388
+
389
+ else :
390
+ # Likely a Hugging Face ID
391
+ model_full_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (model_id , total_params )
392
+
393
+ # Populate model dictionary with extracted components
394
+ if model_full_name_component is not None :
395
+ base_model ["name" ] = Metadata .id_to_title (model_full_name_component )
396
+ if org_component is not None :
397
+ base_model ["organization" ] = Metadata .id_to_title (org_component )
398
+ if version is not None :
399
+ base_model ["version" ] = version
400
+ if org_component is not None and model_full_name_component is not None :
401
+ base_model ["repo_url" ] = f"https://huggingface.co/{ org_component } /{ model_full_name_component } "
402
+
403
+ elif isinstance (model_id , dict ):
404
+ base_model = model_id
405
+
406
+ else :
407
+ logger .error (f"base model entry '{ str (model_id )} ' not in a known format" )
408
+
377
409
metadata .base_models .append (base_model )
378
410
411
+ if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card :
412
+ # This represents the datasets that this was trained from
413
+ metadata_datasets = []
414
+ dataset_value = model_card .get ("datasets" , model_card .get ("dataset" , model_card .get ("dataset_sources" , None )))
415
+
416
+ if dataset_value is not None :
417
+ if isinstance (dataset_value , str ):
418
+ metadata_datasets .append (dataset_value )
419
+ elif isinstance (dataset_value , list ):
420
+ metadata_datasets .extend (dataset_value )
421
+
422
+ if metadata .datasets is None :
423
+ metadata .datasets = []
424
+
425
+ for dataset_id in metadata_datasets :
426
+ # NOTE: model size of base model is assumed to be similar to the size of the current model
427
+ dataset = {}
428
+ if isinstance (dataset_id , str ):
429
+ if dataset_id .startswith (("http://" , "https://" , "ssh://" )):
430
+ dataset ["repo_url" ] = dataset_id
431
+
432
+ # Check if Hugging Face ID is present in URL
433
+ if "huggingface.co" in dataset_id :
434
+ match = re .match (r"https?://huggingface.co/([^/]+/[^/]+)$" , dataset_id )
435
+ if match :
436
+ dataset_id_component = match .group (1 )
437
+ dataset_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (dataset_id_component , total_params )
438
+
439
+ # Populate dataset dictionary with extracted components
440
+ if dataset_name_component is not None :
441
+ dataset ["name" ] = Metadata .id_to_title (dataset_name_component )
442
+ if org_component is not None :
443
+ dataset ["organization" ] = Metadata .id_to_title (org_component )
444
+ if version is not None :
445
+ dataset ["version" ] = version
446
+
447
+ else :
448
+ # Likely a Hugging Face ID
449
+ dataset_name_component , org_component , basename , finetune , version , size_label = Metadata .get_model_id_components (dataset_id , total_params )
450
+
451
+ # Populate dataset dictionary with extracted components
452
+ if dataset_name_component is not None :
453
+ dataset ["name" ] = Metadata .id_to_title (dataset_name_component )
454
+ if org_component is not None :
455
+ dataset ["organization" ] = Metadata .id_to_title (org_component )
456
+ if version is not None :
457
+ dataset ["version" ] = version
458
+ if org_component is not None and dataset_name_component is not None :
459
+ dataset ["repo_url" ] = f"https://huggingface.co/{ org_component } /{ dataset_name_component } "
460
+
461
+ elif isinstance (dataset_id , dict ):
462
+ dataset = dataset_id
463
+
464
+ else :
465
+ logger .error (f"dataset entry '{ str (dataset_id )} ' not in a known format" )
466
+
467
+ metadata .datasets .append (dataset )
468
+
379
469
use_model_card_metadata ("license" , "license" )
380
470
use_model_card_metadata ("license_name" , "license_name" )
381
471
use_model_card_metadata ("license_link" , "license_link" )
@@ -386,9 +476,6 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
386
476
use_array_model_card_metadata ("languages" , "languages" )
387
477
use_array_model_card_metadata ("languages" , "language" )
388
478
389
- use_array_model_card_metadata ("datasets" , "datasets" )
390
- use_array_model_card_metadata ("datasets" , "dataset" )
391
-
392
479
# Hugging Face Parameter Heuristics
393
480
####################################
394
481
@@ -493,6 +580,8 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
493
580
gguf_writer .add_base_model_version (key , base_model_entry ["version" ])
494
581
if "organization" in base_model_entry :
495
582
gguf_writer .add_base_model_organization (key , base_model_entry ["organization" ])
583
+ if "description" in base_model_entry :
584
+ gguf_writer .add_base_model_description (key , base_model_entry ["description" ])
496
585
if "url" in base_model_entry :
497
586
gguf_writer .add_base_model_url (key , base_model_entry ["url" ])
498
587
if "doi" in base_model_entry :
@@ -502,9 +591,29 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
502
591
if "repo_url" in base_model_entry :
503
592
gguf_writer .add_base_model_repo_url (key , base_model_entry ["repo_url" ])
504
593
594
+ if self .datasets is not None :
595
+ gguf_writer .add_dataset_count (len (self .datasets ))
596
+ for key , dataset_entry in enumerate (self .datasets ):
597
+ if "name" in dataset_entry :
598
+ gguf_writer .add_dataset_name (key , dataset_entry ["name" ])
599
+ if "author" in dataset_entry :
600
+ gguf_writer .add_dataset_author (key , dataset_entry ["author" ])
601
+ if "version" in dataset_entry :
602
+ gguf_writer .add_dataset_version (key , dataset_entry ["version" ])
603
+ if "organization" in dataset_entry :
604
+ gguf_writer .add_dataset_organization (key , dataset_entry ["organization" ])
605
+ if "description" in dataset_entry :
606
+ gguf_writer .add_dataset_description (key , dataset_entry ["description" ])
607
+ if "url" in dataset_entry :
608
+ gguf_writer .add_dataset_url (key , dataset_entry ["url" ])
609
+ if "doi" in dataset_entry :
610
+ gguf_writer .add_dataset_doi (key , dataset_entry ["doi" ])
611
+ if "uuid" in dataset_entry :
612
+ gguf_writer .add_dataset_uuid (key , dataset_entry ["uuid" ])
613
+ if "repo_url" in dataset_entry :
614
+ gguf_writer .add_dataset_repo_url (key , dataset_entry ["repo_url" ])
615
+
505
616
if self .tags is not None :
506
617
gguf_writer .add_tags (self .tags )
507
618
if self .languages is not None :
508
619
gguf_writer .add_languages (self .languages )
509
- if self .datasets is not None :
510
- gguf_writer .add_datasets (self .datasets )
0 commit comments