Skip to content

Commit 9f4cc8f

Browse files
authored
1 parent fd08255 commit 9f4cc8f

File tree

6 files changed

+234
-59
lines changed

6 files changed

+234
-59
lines changed

common/chat-template.hpp

+180-33
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,29 @@ struct chat_template_caps {
3333
bool requires_typed_content = false;
3434
};
3535

36+
struct chat_template_inputs {
37+
nlohmann::ordered_json messages;
38+
nlohmann::ordered_json tools;
39+
bool add_generation_prompt = true;
40+
nlohmann::ordered_json extra_context;
41+
std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
42+
};
43+
44+
struct chat_template_options {
45+
bool apply_polyfills = true;
46+
bool use_bos_token = true;
47+
bool use_eos_token = true;
48+
bool define_strftime_now = true;
49+
50+
bool polyfill_tools = true;
51+
bool polyfill_tool_call_examples = true;
52+
bool polyfill_tool_calls = true;
53+
bool polyfill_tool_responses = true;
54+
bool polyfill_system_role = true;
55+
bool polyfill_object_arguments = true;
56+
bool polyfill_typed_content = true;
57+
};
58+
3659
class chat_template {
3760

3861
private:
@@ -41,6 +64,7 @@ class chat_template {
4164
std::string bos_token_;
4265
std::string eos_token_;
4366
std::shared_ptr<minja::TemplateNode> template_root_;
67+
std::string tool_call_example_;
4468

4569
std::string try_raw_render(
4670
const nlohmann::ordered_json & messages,
@@ -49,7 +73,18 @@ class chat_template {
4973
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
5074
{
5175
try {
52-
auto prompt = apply(messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false);
76+
chat_template_inputs inputs;
77+
inputs.messages = messages;
78+
inputs.tools = tools;
79+
inputs.add_generation_prompt = add_generation_prompt;
80+
inputs.extra_context = extra_context;
81+
// Use fixed date for tests
82+
inputs.now = std::chrono::system_clock::from_time_t(0);
83+
84+
chat_template_options opts;
85+
opts.apply_polyfills = false;
86+
87+
auto prompt = apply(inputs, opts);
5388
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
5489
return prompt;
5590
} catch (const std::exception & e) {
@@ -176,35 +211,131 @@ class chat_template {
176211
caps_.supports_tool_responses = contains(out, "Some response!");
177212
caps_.supports_tool_call_id = contains(out, "call_911_");
178213
}
214+
215+
try {
216+
if (!caps_.supports_tools) {
217+
const json user_msg {
218+
{"role", "user"},
219+
{"content", "Hey"},
220+
};
221+
const json args {
222+
{"arg1", "some_value"},
223+
};
224+
const json tool_call_msg {
225+
{"role", "assistant"},
226+
{"content", nullptr},
227+
{"tool_calls", json::array({
228+
{
229+
// TODO: detect if requires numerical id or fixed length == 6 like Nemo
230+
{"id", "call_1___"},
231+
{"type", "function"},
232+
{"function", {
233+
{"name", "tool_name"},
234+
{"arguments", (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(-1, /* to_json= */ true)))},
235+
}},
236+
},
237+
})},
238+
};
239+
std::string prefix, full;
240+
{
241+
chat_template_inputs inputs;
242+
inputs.messages = json::array({user_msg});
243+
inputs.add_generation_prompt = true;
244+
prefix = apply(inputs);
245+
}
246+
{
247+
chat_template_inputs inputs;
248+
inputs.messages = json::array({user_msg, tool_call_msg});
249+
inputs.add_generation_prompt = false;
250+
full = apply(inputs);
251+
}
252+
253+
if (full.find(prefix) != 0) {
254+
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
255+
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
256+
}
257+
}
258+
if (full.find(prefix) != 0) {
259+
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
260+
}
261+
tool_call_example_ = full.substr(prefix.size());
262+
}
263+
} catch (const std::exception & e) {
264+
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
265+
}
179266
}
180267

181268
const std::string & source() const { return source_; }
182269
const std::string & bos_token() const { return bos_token_; }
183270
const std::string & eos_token() const { return eos_token_; }
184271
const chat_template_caps & original_caps() const { return caps_; }
185272

273+
// Deprecated, please use the form with chat_template_inputs and chat_template_options
186274
std::string apply(
187275
const nlohmann::ordered_json & messages,
188276
const nlohmann::ordered_json & tools,
189277
bool add_generation_prompt,
190278
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
191-
bool adjust_inputs = true) const
279+
bool apply_polyfills = true)
280+
{
281+
fprintf(stderr, "[%s] Deprecated!\n", __func__);
282+
chat_template_inputs inputs;
283+
inputs.messages = messages;
284+
inputs.tools = tools;
285+
inputs.add_generation_prompt = add_generation_prompt;
286+
inputs.extra_context = extra_context;
287+
inputs.now = std::chrono::system_clock::now();
288+
289+
chat_template_options opts;
290+
opts.apply_polyfills = apply_polyfills;
291+
292+
return apply(inputs, opts);
293+
}
294+
295+
std::string apply(
296+
const chat_template_inputs & inputs,
297+
const chat_template_options & opts = chat_template_options()) const
192298
{
193299
json actual_messages;
194300

195-
auto needs_adjustments = adjust_inputs && (false
196-
|| !caps_.supports_system_role
197-
|| !caps_.supports_tools
198-
|| !caps_.supports_tool_responses
199-
|| !caps_.supports_tool_calls
200-
|| caps_.requires_object_arguments
201-
|| caps_.requires_typed_content
301+
auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
302+
auto has_tool_calls = false;
303+
auto has_tool_responses = false;
304+
auto has_string_content = false;
305+
for (const auto & message : inputs.messages) {
306+
if (message.contains("tool_calls") && !message["tool_calls"].is_null()) {
307+
has_tool_calls = true;
308+
}
309+
if (message.contains("role") && message["role"] == "tool") {
310+
has_tool_responses = true;
311+
}
312+
if (message.contains("content") && message["content"].is_string()) {
313+
has_string_content = true;
314+
}
315+
}
316+
317+
auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role;
318+
auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools;
319+
auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples;
320+
auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls;
321+
auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses;
322+
auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments;
323+
auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content;
324+
325+
auto needs_polyfills = opts.apply_polyfills && (false
326+
|| polyfill_system_role
327+
|| polyfill_tools
328+
|| polyfill_tool_calls
329+
|| polyfill_tool_responses
330+
|| polyfill_object_arguments
331+
|| polyfill_typed_content
202332
);
203-
if (needs_adjustments) {
333+
334+
if (needs_polyfills) {
204335
actual_messages = json::array();
205336

206337
auto add_message = [&](const json & msg) {
207-
if (caps_.requires_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
338+
if (polyfill_typed_content && msg.contains("content") && !msg.at("content").is_null() && msg.at("content").is_string()) {
208339
actual_messages.push_back({
209340
{"role", msg.at("role")},
210341
{"content", {{
@@ -227,17 +358,25 @@ class chat_template {
227358
pending_system.clear();
228359
}
229360
};
230-
auto needs_tools_in_system = !tools.is_null() && tools.size() > 0 && !caps_.supports_tools;
231361

232-
for (const auto & message_ : needs_tools_in_system ? add_system(messages, "Available tools: " + tools.dump(2)) : messages) {
362+
json adjusted_messages;
363+
if (polyfill_tools) {
364+
adjusted_messages = add_system(inputs.messages,
365+
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
366+
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
367+
} else {
368+
adjusted_messages = inputs.messages;
369+
}
370+
371+
for (const auto & message_ : adjusted_messages) {
233372
auto message = message_;
234373
if (!message.contains("role") || !message.contains("content")) {
235374
throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump());
236375
}
237376
std::string role = message.at("role");
238377

239378
if (message.contains("tool_calls")) {
240-
if (caps_.requires_object_arguments || !caps_.supports_tool_calls) {
379+
if (polyfill_object_arguments || polyfill_tool_calls) {
241380
for (auto & tool_call : message.at("tool_calls")) {
242381
if (tool_call["type"] == "function") {
243382
auto & function = tool_call.at("function");
@@ -252,7 +391,7 @@ class chat_template {
252391
}
253392
}
254393
}
255-
if (!caps_.supports_tool_calls) {
394+
if (polyfill_tool_calls) {
256395
auto content = message.at("content");
257396
auto tool_calls = json::array();
258397
for (const auto & tool_call : message.at("tool_calls")) {
@@ -279,7 +418,7 @@ class chat_template {
279418
message.erase("tool_calls");
280419
}
281420
}
282-
if (!caps_.supports_tool_responses && role == "tool") {
421+
if (polyfill_tool_responses && role == "tool") {
283422
message["role"] = "user";
284423
auto obj = json {
285424
{"tool_response", {
@@ -296,7 +435,7 @@ class chat_template {
296435
message.erase("name");
297436
}
298437

299-
if (!message["content"].is_null() && !caps_.supports_system_role) {
438+
if (!message["content"].is_null() && polyfill_system_role) {
300439
std::string content = message.at("content");
301440
if (role == "system") {
302441
if (!pending_system.empty()) pending_system += "\n";
@@ -315,28 +454,36 @@ class chat_template {
315454
}
316455
add_message(message);
317456
}
318-
if (!caps_.supports_system_role) {
319-
flush_sys();
320-
}
457+
flush_sys();
321458
} else {
322-
actual_messages = messages;
459+
actual_messages = inputs.messages;
323460
}
324461

325462
auto context = minja::Context::make(json({
326463
{"messages", actual_messages},
327-
{"add_generation_prompt", add_generation_prompt},
328-
{"bos_token", bos_token_},
329-
{"eos_token", eos_token_},
464+
{"add_generation_prompt", inputs.add_generation_prompt},
330465
}));
331-
332-
if (!tools.is_null()) {
333-
auto tools_val = minja::Value(tools);
334-
context->set("tools", tools_val);
466+
context->set("bos_token", opts.use_bos_token ? bos_token_ : "");
467+
context->set("eos_token", opts.use_eos_token ? eos_token_ : "");
468+
if (opts.define_strftime_now) {
469+
auto now = inputs.now;
470+
context->set("strftime_now", Value::callable([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
471+
args.expectArgs("strftime_now", {1, 1}, {0, 0});
472+
auto format = args.args[0].get<std::string>();
473+
474+
auto time = std::chrono::system_clock::to_time_t(now);
475+
auto local_time = *std::localtime(&time);
476+
std::ostringstream ss;
477+
ss << std::put_time(&local_time, format.c_str());
478+
return ss.str();
479+
}));
480+
}
481+
if (!inputs.tools.is_null()) {
482+
context->set("tools", minja::Value(inputs.tools));
335483
}
336-
if (!extra_context.is_null()) {
337-
for (auto & kv : extra_context.items()) {
338-
minja::Value val(kv.value());
339-
context->set(kv.key(), val);
484+
if (!inputs.extra_context.is_null()) {
485+
for (auto & kv : inputs.extra_context.items()) {
486+
context->set(kv.key(), minja::Value(kv.value()));
340487
}
341488
}
342489

@@ -353,7 +500,7 @@ class chat_template {
353500
std::string existing_system = messages_with_system.at(0).at("content");
354501
messages_with_system[0] = json {
355502
{"role", "system"},
356-
{"content", existing_system + "\n" + system_prompt},
503+
{"content", existing_system + "\n\n" + system_prompt},
357504
};
358505
} else {
359506
messages_with_system.insert(messages_with_system.begin(), json {

0 commit comments

Comments
 (0)