1
1
#if defined(_WIN32)
2
- # include < windows.h>
3
2
# include < io.h>
3
+ # include < windows.h>
4
4
#else
5
5
# include < sys/file.h>
6
6
# include < sys/ioctl.h>
12
12
#endif
13
13
14
14
#include < signal.h>
15
+ #include < sys/stat.h>
15
16
16
17
#include < climits>
17
18
#include < cstdarg>
18
19
#include < cstdio>
19
20
#include < cstring>
20
21
#include < filesystem>
22
+ #include < fstream>
21
23
#include < iostream>
22
24
#include < sstream>
23
25
#include < string>
35
37
#endif
36
38
37
39
GGML_ATTRIBUTE_FORMAT (1 , 2 )
40
+
38
41
static std::string fmt(const char * fmt, ...) {
39
42
va_list ap;
40
43
va_list ap2;
41
44
va_start (ap, fmt);
42
45
va_copy (ap2, ap);
43
46
const int size = vsnprintf (NULL , 0 , fmt, ap);
44
- GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
47
+ GGML_ASSERT (size >= 0 && size < INT_MAX); // NOLINT
45
48
std::string buf;
46
49
buf.resize (size);
47
50
const int size2 = vsnprintf (const_cast <char *>(buf.data ()), buf.size () + 1 , fmt, ap2);
@@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) {
53
56
}
54
57
55
58
GGML_ATTRIBUTE_FORMAT (1 , 2 )
59
+
56
60
static int printe(const char * fmt, ...) {
57
61
va_list args;
58
62
va_start (args, fmt);
@@ -101,7 +105,8 @@ class Opt {
101
105
102
106
llama_context_params ctx_params;
103
107
llama_model_params model_params;
104
- std::string model_;
108
+ std::string model_;
109
+ std::string chat_template_;
105
110
std::string user;
106
111
int context_size = -1 , ngl = -1 ;
107
112
float temperature = -1 ;
@@ -137,7 +142,7 @@ class Opt {
137
142
}
138
143
139
144
int parse (int argc, const char ** argv) {
140
- bool options_parsing = true ;
145
+ bool options_parsing = true ;
141
146
for (int i = 1 , positional_args_i = 0 ; i < argc; ++i) {
142
147
if (options_parsing && (strcmp (argv[i], " -c" ) == 0 || strcmp (argv[i], " --context-size" ) == 0 )) {
143
148
if (handle_option_with_value (argc, argv, i, context_size) == 1 ) {
@@ -166,6 +171,11 @@ class Opt {
166
171
167
172
++positional_args_i;
168
173
model_ = argv[i];
174
+ } else if (options_parsing && strcmp (argv[i], " --chat-template" ) == 0 ) {
175
+ if (i + 1 >= argc) {
176
+ return 1 ;
177
+ }
178
+ chat_template_ = argv[++i];
169
179
} else if (positional_args_i == 1 ) {
170
180
++positional_args_i;
171
181
user = argv[i];
@@ -475,7 +485,9 @@ class HttpClient {
475
485
return (now_downloaded_plus_file_size * 100 ) / total_to_download;
476
486
}
477
487
478
- static std::string generate_progress_prefix (curl_off_t percentage) { return fmt (" %3ld%% |" , static_cast <long int >(percentage)); }
488
+ static std::string generate_progress_prefix (curl_off_t percentage) {
489
+ return fmt (" %3ld%% |" , static_cast <long int >(percentage));
490
+ }
479
491
480
492
static double calculate_speed (curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481
493
const auto now = std::chrono::steady_clock::now ();
@@ -515,6 +527,7 @@ class HttpClient {
515
527
printe (" \r %*s\r %s%s| %s" , get_terminal_width (), " " , progress_prefix.c_str (), progress_bar.c_str (),
516
528
progress_suffix.c_str ());
517
529
}
530
+
518
531
// Function to write data to a file
519
532
static size_t write_data (void * ptr, size_t size, size_t nmemb, void * stream) {
520
533
FILE * out = static_cast <FILE *>(stream);
@@ -538,19 +551,23 @@ class LlamaData {
538
551
std::vector<llama_chat_message> messages;
539
552
std::vector<std::string> msg_strs;
540
553
std::vector<char > fmtted;
554
+ std::string chat_template;
541
555
542
556
int init (Opt & opt) {
543
557
model = initialize_model (opt);
544
558
if (!model) {
545
559
return 1 ;
546
560
}
547
561
562
+ chat_template = initialize_chat_template (model, opt);
563
+
548
564
context = initialize_context (model, opt);
549
565
if (!context) {
550
566
return 1 ;
551
567
}
552
568
553
569
sampler = initialize_sampler (opt);
570
+
554
571
return 0 ;
555
572
}
556
573
@@ -573,21 +590,74 @@ class LlamaData {
573
590
}
574
591
#endif
575
592
576
- int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
593
+ int huggingface_dl_tmpl (const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
594
+ if (std::filesystem::exists (tn)) {
595
+ return 0 ;
596
+ }
597
+
598
+ const std::string config_url = " https://huggingface.co/" + hfr + " /resolve/main/tokenizer_config.json" ;
599
+ std::string tokenizer_config_str;
600
+ download (config_url, headers, " " , true , &tokenizer_config_str);
601
+ if (tokenizer_config_str.empty ()) {
602
+ // still return success since tokenizer_config is optional
603
+ return 0 ;
604
+ }
605
+
606
+ nlohmann::json config = nlohmann::json::parse (tokenizer_config_str);
607
+ std::string tmpl = config[" chat_template" ];
608
+
609
+ FILE * tmpl_file = fopen (tn.c_str (), " w" );
610
+ if (tmpl_file == NULL ) {
611
+ return 1 ;
612
+ }
613
+ fprintf (tmpl_file, " %s" , tmpl.c_str ());
614
+ fclose (tmpl_file);
615
+
616
+ return 0 ;
617
+ }
618
+
619
+ int huggingface_dl (const std::string & model, const std::vector<std::string> headers, const std::string & bn,
620
+ const std::string & tn) {
621
+ bool model_exists = std::filesystem::exists (bn);
622
+ bool chat_tmpl_exists = std::filesystem::exists (tn);
623
+ if (model_exists && chat_tmpl_exists) {
624
+ return 0 ;
625
+ }
626
+
577
627
// Find the second occurrence of '/' after protocol string
578
628
size_t pos = model.find (' /' );
579
629
pos = model.find (' /' , pos + 1 );
580
630
if (pos == std::string::npos) {
581
631
return 1 ;
582
632
}
583
-
584
633
const std::string hfr = model.substr (0 , pos);
585
634
const std::string hff = model.substr (pos + 1 );
586
- const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
587
- return download (url, headers, bn, true );
635
+
636
+ if (!chat_tmpl_exists) {
637
+ const int ret = huggingface_dl_tmpl (hfr, headers, tn);
638
+ if (ret) {
639
+ return ret;
640
+ }
641
+ }
642
+
643
+ if (!model_exists) {
644
+ const std::string url = " https://huggingface.co/" + hfr + " /resolve/main/" + hff;
645
+ const int ret = download (url, headers, bn, true );
646
+ if (ret) {
647
+ return ret;
648
+ }
649
+ }
650
+ return 0 ;
588
651
}
589
652
590
- int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn) {
653
+ int ollama_dl (std::string & model, const std::vector<std::string> headers, const std::string & bn,
654
+ const std::string & tn) {
655
+ bool model_exists = std::filesystem::exists (bn);
656
+ bool chat_tmpl_exists = std::filesystem::exists (tn);
657
+ if (model_exists && chat_tmpl_exists) {
658
+ return 0 ;
659
+ }
660
+
591
661
if (model.find (' /' ) == std::string::npos) {
592
662
model = " library/" + model;
593
663
}
@@ -607,16 +677,34 @@ class LlamaData {
607
677
}
608
678
609
679
nlohmann::json manifest = nlohmann::json::parse (manifest_str);
610
- std::string layer;
680
+ std::string sha_model;
681
+ std::string sha_template;
611
682
for (const auto & l : manifest[" layers" ]) {
612
683
if (l[" mediaType" ] == " application/vnd.ollama.image.model" ) {
613
- layer = l[" digest" ];
614
- break ;
684
+ sha_model = l[" digest" ];
685
+ }
686
+ if (l[" mediaType" ] == " application/vnd.ollama.image.template" ) {
687
+ sha_template = l[" digest" ];
688
+ }
689
+ }
690
+
691
+ if (!chat_tmpl_exists && !sha_template.empty ()) {
692
+ std::string tmpl_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_template;
693
+ const int tmpl_ret = download (tmpl_blob_url, headers, tn, true );
694
+ if (tmpl_ret) {
695
+ return tmpl_ret;
696
+ }
697
+ }
698
+
699
+ if (!model_exists) {
700
+ std::string model_blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + sha_model;
701
+ const int model_ret = download (model_blob_url, headers, bn, true );
702
+ if (model_ret) {
703
+ return model_ret;
615
704
}
616
705
}
617
706
618
- std::string blob_url = " https://registry.ollama.ai/v2/" + model + " /blobs/" + layer;
619
- return download (blob_url, headers, bn, true );
707
+ return 0 ;
620
708
}
621
709
622
710
std::string basename (const std::string & path) {
@@ -628,6 +716,15 @@ class LlamaData {
628
716
return path.substr (pos + 1 );
629
717
}
630
718
719
+ std::string get_proto (const std::string & model_) {
720
+ const std::string::size_type pos = model_.find (" ://" );
721
+ if (pos == std::string::npos) {
722
+ return " " ;
723
+ }
724
+
725
+ return model_.substr (0 , pos + 3 ); // Include "://"
726
+ }
727
+
631
728
int remove_proto (std::string & model_) {
632
729
const std::string::size_type pos = model_.find (" ://" );
633
730
if (pos == std::string::npos) {
@@ -638,38 +735,40 @@ class LlamaData {
638
735
return 0 ;
639
736
}
640
737
641
- int resolve_model (std::string & model_) {
642
- int ret = 0 ;
643
- if (string_starts_with (model_, " file://" ) || std::filesystem::exists (model_) ) {
738
+ int resolve_model (std::string & model_, std::string & chat_template_ ) {
739
+ int ret = 0 ;
740
+ if (string_starts_with (model_, " file://" )) {
644
741
remove_proto (model_);
645
-
646
742
return ret;
647
743
}
648
744
745
+ std::string proto = get_proto (model_);
746
+ remove_proto (model_);
747
+
649
748
const std::string bn = basename (model_);
749
+ const std::string tn = chat_template_.empty () ? bn + " .template" : chat_template_;
650
750
const std::vector<std::string> headers = { " --header" ,
651
751
" Accept: application/vnd.docker.distribution.manifest.v2+json" };
652
- if (string_starts_with (model_, " hf://" ) || string_starts_with (model_, " huggingface://" )) {
653
- remove_proto (model_);
654
- ret = huggingface_dl (model_, headers, bn);
655
- } else if (string_starts_with (model_, " ollama://" )) {
656
- remove_proto (model_);
657
- ret = ollama_dl (model_, headers, bn);
658
- } else if (string_starts_with (model_, " https://" )) {
752
+ if (string_starts_with (proto, " hf://" ) || string_starts_with (proto, " huggingface://" )) {
753
+ ret = huggingface_dl (model_, headers, bn, tn);
754
+ } else if (string_starts_with (proto, " ollama://" )) {
755
+ ret = ollama_dl (model_, headers, bn, tn);
756
+ } else if (string_starts_with (proto, " https://" )) {
659
757
download (model_, headers, bn, true );
660
758
} else {
661
- ret = ollama_dl (model_, headers, bn);
759
+ ret = ollama_dl (model_, headers, bn, tn );
662
760
}
663
761
664
- model_ = bn;
762
+ model_ = bn;
763
+ chat_template_ = tn;
665
764
666
765
return ret;
667
766
}
668
767
669
768
// Initializes the model and returns a unique pointer to it
670
769
llama_model_ptr initialize_model (Opt & opt) {
671
770
ggml_backend_load_all ();
672
- resolve_model (opt.model_ );
771
+ resolve_model (opt.model_ , opt. chat_template_ );
673
772
printe (
674
773
" \r %*s"
675
774
" \r Loading model" ,
@@ -702,6 +801,31 @@ class LlamaData {
702
801
703
802
return sampler;
704
803
}
804
+
805
+ std::string initialize_chat_template (const llama_model_ptr & model, const Opt & opt) {
806
+ if (!std::filesystem::exists (opt.chat_template_ )) {
807
+ return common_get_builtin_chat_template (model.get ());
808
+ }
809
+
810
+ FILE * tmpl_file = ggml_fopen (opt.chat_template_ .c_str (), " r" );
811
+ if (!tmpl_file) {
812
+ std::cerr << " Error opening file '" << opt.chat_template_ << " ': " << strerror (errno) << " \n " ;
813
+ return " " ;
814
+ }
815
+
816
+ fseek (tmpl_file, 0 , SEEK_END);
817
+ size_t size = ftell (tmpl_file);
818
+ fseek (tmpl_file, 0 , SEEK_SET);
819
+
820
+ std::vector<unsigned char > data (size);
821
+ size_t read_size = fread (data.data (), 1 , size, tmpl_file);
822
+ fclose (tmpl_file);
823
+ if (read_size != size) {
824
+ std::cerr << " Error reading file '" << opt.chat_template_ << " ': " << strerror (errno) << " \n " ;
825
+ return " " ;
826
+ }
827
+ return std::string (data.begin (), data.end ());
828
+ }
705
829
};
706
830
707
831
// Add a message to `messages` and store its content in `msg_strs`
@@ -713,11 +837,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713
837
// Function to apply the chat template and resize `formatted` if needed
714
838
static int apply_chat_template (LlamaData & llama_data, const bool append) {
715
839
int result = llama_chat_apply_template (
716
- llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (), llama_data.messages .size (), append,
840
+ llama_data.chat_template . c_str ( ), llama_data.messages .data (), llama_data.messages .size (), append,
717
841
append ? llama_data.fmtted .data () : nullptr , append ? llama_data.fmtted .size () : 0 );
718
842
if (append && result > static_cast <int >(llama_data.fmtted .size ())) {
719
843
llama_data.fmtted .resize (result);
720
- result = llama_chat_apply_template (llama_model_chat_template ( llama_data.model . get () ), llama_data.messages .data (),
844
+ result = llama_chat_apply_template (llama_data.chat_template . c_str ( ), llama_data.messages .data (),
721
845
llama_data.messages .size (), append, llama_data.fmtted .data (),
722
846
llama_data.fmtted .size ());
723
847
}
@@ -730,8 +854,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
730
854
std::vector<llama_token> & prompt_tokens) {
731
855
const int n_prompt_tokens = -llama_tokenize (vocab, prompt.c_str (), prompt.size (), NULL , 0 , true , true );
732
856
prompt_tokens.resize (n_prompt_tokens);
733
- if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true ,
734
- true ) < 0 ) {
857
+ if (llama_tokenize (vocab, prompt.c_str (), prompt.size (), prompt_tokens.data (), prompt_tokens.size (), true , true ) <
858
+ 0 ) {
735
859
printe (" failed to tokenize the prompt\n " );
736
860
return -1 ;
737
861
}
0 commit comments