Skip to content

Server: use llama_chat_apply_template #5593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/server/oai.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
using json = nlohmann::json;

inline static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json &body, /* openai api json semantics */
const std::string &chat_template)
{
json llama_params;
std::string formatted_prompt = chat_template == "chatml"
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
: format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)

llama_params["__oaicompat"] = true;

Expand All @@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = formatted_prompt;
llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);
Expand Down
16 changes: 9 additions & 7 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct server_params
std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys;
std::string public_path = "examples/server/public";
std::string chat_template = "chatml";
std::string chat_template = "";
int32_t port = 8080;
int32_t read_timeout = 600;
int32_t write_timeout = 600;
Expand Down Expand Up @@ -1937,8 +1937,9 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf(" --chat-template FORMAT_NAME");
printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str());
printf(" --chat-template JINJA_TEMPLATE");
printf(" set custom jinja chat template (default: template taken from model's metadata)");
printf(" Note: only commonly used templates are accepted, since we don't have jinja parser");
printf("\n");
}

Expand Down Expand Up @@ -2390,12 +2391,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break;
}
std::string value(argv[i]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value seems unused now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's unused, I forgot to remove that. It's now removed

if (value != "chatml" && value != "llama2") {
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str());
if (!verify_custom_template(argv[i])) {
fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
invalid_param = true;
break;
}
sparams.chat_template = value;
sparams.chat_template = argv[i];
}
else if (arg == "--override-kv")
{
Expand Down Expand Up @@ -2913,7 +2915,7 @@ int main(int argc, char **argv)
if (!validate_api_key(req, res)) {
return;
}
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template);
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);

const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
Expand Down
69 changes: 33 additions & 36 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v
: default_value;
}

inline std::string format_llama2(std::vector<json> messages)
{
std::ostringstream output;
bool is_inside_turn = false;
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
inline bool verify_custom_template(std::string tmpl) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inline bool verify_custom_template(std::string tmpl) {
inline bool verify_custom_template(const std::string & tmpl) {

llama_chat_message chat[] = {{"user", "test"}};
std::vector<char> buf(1);
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
return res >= 0;
}

for (auto it = messages.begin(); it != messages.end(); ++it) {
if (!is_inside_turn) {
output << "[INST] ";
}
std::string role = json_value(*it, "role", std::string("user"));
std::string content = json_value(*it, "content", std::string(""));
if (role == "system") {
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
is_inside_turn = true;
} else if (role == "user") {
output << content << " [/INST]";
is_inside_turn = true;
} else {
output << " " << content << " </s>";
is_inside_turn = false;
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string tmpl, std::vector<json> messages)
{
size_t alloc_size = 0;
// vector holding all allocated string to be passed to llama_chat_apply_template
std::vector<std::string> str(messages.size() * 2);
std::vector<llama_chat_message> chat(messages.size());

for (size_t i = 0; i < messages.size(); ++i) {
auto &curr_msg = messages[i];
str[i] = json_value(curr_msg, "role", std::string(""));
str[i + 1] = json_value(curr_msg, "content", std::string(""));
alloc_size += str[i + 1].length();
chat[i].role = str[i].c_str();
chat[i].content = str[i + 1].c_str();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a bug here. Maybe change to str[2*i + 0] = ... and str[2*i + 1] = ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank for notice that. That's why I noticed that the bot's response is quite weird when I test this PR yesterday.

Fixed on c53b34d

Looking at the debug log, I can confirm that the formatted chat is correct:

{"timestamp":1708423580,"level":"VERBOSE","function":"format_chat","line":208,"message":"formatted_chat","text":"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nhi, how are you<|im_end|>\n<|im_start|>assistant\n"}

}

LOG_VERBOSE("format_llama2", {{"text", output.str()}});
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2);

return output.str();
}

inline std::string format_chatml(std::vector<json> messages)
{
std::ostringstream chatml_msgs;
// run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());

for (auto it = messages.begin(); it != messages.end(); ++it) {
chatml_msgs << "<|im_start|>"
<< json_value(*it, "role", std::string("user")) << '\n';
chatml_msgs << json_value(*it, "content", std::string(""))
<< "<|im_end|>\n";
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
}

chatml_msgs << "<|im_start|>assistant" << '\n';

LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
std::string formatted_chat(buf.data(), buf.size());
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});

return chatml_msgs.str();
return formatted_chat;
}

//
Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12602,7 +12602,7 @@ LLAMA_API int32_t llama_chat_apply_template(
// load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size());
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res < 0) {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal
Expand Down