Skip to content

Commit a899673

Browse files
committed
Added chat template support to llama-run
Fixes: #11178 The llama-run CLI currently doesn't take the chat template of a model into account. Thus executing llama-run on a model requiring a chat template will fail. In order to solve this, the chat template is being downloaded from ollama or huggingface as well and applied during the chat. Signed-off-by: Michael Engel <mengel@redhat.com>
1 parent bbf3e55 commit a899673

File tree

1 file changed

+157
-33
lines changed

1 file changed

+157
-33
lines changed

examples/run/run.cpp

+157-33
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#if defined(_WIN32)
2-
# include <windows.h>
32
# include <io.h>
3+
# include <windows.h>
44
#else
55
# include <sys/file.h>
66
# include <sys/ioctl.h>
@@ -12,12 +12,14 @@
1212
#endif
1313

1414
#include <signal.h>
15+
#include <sys/stat.h>
1516

1617
#include <climits>
1718
#include <cstdarg>
1819
#include <cstdio>
1920
#include <cstring>
2021
#include <filesystem>
22+
#include <fstream>
2123
#include <iostream>
2224
#include <sstream>
2325
#include <string>
@@ -35,13 +37,14 @@
3537
#endif
3638

3739
GGML_ATTRIBUTE_FORMAT(1, 2)
40+
3841
static std::string fmt(const char * fmt, ...) {
3942
va_list ap;
4043
va_list ap2;
4144
va_start(ap, fmt);
4245
va_copy(ap2, ap);
4346
const int size = vsnprintf(NULL, 0, fmt, ap);
44-
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
47+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
4548
std::string buf;
4649
buf.resize(size);
4750
const int size2 = vsnprintf(const_cast<char *>(buf.data()), buf.size() + 1, fmt, ap2);
@@ -53,6 +56,7 @@ static std::string fmt(const char * fmt, ...) {
5356
}
5457

5558
GGML_ATTRIBUTE_FORMAT(1, 2)
59+
5660
static int printe(const char * fmt, ...) {
5761
va_list args;
5862
va_start(args, fmt);
@@ -101,7 +105,8 @@ class Opt {
101105

102106
llama_context_params ctx_params;
103107
llama_model_params model_params;
104-
std::string model_;
108+
std::string model_;
109+
std::string chat_template_;
105110
std::string user;
106111
int context_size = -1, ngl = -1;
107112
float temperature = -1;
@@ -137,7 +142,7 @@ class Opt {
137142
}
138143

139144
int parse(int argc, const char ** argv) {
140-
bool options_parsing = true;
145+
bool options_parsing = true;
141146
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
142147
if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) {
143148
if (handle_option_with_value(argc, argv, i, context_size) == 1) {
@@ -166,6 +171,11 @@ class Opt {
166171

167172
++positional_args_i;
168173
model_ = argv[i];
174+
} else if (options_parsing && strcmp(argv[i], "--chat-template") == 0) {
175+
if (i + 1 >= argc) {
176+
return 1;
177+
}
178+
chat_template_ = argv[++i];
169179
} else if (positional_args_i == 1) {
170180
++positional_args_i;
171181
user = argv[i];
@@ -475,7 +485,9 @@ class HttpClient {
475485
return (now_downloaded_plus_file_size * 100) / total_to_download;
476486
}
477487

478-
static std::string generate_progress_prefix(curl_off_t percentage) { return fmt("%3ld%% |", static_cast<long int>(percentage)); }
488+
static std::string generate_progress_prefix(curl_off_t percentage) {
489+
return fmt("%3ld%% |", static_cast<long int>(percentage));
490+
}
479491

480492
static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) {
481493
const auto now = std::chrono::steady_clock::now();
@@ -515,6 +527,7 @@ class HttpClient {
515527
printe("\r%*s\r%s%s| %s", get_terminal_width(), " ", progress_prefix.c_str(), progress_bar.c_str(),
516528
progress_suffix.c_str());
517529
}
530+
518531
// Function to write data to a file
519532
static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) {
520533
FILE * out = static_cast<FILE *>(stream);
@@ -538,19 +551,23 @@ class LlamaData {
538551
std::vector<llama_chat_message> messages;
539552
std::vector<std::string> msg_strs;
540553
std::vector<char> fmtted;
554+
std::string chat_template;
541555

542556
int init(Opt & opt) {
543557
model = initialize_model(opt);
544558
if (!model) {
545559
return 1;
546560
}
547561

562+
chat_template = initialize_chat_template(model, opt);
563+
548564
context = initialize_context(model, opt);
549565
if (!context) {
550566
return 1;
551567
}
552568

553569
sampler = initialize_sampler(opt);
570+
554571
return 0;
555572
}
556573

@@ -573,21 +590,74 @@ class LlamaData {
573590
}
574591
#endif
575592

576-
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
593+
int huggingface_dl_tmpl(const std::string & hfr, const std::vector<std::string> headers, const std::string & tn) {
594+
if (std::filesystem::exists(tn)) {
595+
return 0;
596+
}
597+
598+
const std::string config_url = "https://huggingface.co/" + hfr + "/resolve/main/tokenizer_config.json";
599+
std::string tokenizer_config_str;
600+
download(config_url, headers, "", true, &tokenizer_config_str);
601+
if (tokenizer_config_str.empty()) {
602+
// still return success since tokenizer_config is optional
603+
return 0;
604+
}
605+
606+
nlohmann::json config = nlohmann::json::parse(tokenizer_config_str);
607+
std::string tmpl = config["chat_template"];
608+
609+
FILE * tmpl_file = fopen(tn.c_str(), "w");
610+
if (tmpl_file == NULL) {
611+
return 1;
612+
}
613+
fprintf(tmpl_file, "%s", tmpl.c_str());
614+
fclose(tmpl_file);
615+
616+
return 0;
617+
}
618+
619+
int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn,
620+
const std::string & tn) {
621+
bool model_exists = std::filesystem::exists(bn);
622+
bool chat_tmpl_exists = std::filesystem::exists(tn);
623+
if (model_exists && chat_tmpl_exists) {
624+
return 0;
625+
}
626+
577627
// Find the second occurrence of '/' after protocol string
578628
size_t pos = model.find('/');
579629
pos = model.find('/', pos + 1);
580630
if (pos == std::string::npos) {
581631
return 1;
582632
}
583-
584633
const std::string hfr = model.substr(0, pos);
585634
const std::string hff = model.substr(pos + 1);
586-
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
587-
return download(url, headers, bn, true);
635+
636+
if (!chat_tmpl_exists) {
637+
const int ret = huggingface_dl_tmpl(hfr, headers, tn);
638+
if (ret) {
639+
return ret;
640+
}
641+
}
642+
643+
if (!model_exists) {
644+
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
645+
const int ret = download(url, headers, bn, true);
646+
if (ret) {
647+
return ret;
648+
}
649+
}
650+
return 0;
588651
}
589652

590-
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
653+
int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn,
654+
const std::string & tn) {
655+
bool model_exists = std::filesystem::exists(bn);
656+
bool chat_tmpl_exists = std::filesystem::exists(tn);
657+
if (model_exists && chat_tmpl_exists) {
658+
return 0;
659+
}
660+
591661
if (model.find('/') == std::string::npos) {
592662
model = "library/" + model;
593663
}
@@ -607,16 +677,34 @@ class LlamaData {
607677
}
608678

609679
nlohmann::json manifest = nlohmann::json::parse(manifest_str);
610-
std::string layer;
680+
std::string sha_model;
681+
std::string sha_template;
611682
for (const auto & l : manifest["layers"]) {
612683
if (l["mediaType"] == "application/vnd.ollama.image.model") {
613-
layer = l["digest"];
614-
break;
684+
sha_model = l["digest"];
685+
}
686+
if (l["mediaType"] == "application/vnd.ollama.image.template") {
687+
sha_template = l["digest"];
688+
}
689+
}
690+
691+
if (!chat_tmpl_exists && !sha_template.empty()) {
692+
std::string tmpl_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_template;
693+
const int tmpl_ret = download(tmpl_blob_url, headers, tn, true);
694+
if (tmpl_ret) {
695+
return tmpl_ret;
696+
}
697+
}
698+
699+
if (!model_exists) {
700+
std::string model_blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + sha_model;
701+
const int model_ret = download(model_blob_url, headers, bn, true);
702+
if (model_ret) {
703+
return model_ret;
615704
}
616705
}
617706

618-
std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
619-
return download(blob_url, headers, bn, true);
707+
return 0;
620708
}
621709

622710
std::string basename(const std::string & path) {
@@ -628,6 +716,15 @@ class LlamaData {
628716
return path.substr(pos + 1);
629717
}
630718

719+
std::string get_proto(const std::string & model_) {
720+
const std::string::size_type pos = model_.find("://");
721+
if (pos == std::string::npos) {
722+
return "";
723+
}
724+
725+
return model_.substr(0, pos + 3); // Include "://"
726+
}
727+
631728
int remove_proto(std::string & model_) {
632729
const std::string::size_type pos = model_.find("://");
633730
if (pos == std::string::npos) {
@@ -638,38 +735,40 @@ class LlamaData {
638735
return 0;
639736
}
640737

641-
int resolve_model(std::string & model_) {
642-
int ret = 0;
643-
if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) {
738+
int resolve_model(std::string & model_, std::string & chat_template_) {
739+
int ret = 0;
740+
if (string_starts_with(model_, "file://")) {
644741
remove_proto(model_);
645-
646742
return ret;
647743
}
648744

745+
std::string proto = get_proto(model_);
746+
remove_proto(model_);
747+
649748
const std::string bn = basename(model_);
749+
const std::string tn = chat_template_.empty() ? bn + ".template" : chat_template_;
650750
const std::vector<std::string> headers = { "--header",
651751
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
652-
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
653-
remove_proto(model_);
654-
ret = huggingface_dl(model_, headers, bn);
655-
} else if (string_starts_with(model_, "ollama://")) {
656-
remove_proto(model_);
657-
ret = ollama_dl(model_, headers, bn);
658-
} else if (string_starts_with(model_, "https://")) {
752+
if (string_starts_with(proto, "hf://") || string_starts_with(proto, "huggingface://")) {
753+
ret = huggingface_dl(model_, headers, bn, tn);
754+
} else if (string_starts_with(proto, "ollama://")) {
755+
ret = ollama_dl(model_, headers, bn, tn);
756+
} else if (string_starts_with(proto, "https://")) {
659757
download(model_, headers, bn, true);
660758
} else {
661-
ret = ollama_dl(model_, headers, bn);
759+
ret = ollama_dl(model_, headers, bn, tn);
662760
}
663761

664-
model_ = bn;
762+
model_ = bn;
763+
chat_template_ = tn;
665764

666765
return ret;
667766
}
668767

669768
// Initializes the model and returns a unique pointer to it
670769
llama_model_ptr initialize_model(Opt & opt) {
671770
ggml_backend_load_all();
672-
resolve_model(opt.model_);
771+
resolve_model(opt.model_, opt.chat_template_);
673772
printe(
674773
"\r%*s"
675774
"\rLoading model",
@@ -702,6 +801,31 @@ class LlamaData {
702801

703802
return sampler;
704803
}
804+
805+
std::string initialize_chat_template(const llama_model_ptr & model, const Opt & opt) {
806+
if (!std::filesystem::exists(opt.chat_template_)) {
807+
return common_get_builtin_chat_template(model.get());
808+
}
809+
810+
FILE * tmpl_file = ggml_fopen(opt.chat_template_.c_str(), "r");
811+
if (!tmpl_file) {
812+
std::cerr << "Error opening file '" << opt.chat_template_ << "': " << strerror(errno) << "\n";
813+
return "";
814+
}
815+
816+
fseek(tmpl_file, 0, SEEK_END);
817+
size_t size = ftell(tmpl_file);
818+
fseek(tmpl_file, 0, SEEK_SET);
819+
820+
std::vector<unsigned char> data(size);
821+
size_t read_size = fread(data.data(), 1, size, tmpl_file);
822+
fclose(tmpl_file);
823+
if (read_size != size) {
824+
std::cerr << "Error reading file '" << opt.chat_template_ << "': " << strerror(errno) << "\n";
825+
return "";
826+
}
827+
return std::string(data.begin(), data.end());
828+
}
705829
};
706830

707831
// Add a message to `messages` and store its content in `msg_strs`
@@ -713,11 +837,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
713837
// Function to apply the chat template and resize `formatted` if needed
714838
static int apply_chat_template(LlamaData & llama_data, const bool append) {
715839
int result = llama_chat_apply_template(
716-
llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
840+
llama_data.chat_template.c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
717841
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
718842
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
719843
llama_data.fmtted.resize(result);
720-
result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
844+
result = llama_chat_apply_template(llama_data.chat_template.c_str(), llama_data.messages.data(),
721845
llama_data.messages.size(), append, llama_data.fmtted.data(),
722846
llama_data.fmtted.size());
723847
}
@@ -730,8 +854,8 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt
730854
std::vector<llama_token> & prompt_tokens) {
731855
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
732856
prompt_tokens.resize(n_prompt_tokens);
733-
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
734-
true) < 0) {
857+
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) <
858+
0) {
735859
printe("failed to tokenize the prompt\n");
736860
return -1;
737861
}

0 commit comments

Comments
 (0)