Skip to content

Commit d7b31a9

Browse files
authored
1 parent 9ac3457 commit d7b31a9

File tree

2 files changed

+45
-16
lines changed

2 files changed

+45
-16
lines changed

common/chat-template.hpp

+21-7
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,30 @@ class chat_template {
249249
inputs.add_generation_prompt = false;
250250
full = apply(inputs);
251251
}
252-
253-
if (full.find(prefix) != 0) {
254-
if (prefix.rfind(eos_token_) == prefix.size() - eos_token_.size()) {
255-
prefix = prefix.substr(0, prefix.size() - eos_token_.size());
252+
auto eos_pos_last = full.rfind(eos_token_);
253+
if (eos_pos_last == prefix.size() - eos_token_.size() ||
254+
(full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) {
255+
full = full.substr(0, eos_pos_last);
256+
}
257+
size_t common_prefix_length = 0;
258+
for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) {
259+
if (prefix[i] != full[i]) {
260+
break;
256261
}
262+
if (prefix[i] == '<') {
263+
// DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt,
264+
// but it removes thinking tags for past messages.
265+
// The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <.
266+
continue;
267+
}
268+
common_prefix_length = i + 1;
257269
}
258-
if (full.find(prefix) != 0) {
270+
auto example = full.substr(common_prefix_length);
271+
if (example.find("tool_name") == std::string::npos && example.find("some_value") == std::string::npos) {
259272
fprintf(stderr, "Failed to infer a tool call example (possible template bug)\n");
273+
} else {
274+
tool_call_example_ = example;
260275
}
261-
tool_call_example_ = full.substr(prefix.size());
262276
}
263277
} catch (const std::exception & e) {
264278
fprintf(stderr, "Failed to generate tool call example: %s\n", e.what());
@@ -363,7 +377,7 @@ class chat_template {
363377
if (polyfill_tools) {
364378
adjusted_messages = add_system(inputs.messages,
365379
"You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(2, /* to_json= */ true) +
366-
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_));
380+
(!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n"));
367381
} else {
368382
adjusted_messages = inputs.messages;
369383
}

common/minja.hpp

+24-9
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,13 @@ static std::string strip(const std::string & s) {
13851385
return s.substr(start, end - start + 1);
13861386
}
13871387

1388+
static std::string capitalize(const std::string & s) {
1389+
if (s.empty()) return s;
1390+
auto result = s;
1391+
result[0] = std::toupper(result[0]);
1392+
return result;
1393+
}
1394+
13881395
static std::string html_escape(const std::string & s) {
13891396
std::string result;
13901397
result.reserve(s.size());
@@ -1462,6 +1469,9 @@ class MethodCallExpr : public Expression {
14621469
if (method->get_name() == "strip") {
14631470
vargs.expectArgs("strip method", {0, 0}, {0, 0});
14641471
return Value(strip(str));
1472+
} else if (method->get_name() == "capitalize") {
1473+
vargs.expectArgs("capitalize method", {0, 0}, {0, 0});
1474+
return Value(capitalize(str));
14651475
} else if (method->get_name() == "endswith") {
14661476
vargs.expectArgs("endswith method", {1, 1}, {0, 0});
14671477
auto suffix = vargs.args[0].get<std::string>();
@@ -1792,7 +1802,7 @@ class Parser {
17921802
auto left = parseStringConcat();
17931803
if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression");
17941804

1795-
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)");
1805+
static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not\s+in\b)");
17961806
static std::regex not_tok(R"(not\b)");
17971807
std::string op_str;
17981808
while (!(op_str = consumeToken(compare_tok)).empty()) {
@@ -2171,7 +2181,7 @@ class Parser {
21712181
using TemplateTokenIterator = TemplateTokenVector::const_iterator;
21722182

21732183
std::vector<std::string> parseVarNames() {
2174-
static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)");
2184+
static std::regex varnames_regex(R"(((?:\w+)(?:\s*,\s*(?:\w+))*)\s*)");
21752185

21762186
std::vector<std::string> group;
21772187
if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names");
@@ -2194,13 +2204,13 @@ class Parser {
21942204
}
21952205

21962206
TemplateTokenVector tokenize() {
2197-
static std::regex comment_tok(R"(\{#([-~]?)([\s\S\r\n]*?)([-~]?)#\})");
2207+
static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})");
21982208
static std::regex expr_open_regex(R"(\{\{([-~])?)");
2199-
static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)");
2209+
static std::regex block_open_regex(R"(^\{%([-~])?\s*)");
22002210
static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)");
22012211
static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)");
2202-
static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})");
2203-
static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})");
2212+
static std::regex expr_close_regex(R"(\s*([-~])?\}\})");
2213+
static std::regex block_close_regex(R"(\s*([-~])?%\})");
22042214

22052215
TemplateTokenVector tokens;
22062216
std::vector<std::string> group;
@@ -2284,7 +2294,7 @@ class Parser {
22842294
auto post_space = parseBlockClose();
22852295
tokens.push_back(std::make_unique<EndGenerationTemplateToken>(location, pre_space, post_space));
22862296
} else if (keyword == "set") {
2287-
static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))");
2297+
static std::regex namespaced_var_regex(R"((\w+)\s*\.\s*(\w+))");
22882298

22892299
std::string ns;
22902300
std::vector<std::string> var_names;
@@ -2336,6 +2346,11 @@ class Parser {
23362346
throw std::runtime_error("Unexpected block: " + keyword);
23372347
}
23382348
} else if (std::regex_search(it, end, match, non_text_open_regex)) {
2349+
if (!match.position()) {
2350+
if (match[0] != "{#")
2351+
throw std::runtime_error("Internal error: Expected a comment");
2352+
throw std::runtime_error("Missing end of comment tag");
2353+
}
23392354
auto text_end = it + match.position();
23402355
text = std::string(it, text_end);
23412356
it = text_end;
@@ -2400,7 +2415,7 @@ class Parser {
24002415

24012416
auto text = text_token->text;
24022417
if (post_space == SpaceHandling::Strip) {
2403-
static std::regex trailing_space_regex(R"((\s|\r|\n)+$)");
2418+
static std::regex trailing_space_regex(R"(\s+$)");
24042419
text = std::regex_replace(text, trailing_space_regex, "");
24052420
} else if (options.lstrip_blocks && it != end) {
24062421
auto i = text.size();
@@ -2410,7 +2425,7 @@ class Parser {
24102425
}
24112426
}
24122427
if (pre_space == SpaceHandling::Strip) {
2413-
static std::regex leading_space_regex(R"(^(\s|\r|\n)+)");
2428+
static std::regex leading_space_regex(R"(^\s+)");
24142429
text = std::regex_replace(text, leading_space_regex, "");
24152430
} else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast<ExpressionTemplateToken*>((*(it - 2)).get())) {
24162431
if (text.length() > 0 && text[0] == '\n') {

0 commit comments

Comments
 (0)