Skip to content

Commit 9c405c9

Browse files
ngxsonggerganov
andauthored
Server: use llama_chat_apply_template (#5593)
* server: use llama_chat_apply_template * server: remove trailing space * server: fix format_chat * server: fix help message Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * server: fix formatted_chat --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 5207b3f commit 9c405c9

File tree

4 files changed

+45
-49
lines changed

4 files changed

+45
-49
lines changed

examples/server/oai.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
using json = nlohmann::json;
1616

1717
inline static json oaicompat_completion_params_parse(
18+
const struct llama_model * model,
1819
const json &body, /* openai api json semantics */
1920
const std::string &chat_template)
2021
{
2122
json llama_params;
22-
std::string formatted_prompt = chat_template == "chatml"
23-
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
24-
: format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
2523

2624
llama_params["__oaicompat"] = true;
2725

@@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse(
3432
// https://platform.openai.com/docs/api-reference/chat/create
3533
llama_sampling_params default_sparams;
3634
llama_params["model"] = json_value(body, "model", std::string("unknown"));
37-
llama_params["prompt"] = formatted_prompt;
35+
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
3836
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
3937
llama_params["temperature"] = json_value(body, "temperature", 0.0);
4038
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);

examples/server/server.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ struct server_params
3737
std::string hostname = "127.0.0.1";
3838
std::vector<std::string> api_keys;
3939
std::string public_path = "examples/server/public";
40-
std::string chat_template = "chatml";
40+
std::string chat_template = "";
4141
int32_t port = 8080;
4242
int32_t read_timeout = 600;
4343
int32_t write_timeout = 600;
@@ -1937,8 +1937,9 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
19371937
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
19381938
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
19391939
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
1940-
printf(" --chat-template FORMAT_NAME");
1941-
printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str());
1940+
printf(" --chat-template JINJA_TEMPLATE\n");
1941+
printf(" set custom jinja chat template (default: template taken from model's metadata)\n");
1942+
printf(" Note: only commonly used templates are accepted, since we don't have jinja parser\n");
19421943
printf("\n");
19431944
}
19441945

@@ -2389,13 +2390,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
23892390
invalid_param = true;
23902391
break;
23912392
}
2392-
std::string value(argv[i]);
2393-
if (value != "chatml" && value != "llama2") {
2394-
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
2393+
if (!verify_custom_template(argv[i])) {
2394+
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
2395+
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
23952396
invalid_param = true;
23962397
break;
23972398
}
2398-
sparams.chat_template = value;
2399+
sparams.chat_template = argv[i];
23992400
}
24002401
else if (arg == "--override-kv")
24012402
{
@@ -2913,7 +2914,7 @@ int main(int argc, char **argv)
29132914
if (!validate_api_key(req, res)) {
29142915
return;
29152916
}
2916-
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
2917+
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
29172918

29182919
const int task_id = llama.queue_tasks.get_new_id();
29192920
llama.queue_results.add_waiting_task_id(task_id);

examples/server/utils.hpp

+33-36
Original file line numberDiff line numberDiff line change
@@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v
167167
: default_value;
168168
}
169169

170-
inline std::string format_llama2(std::vector<json> messages)
171-
{
172-
std::ostringstream output;
173-
bool is_inside_turn = false;
170+
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
171+
inline bool verify_custom_template(const std::string & tmpl) {
172+
llama_chat_message chat[] = {{"user", "test"}};
173+
std::vector<char> buf(1);
174+
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
175+
return res >= 0;
176+
}
174177

175-
for (auto it = messages.begin(); it != messages.end(); ++it) {
176-
if (!is_inside_turn) {
177-
output << "[INST] ";
178-
}
179-
std::string role = json_value(*it, "role", std::string("user"));
180-
std::string content = json_value(*it, "content", std::string(""));
181-
if (role == "system") {
182-
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
183-
is_inside_turn = true;
184-
} else if (role == "user") {
185-
output << content << " [/INST]";
186-
is_inside_turn = true;
187-
} else {
188-
output << " " << content << " </s>";
189-
is_inside_turn = false;
190-
}
178+
// Format given chat. If tmpl is empty, we take the template from model metadata
179+
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages)
180+
{
181+
size_t alloc_size = 0;
182+
// vector holding all allocated string to be passed to llama_chat_apply_template
183+
std::vector<std::string> str(messages.size() * 2);
184+
std::vector<llama_chat_message> chat(messages.size());
185+
186+
for (size_t i = 0; i < messages.size(); ++i) {
187+
auto &curr_msg = messages[i];
188+
str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
189+
str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
190+
alloc_size += str[i*2 + 1].length();
191+
chat[i].role = str[i*2 + 0].c_str();
192+
chat[i].content = str[i*2 + 1].c_str();
191193
}
192194

193-
LOG_VERBOSE("format_llama2", {{"text", output.str()}});
195+
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
196+
std::vector<char> buf(alloc_size * 2);
194197

195-
return output.str();
196-
}
197-
198-
inline std::string format_chatml(std::vector<json> messages)
199-
{
200-
std::ostringstream chatml_msgs;
198+
// run the first time to get the total output length
199+
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
201200

202-
for (auto it = messages.begin(); it != messages.end(); ++it) {
203-
chatml_msgs << "<|im_start|>"
204-
<< json_value(*it, "role", std::string("user")) << '\n';
205-
chatml_msgs << json_value(*it, "content", std::string(""))
206-
<< "<|im_end|>\n";
201+
// if it turns out that our buffer is too small, we resize it
202+
if ((size_t) res > buf.size()) {
203+
buf.resize(res);
204+
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
207205
}
208206

209-
chatml_msgs << "<|im_start|>assistant" << '\n';
210-
211-
LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
207+
std::string formatted_chat(buf.data(), res);
208+
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
212209

213-
return chatml_msgs.str();
210+
return formatted_chat;
214211
}
215212

216213
//

llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12602,7 +12602,7 @@ LLAMA_API int32_t llama_chat_apply_template(
1260212602
// load template from model
1260312603
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
1260412604
std::string template_key = "tokenizer.chat_template";
12605-
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
12605+
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
1260612606
if (res < 0) {
1260712607
// worst case: there is no information about template, we will use chatml by default
1260812608
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal

0 commit comments

Comments
 (0)