@@ -33,6 +33,29 @@ struct chat_template_caps {
33
33
bool requires_typed_content = false ;
34
34
};
35
35
36
+ struct chat_template_inputs {
37
+ nlohmann::ordered_json messages;
38
+ nlohmann::ordered_json tools;
39
+ bool add_generation_prompt = true ;
40
+ nlohmann::ordered_json extra_context;
41
+ std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
42
+ };
43
+
44
+ struct chat_template_options {
45
+ bool apply_polyfills = true ;
46
+ bool use_bos_token = true ;
47
+ bool use_eos_token = true ;
48
+ bool define_strftime_now = true ;
49
+
50
+ bool polyfill_tools = true ;
51
+ bool polyfill_tool_call_examples = true ;
52
+ bool polyfill_tool_calls = true ;
53
+ bool polyfill_tool_responses = true ;
54
+ bool polyfill_system_role = true ;
55
+ bool polyfill_object_arguments = true ;
56
+ bool polyfill_typed_content = true ;
57
+ };
58
+
36
59
class chat_template {
37
60
38
61
private:
@@ -41,6 +64,7 @@ class chat_template {
41
64
std::string bos_token_;
42
65
std::string eos_token_;
43
66
std::shared_ptr<minja::TemplateNode> template_root_;
67
+ std::string tool_call_example_;
44
68
45
69
std::string try_raw_render (
46
70
const nlohmann::ordered_json & messages,
@@ -49,7 +73,18 @@ class chat_template {
49
73
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const
50
74
{
51
75
try {
52
- auto prompt = apply (messages, tools, add_generation_prompt, extra_context, /* adjust_inputs= */ false );
76
+ chat_template_inputs inputs;
77
+ inputs.messages = messages;
78
+ inputs.tools = tools;
79
+ inputs.add_generation_prompt = add_generation_prompt;
80
+ inputs.extra_context = extra_context;
81
+ // Use fixed date for tests
82
+ inputs.now = std::chrono::system_clock::from_time_t (0 );
83
+
84
+ chat_template_options opts;
85
+ opts.apply_polyfills = false ;
86
+
87
+ auto prompt = apply (inputs, opts);
53
88
// fprintf(stderr, "try_raw_render: %s\n", prompt.c_str());
54
89
return prompt;
55
90
} catch (const std::exception & e) {
@@ -176,35 +211,131 @@ class chat_template {
176
211
caps_.supports_tool_responses = contains (out, " Some response!" );
177
212
caps_.supports_tool_call_id = contains (out, " call_911_" );
178
213
}
214
+
215
+ try {
216
+ if (!caps_.supports_tools ) {
217
+ const json user_msg {
218
+ {" role" , " user" },
219
+ {" content" , " Hey" },
220
+ };
221
+ const json args {
222
+ {" arg1" , " some_value" },
223
+ };
224
+ const json tool_call_msg {
225
+ {" role" , " assistant" },
226
+ {" content" , nullptr },
227
+ {" tool_calls" , json::array ({
228
+ {
229
+ // TODO: detect if requires numerical id or fixed length == 6 like Nemo
230
+ {" id" , " call_1___" },
231
+ {" type" , " function" },
232
+ {" function" , {
233
+ {" name" , " tool_name" },
234
+ {" arguments" , (caps_.requires_object_arguments ? args : json (minja::Value (args).dump (-1 , /* to_json= */ true )))},
235
+ }},
236
+ },
237
+ })},
238
+ };
239
+ std::string prefix, full;
240
+ {
241
+ chat_template_inputs inputs;
242
+ inputs.messages = json::array ({user_msg});
243
+ inputs.add_generation_prompt = true ;
244
+ prefix = apply (inputs);
245
+ }
246
+ {
247
+ chat_template_inputs inputs;
248
+ inputs.messages = json::array ({user_msg, tool_call_msg});
249
+ inputs.add_generation_prompt = false ;
250
+ full = apply (inputs);
251
+ }
252
+
253
+ if (full.find (prefix) != 0 ) {
254
+ if (prefix.rfind (eos_token_) == prefix.size () - eos_token_.size ()) {
255
+ prefix = prefix.substr (0 , prefix.size () - eos_token_.size ());
256
+ }
257
+ }
258
+ if (full.find (prefix) != 0 ) {
259
+ fprintf (stderr, " Failed to infer a tool call example (possible template bug)\n " );
260
+ }
261
+ tool_call_example_ = full.substr (prefix.size ());
262
+ }
263
+ } catch (const std::exception & e) {
264
+ fprintf (stderr, " Failed to generate tool call example: %s\n " , e.what ());
265
+ }
179
266
}
180
267
181
268
const std::string & source () const { return source_; }
182
269
const std::string & bos_token () const { return bos_token_; }
183
270
const std::string & eos_token () const { return eos_token_; }
184
271
const chat_template_caps & original_caps () const { return caps_; }
185
272
273
+ // Deprecated, please use the form with chat_template_inputs and chat_template_options
186
274
std::string apply (
187
275
const nlohmann::ordered_json & messages,
188
276
const nlohmann::ordered_json & tools,
189
277
bool add_generation_prompt,
190
278
const nlohmann::ordered_json & extra_context = nlohmann::ordered_json(),
191
- bool adjust_inputs = true) const
279
+ bool apply_polyfills = true)
280
+ {
281
+ fprintf (stderr, " [%s] Deprecated!\n " , __func__);
282
+ chat_template_inputs inputs;
283
+ inputs.messages = messages;
284
+ inputs.tools = tools;
285
+ inputs.add_generation_prompt = add_generation_prompt;
286
+ inputs.extra_context = extra_context;
287
+ inputs.now = std::chrono::system_clock::now ();
288
+
289
+ chat_template_options opts;
290
+ opts.apply_polyfills = apply_polyfills;
291
+
292
+ return apply (inputs, opts);
293
+ }
294
+
295
+ std::string apply (
296
+ const chat_template_inputs & inputs,
297
+ const chat_template_options & opts = chat_template_options()) const
192
298
{
193
299
json actual_messages;
194
300
195
- auto needs_adjustments = adjust_inputs && (false
196
- || !caps_.supports_system_role
197
- || !caps_.supports_tools
198
- || !caps_.supports_tool_responses
199
- || !caps_.supports_tool_calls
200
- || caps_.requires_object_arguments
201
- || caps_.requires_typed_content
301
+ auto has_tools = inputs.tools .is_array () && !inputs.tools .empty ();
302
+ auto has_tool_calls = false ;
303
+ auto has_tool_responses = false ;
304
+ auto has_string_content = false ;
305
+ for (const auto & message : inputs.messages ) {
306
+ if (message.contains (" tool_calls" ) && !message[" tool_calls" ].is_null ()) {
307
+ has_tool_calls = true ;
308
+ }
309
+ if (message.contains (" role" ) && message[" role" ] == " tool" ) {
310
+ has_tool_responses = true ;
311
+ }
312
+ if (message.contains (" content" ) && message[" content" ].is_string ()) {
313
+ has_string_content = true ;
314
+ }
315
+ }
316
+
317
+ auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role ;
318
+ auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools ;
319
+ auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples ;
320
+ auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls ;
321
+ auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses ;
322
+ auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments ;
323
+ auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content ;
324
+
325
+ auto needs_polyfills = opts.apply_polyfills && (false
326
+ || polyfill_system_role
327
+ || polyfill_tools
328
+ || polyfill_tool_calls
329
+ || polyfill_tool_responses
330
+ || polyfill_object_arguments
331
+ || polyfill_typed_content
202
332
);
203
- if (needs_adjustments) {
333
+
334
+ if (needs_polyfills) {
204
335
actual_messages = json::array ();
205
336
206
337
auto add_message = [&](const json & msg) {
207
- if (caps_. requires_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
338
+ if (polyfill_typed_content && msg.contains (" content" ) && !msg.at (" content" ).is_null () && msg.at (" content" ).is_string ()) {
208
339
actual_messages.push_back ({
209
340
{" role" , msg.at (" role" )},
210
341
{" content" , {{
@@ -227,17 +358,25 @@ class chat_template {
227
358
pending_system.clear ();
228
359
}
229
360
};
230
- auto needs_tools_in_system = !tools.is_null () && tools.size () > 0 && !caps_.supports_tools ;
231
361
232
- for (const auto & message_ : needs_tools_in_system ? add_system (messages, " Available tools: " + tools.dump (2 )) : messages) {
362
+ json adjusted_messages;
363
+ if (polyfill_tools) {
364
+ adjusted_messages = add_system (inputs.messages ,
365
+ " You can call any of the following tools to satisfy the user's requests: " + minja::Value (inputs.tools ).dump (2 , /* to_json= */ true ) +
366
+ (!polyfill_tool_call_example || tool_call_example_.empty () ? " " : " \n\n Example tool call syntax:\n\n " + tool_call_example_));
367
+ } else {
368
+ adjusted_messages = inputs.messages ;
369
+ }
370
+
371
+ for (const auto & message_ : adjusted_messages) {
233
372
auto message = message_;
234
373
if (!message.contains (" role" ) || !message.contains (" content" )) {
235
374
throw std::runtime_error (" message must have 'role' and 'content' fields: " + message.dump ());
236
375
}
237
376
std::string role = message.at (" role" );
238
377
239
378
if (message.contains (" tool_calls" )) {
240
- if (caps_. requires_object_arguments || !caps_. supports_tool_calls ) {
379
+ if (polyfill_object_arguments || polyfill_tool_calls ) {
241
380
for (auto & tool_call : message.at (" tool_calls" )) {
242
381
if (tool_call[" type" ] == " function" ) {
243
382
auto & function = tool_call.at (" function" );
@@ -252,7 +391,7 @@ class chat_template {
252
391
}
253
392
}
254
393
}
255
- if (!caps_. supports_tool_calls ) {
394
+ if (polyfill_tool_calls ) {
256
395
auto content = message.at (" content" );
257
396
auto tool_calls = json::array ();
258
397
for (const auto & tool_call : message.at (" tool_calls" )) {
@@ -279,7 +418,7 @@ class chat_template {
279
418
message.erase (" tool_calls" );
280
419
}
281
420
}
282
- if (!caps_. supports_tool_responses && role == " tool" ) {
421
+ if (polyfill_tool_responses && role == " tool" ) {
283
422
message[" role" ] = " user" ;
284
423
auto obj = json {
285
424
{" tool_response" , {
@@ -296,7 +435,7 @@ class chat_template {
296
435
message.erase (" name" );
297
436
}
298
437
299
- if (!message[" content" ].is_null () && !caps_. supports_system_role ) {
438
+ if (!message[" content" ].is_null () && polyfill_system_role ) {
300
439
std::string content = message.at (" content" );
301
440
if (role == " system" ) {
302
441
if (!pending_system.empty ()) pending_system += " \n " ;
@@ -315,28 +454,36 @@ class chat_template {
315
454
}
316
455
add_message (message);
317
456
}
318
- if (!caps_.supports_system_role ) {
319
- flush_sys ();
320
- }
457
+ flush_sys ();
321
458
} else {
322
- actual_messages = messages;
459
+ actual_messages = inputs. messages ;
323
460
}
324
461
325
462
auto context = minja::Context::make (json ({
326
463
{" messages" , actual_messages},
327
- {" add_generation_prompt" , add_generation_prompt},
328
- {" bos_token" , bos_token_},
329
- {" eos_token" , eos_token_},
464
+ {" add_generation_prompt" , inputs.add_generation_prompt },
330
465
}));
331
-
332
- if (!tools.is_null ()) {
333
- auto tools_val = minja::Value (tools);
334
- context->set (" tools" , tools_val);
466
+ context->set (" bos_token" , opts.use_bos_token ? bos_token_ : " " );
467
+ context->set (" eos_token" , opts.use_eos_token ? eos_token_ : " " );
468
+ if (opts.define_strftime_now ) {
469
+ auto now = inputs.now ;
470
+ context->set (" strftime_now" , Value::callable ([now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) {
471
+ args.expectArgs (" strftime_now" , {1 , 1 }, {0 , 0 });
472
+ auto format = args.args [0 ].get <std::string>();
473
+
474
+ auto time = std::chrono::system_clock::to_time_t (now);
475
+ auto local_time = *std::localtime (&time );
476
+ std::ostringstream ss;
477
+ ss << std::put_time (&local_time, format.c_str ());
478
+ return ss.str ();
479
+ }));
480
+ }
481
+ if (!inputs.tools .is_null ()) {
482
+ context->set (" tools" , minja::Value (inputs.tools ));
335
483
}
336
- if (!extra_context.is_null ()) {
337
- for (auto & kv : extra_context.items ()) {
338
- minja::Value val (kv.value ());
339
- context->set (kv.key (), val);
484
+ if (!inputs.extra_context .is_null ()) {
485
+ for (auto & kv : inputs.extra_context .items ()) {
486
+ context->set (kv.key (), minja::Value (kv.value ()));
340
487
}
341
488
}
342
489
@@ -353,7 +500,7 @@ class chat_template {
353
500
std::string existing_system = messages_with_system.at (0 ).at (" content" );
354
501
messages_with_system[0 ] = json {
355
502
{" role" , " system" },
356
- {" content" , existing_system + " \n " + system_prompt},
503
+ {" content" , existing_system + " \n\n " + system_prompt},
357
504
};
358
505
} else {
359
506
messages_with_system.insert (messages_with_system.begin (), json {
0 commit comments