Skip to content

Add Jinja template support #11016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 47 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
abd274a
Copy minja from https://github.com/google/minja/commit/58f0ca6dd74bcb…
ochafik Dec 30, 2024
e5113e8
Add --jinja and --chat-template-file flags
ochafik Dec 30, 2024
80138d9
Add missing <optional> include
ochafik Dec 30, 2024
06b5159
Avoid print in get_hf_chat_template.py
ochafik Dec 30, 2024
ce48584
No designated initializers yet
ochafik Dec 30, 2024
389d79b
Try and work around msvc++ non-macro max resolution quirk
ochafik Dec 30, 2024
238b968
Update test_chat_completion.py
ochafik Dec 30, 2024
cb72cf1
Merge remote-tracking branch 'origin/master' into jinja
ochafik Jan 13, 2025
78861a3
Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template
ochafik Jan 13, 2025
1aac99a
Refactor test-chat-template
ochafik Jan 13, 2025
7c84ebc
Test templates w/ minja
ochafik Jan 13, 2025
18f257b
Fix deprecation
ochafik Jan 13, 2025
8dd4f33
Add --jinja to llama-run
ochafik Jan 13, 2025
c04c50e
Merge remote-tracking branch 'origin/master' into jinja
ochafik Jan 13, 2025
a6afb27
Update common_chat_format_example to use minja template wrapper
ochafik Jan 13, 2025
b4083e4
Test chat_template in e2e test
ochafik Jan 13, 2025
b7e2171
Update utils.py
ochafik Jan 13, 2025
a57bb94
Update test_chat_completion.py
ochafik Jan 13, 2025
4daae0b
Update run.cpp
ochafik Jan 13, 2025
1b3bb7e
Update arg.cpp
ochafik Jan 14, 2025
3ed670b
Merge remote-tracking branch 'origin/master' into jinja
ochafik Jan 14, 2025
b75d062
Refactor common_chat_* functions to accept minja template + use_jinja…
ochafik Jan 18, 2025
40db789
Merge remote-tracking branch 'origin/master' into jinja
ochafik Jan 18, 2025
81c0d43
Attempt to fix linkage of LLAMA_CHATML_TEMPLATE
ochafik Jan 18, 2025
d5fa351
Revert LLAMA_CHATML_TEMPLATE refactor
ochafik Jan 18, 2025
ee1e10e
Normalize newlines in test-chat-templates for windows tests
ochafik Jan 18, 2025
e63520f
Forward decl minja::chat_template to avoid eager json dep
ochafik Jan 18, 2025
33322e8
Flush stdout in chat template before potential crash
ochafik Jan 18, 2025
5074e6f
Fix copy elision warning
ochafik Jan 18, 2025
fc60802
Rm unused optional include
ochafik Jan 18, 2025
0e74c9d
Add missing optional include to server.cpp
ochafik Jan 18, 2025
e3c475c
Disable jinja test that has a cryptic windows failure
ochafik Jan 18, 2025
cc50356
minja: fix vigogne (https://github.com/google/minja/pull/22)
ochafik Jan 18, 2025
153e852
Apply suggestions from code review
ochafik Jan 20, 2025
db9dd0c
Finish suggested renamings
ochafik Jan 20, 2025
c9e8fdd
Move chat_templates inside server_context + remove mutex
ochafik Jan 20, 2025
8c84aef
Update --chat-template-file w/ recent change to --chat-template
ochafik Jan 20, 2025
154bfaa
Refactor chat template validation
ochafik Jan 20, 2025
099f983
Merge remote-tracking branch 'origin/master' into jinja
ochafik Jan 20, 2025
54a669e
Guard against missing eos/bos tokens (null token otherwise throws in …
ochafik Jan 20, 2025
8348c60
Warn against missing eos / bos tokens when jinja template references …
ochafik Jan 20, 2025
ee475d2
rename: common_chat_template[s]
ochafik Jan 20, 2025
8a7c89e
reinstate assert on chat_templates.template_default
ochafik Jan 20, 2025
8347da9
Update minja to https://github.com/google/minja/commit/b8437df626ac6c…
ochafik Jan 20, 2025
ff2cce5
Update minja to https://github.com/google/minja/pull/25
ochafik Jan 21, 2025
9d8ebd6
Update minja from https://github.com/google/minja/pull/27
ochafik Jan 21, 2025
cbb9b81
rm unused optional header
ochafik Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,9 @@ llama-server: \
examples/server/httplib.h \
examples/server/index.html.hpp \
examples/server/loading.html.hpp \
common/chat-template.hpp \
common/json.hpp \
common/minja.hpp \
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
Expand Down
2 changes: 2 additions & 0 deletions common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC
arg.cpp
arg.h
base64.hpp
chat-template.hpp
common.cpp
common.h
console.cpp
Expand All @@ -64,6 +65,7 @@ add_library(${TARGET} STATIC
json.hpp
log.cpp
log.h
minja.hpp
ngram-cache.cpp
ngram-cache.h
sampling.cpp
Expand Down
42 changes: 35 additions & 7 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both");
}

if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s%s\n",
params.chat_template.c_str(),
params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates"
));
}

return true;
}

Expand Down Expand Up @@ -1947,24 +1955,44 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--chat-template"}, "JINJA_TEMPLATE",
string_format(
"set custom jinja chat template (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
if (!common_chat_verify_template(value)) {
throw std::runtime_error(string_format(
"error: the supplied chat template is not supported: %s\n"
"note: llama.cpp does not use jinja parser, we only support commonly used templates\n",
value.c_str()
));
}
params.chat_template = value;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
add_opt(common_arg(
{"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
string_format(
"set custom jinja chat template file (default: template taken from model's metadata)\n"
"if suffix/prefix are specified, template will be disabled\n"
"only commonly used templates are accepted (unless --jinja is set before this flag):\n"
"list of built-in templates:\n%s", list_builtin_chat_templates().c_str()
),
[](common_params & params, const std::string & value) {
std::ifstream file(value);
if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str()));
}
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(params.chat_template));
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE"));
add_opt(common_arg(
{"-sps", "--slot-prompt-similarity"}, "SIMILARITY",
string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity),
Expand Down
249 changes: 249 additions & 0 deletions common/chat-template.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
Copyright 2024 Google LLC

Use of this source code is governed by an MIT-style
license that can be found in the LICENSE file or at
https://opensource.org/licenses/MIT.
*/
// SPDX-License-Identifier: MIT
#pragma once

#include "minja.hpp"
#include <json.hpp>
#include <string>
#include <vector>

using json = nlohmann::ordered_json;

namespace minja {

class chat_template {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea to be able to #include "chat-template.hpp" in main is to forward declare json here without #include <json.hpp>, only define the prototype of class chat_template here. Then we will need a new file chat-template.cpp that hold the actual implementation, including #include <json.hpp>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not sure if this even works, but we can do in another PR, just noting my idea here)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to keep minja header-only for now, but happy to explore options as follow up :-)

public:

private:
bool supports_tools_ = true;
// Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object.
// Most other templates (and OpenAI's API) expect the arguments object to be stringified.
bool requires_object_arguments_ = false;
bool supports_system_role_ = true;
bool supports_parallel_tool_calls_ = false;
std::string source_;
std::string bos_token_;
std::string eos_token_;
std::shared_ptr<minja::TemplateNode> template_root_;

std::string try_render(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
try {
auto prompt = apply(messages, tools, add_generation_prompt, extra_context);
// fprintf(stderr, "Prompt: %s\n", prompt.c_str());
return prompt;
} catch (const std::exception & e) {
// fprintf(stderr, "Error: %s\n", e.what());
return "";
}
}

public:
chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token)
: source_(source), bos_token_(bos_token), eos_token_(eos_token)
{
template_root_ = minja::Parser::parse(source_, {
/* .trim_blocks = */ true,
/* .lstrip_blocks = */ true,
/* .keep_trailing_newline = */ false,
});
supports_tools_ = source.find("tools") != std::string::npos;

auto renders_string_arguments =
try_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", "{\"code\": \"print('Hello, World!')\"}"},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
if (!renders_string_arguments) {
auto renders_object_arguments =
try_render({
{
{"role", "user"},
{"content", "Hey"}
},
{
{"role", "assistant"},
{"tool_calls", json::array({
{
{"id", "call_1___"},
{"type", "function"},
{"function", {
{"arguments", {
{"code", "print('Hello, World!')"},
}},
{"name", "ipython"},
}},
},
})},
}
}, {}, false).find("{\"code\": \"print") != std::string::npos;
requires_object_arguments_ = renders_object_arguments;
}
supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos;

supports_system_role_ = try_render({
{{"role", "system"}, {"content", "<System Needle>"}},
{{"role", "user"}, {"content", "Hey"}}
}, {}, false).find("<System Needle>") != std::string::npos;
}

const std::string & source() const { return source_; }
const std::string & bos_token() const { return bos_token_; }
const std::string & eos_token() const { return eos_token_; }
bool supports_tools() const { return supports_tools_; }
bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; }

std::string apply(
const nlohmann::ordered_json & messages,
const nlohmann::ordered_json & tools,
bool add_generation_prompt,
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
{
json actual_messages;

// First, "fix" messages so they have a chance to be rendered correctly by the template

if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) {
actual_messages = json::array();

std::string pending_system;
auto flush_sys = [&]() {
if (!pending_system.empty()) {
actual_messages.push_back({
{"role", "user"},
{"content", pending_system},
});
pending_system.clear();
}
};
for (const auto & message_ : messages) {
auto message = message_;
if (!message.contains("role") || !message.contains("content")) {
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
}
std::string role = message.at("role");

if (message.contains("tool_calls")) {
if (requires_object_arguments_ || !supports_tools_) {
for (auto & tool_call : message.at("tool_calls")) {
if (tool_call["type"] == "function") {
auto & function = tool_call.at("function");
std::string arguments = function.at("arguments");
function["arguments"] = json::parse(arguments);
}
}
}
if (!supports_tools_) {
auto content = message.at("content");
auto tool_calls = json::array();
for (const auto & tool_call : message.at("tool_calls")) {
if (tool_call.at("type") != "function") {
continue;
}
const auto & function = tool_call.at("function");
auto tc = json {
{"name", function.at("name")},
{"arguments", function.at("arguments")},
};
if (tool_call.contains("id")) {
tc["id"] = tool_call["id"];
}
tool_calls.push_back(tc);
}
auto obj = json {
{"tool_calls", tool_calls},
};
if (!content.is_null() && content != "") {
obj["content"] = content;
}
message["content"] = obj.dump(2);
message.erase("tool_calls");
}
}
if (!supports_tools_ && role == "tool") {
message["role"] = "user";
auto obj = json {
{"tool_response", {
{"tool", message.at("name")},
{"content", message.at("content")},
}},
};
if (message.contains("tool_call_id")) {
obj["tool_response"]["tool_call_id"] = message.at("tool_call_id");
}
message["content"] = obj.dump(2);
message.erase("name");
}

if (!message["content"].is_null() && !supports_system_role_) {
std::string content = message.at("content");
if (role == "system") {
if (!pending_system.empty()) pending_system += "\n";
pending_system += content;
continue;
} else {
if (role == "user") {
if (!pending_system.empty()) {
message["content"] = pending_system + (content.empty() ? "" : "\n" + content);
pending_system.clear();
}
} else {
flush_sys();
}
}
}
actual_messages.push_back(message);
}
flush_sys();
} else {
actual_messages = messages;
}

auto context = minja::Context::make(json({
{"messages", actual_messages},
{"add_generation_prompt", add_generation_prompt},
{"bos_token", bos_token_},
{"eos_token", eos_token_},
}));

if (!tools.is_null()) {
auto tools_val = minja::Value(tools);
context->set("tools", tools_val);
}
if (!extra_context.is_null()) {
for (auto & kv : extra_context.items()) {
minja::Value val(kv.value());
context->set(kv.key(), val);
}
}

return template_root_->render(context);
}
};

} // namespace minja
Loading
Loading