From abd274a48f381cb3f790025685218cc8272b97c7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 03:21:44 +0000 Subject: [PATCH 01/42] Copy minja from https://github.com/google/minja/commit/58f0ca6dd74bcbfbd4e71229736640322b31c7f9 --- common/chat-template.hpp | 247 ++++ common/minja.hpp | 2758 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 3005 insertions(+) create mode 100644 common/chat-template.hpp create mode 100644 common/minja.hpp diff --git a/common/chat-template.hpp b/common/chat-template.hpp new file mode 100644 index 0000000000000..302a173c29d95 --- /dev/null +++ b/common/chat-template.hpp @@ -0,0 +1,247 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include "minja.hpp" +#include +#include +#include + +using json = nlohmann::ordered_json; + +namespace minja { + +class chat_template { + public: + + private: + bool supports_tools_ = true; + // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool requires_object_arguments_ = false; + bool supports_system_role_ = true; + bool supports_parallel_tool_calls_ = false; + std::string source_; + std::string bos_token_; + std::string eos_token_; + std::shared_ptr template_root_; + + std::string try_render( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + try { + auto prompt = apply(messages, tools, add_generation_prompt, extra_context); + // fprintf(stderr, "Prompt: %s\n", prompt.c_str()); + return prompt; + } catch (const std::exception & e) { + // fprintf(stderr, "Error: %s\n", e.what()); + return ""; + } + } + + public: + chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) + : source_(source), bos_token_(bos_token), eos_token_(eos_token) + { + template_root_ = minja::Parser::parse(source_, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + supports_tools_ = source.find("tools") != std::string::npos; + + auto renders_string_arguments = + try_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", "{\"code\": \"print('Hello, World!')\"}"}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos; + if (!renders_string_arguments) { + auto renders_object_arguments = + try_render({ + { + {"role", "user"}, + {"content", "Hey"} + }, + { + {"role", "assistant"}, + {"tool_calls", json::array({ + { + {"id", "call_1___"}, + {"type", "function"}, + {"function", { + {"arguments", { + {"code", "print('Hello, World!')"}, + }}, + {"name", "ipython"}, + }}, + }, + })}, + } + }, {}, false).find("{\"code\": \"print") != std::string::npos; + requires_object_arguments_ = renders_object_arguments; + } + supports_parallel_tool_calls_ = source.find("tool_call_id") != std::string::npos; + + supports_system_role_ = try_render({ + {{"role", "system"}, {"content", ""}}, + {{"role", "user"}, {"content", "Hey"}} + }, {}, false).find("") != std::string::npos; + } + + const std::string & source() const { return source_; } + bool supports_tools() const { return supports_tools_; } + bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } + + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt, + const nlohmann::ordered_json & extra_context = nlohmann::ordered_json()) const + { + json actual_messages; + + // First, "fix" messages so they have a chance to be rendered correctly by the template + + if (requires_object_arguments_ || !supports_system_role_ || !supports_tools_) { + actual_messages = json::array(); + + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + actual_messages.push_back({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + for (const auto & message_ : messages) { + auto message = message_; + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + } + std::string role = message.at("role"); + + if (message.contains("tool_calls")) { + if (requires_object_arguments_ || !supports_tools_) { + for (auto & tool_call : message.at("tool_calls")) { + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + std::string arguments = function.at("arguments"); + function["arguments"] = json::parse(arguments); + } + } + } + if (!supports_tools_) { + auto content = message.at("content"); + auto tool_calls = json::array(); + for (const auto & tool_call : message.at("tool_calls")) { + if (tool_call.at("type") != "function") { + continue; + } + const auto & function = tool_call.at("function"); + auto tc = json { + {"name", function.at("name")}, + {"arguments", function.at("arguments")}, + }; + if (tool_call.contains("id")) { + tc["id"] = tool_call["id"]; + } + tool_calls.push_back(tc); + } + auto obj = json { + {"tool_calls", tool_calls}, + }; + if (!content.is_null() && content != "") { + obj["content"] = content; + } + message["content"] = obj.dump(2); + message.erase("tool_calls"); + } + } + if (!supports_tools_ && role == "tool") { + message["role"] = "user"; + auto obj = json { + {"tool_response", { + {"tool", message.at("name")}, + {"content", message.at("content")}, + }}, + }; + if (message.contains("tool_call_id")) { + obj["tool_response"]["tool_call_id"] = message.at("tool_call_id"); + } + message["content"] = obj.dump(2); + message.erase("name"); + } + + if (!message["content"].is_null() && !supports_system_role_) { + std::string content = message.at("content"); + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + actual_messages.push_back(message); + } + flush_sys(); + } else { + actual_messages = messages; + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", add_generation_prompt}, + {"bos_token", bos_token_}, + {"eos_token", eos_token_}, + })); + + if (!tools.is_null()) { + auto tools_val = minja::Value(tools); + context->set("tools", tools_val); + } + if (!extra_context.is_null()) { + for (auto & kv : extra_context.items()) { + minja::Value val(kv.value()); + context->set(kv.key(), val); + } + } + + return template_root_->render(context); + } +}; + +} // namespace minja diff --git a/common/minja.hpp b/common/minja.hpp new file mode 100644 index 0000000000000..9d9a1a08faf4d --- /dev/null +++ b/common/minja.hpp @@ -0,0 +1,2758 @@ +/* + Copyright 2024 Google LLC + + Use of this source code is governed by an MIT-style + license that can be found in the LICENSE file or at + https://opensource.org/licenses/MIT. +*/ +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define ENDL "\r\n" +#else +#define ENDL "\n" +#endif + +using json = nlohmann::ordered_json; + +namespace minja { + +class Context; + +struct Options { + bool trim_blocks; // removes the first newline after a block + bool lstrip_blocks; // removes leading whitespace on the line of the block + bool keep_trailing_newline; // don't remove last newline +}; + +struct ArgumentsValue; + +static std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} + +/* Values that behave roughly like in Python. */ +class Value : public std::enable_shared_from_this { +public: + using CallableType = std::function &, ArgumentsValue &)>; + using FilterType = std::function &, ArgumentsValue &)>; + +private: + using ObjectType = nlohmann::ordered_map; // Only contains primitive keys + using ArrayType = std::vector; + + std::shared_ptr array_; + std::shared_ptr object_; + std::shared_ptr callable_; + json primitive_; + + Value(const std::shared_ptr & array) : array_(array) {} + Value(const std::shared_ptr & object) : object_(object) {} + Value(const std::shared_ptr & callable) : object_(std::make_shared()), callable_(callable) {} + + /* Python-style string repr */ + static void dump_string(const json & primitive, std::ostringstream & out, char string_quote = '\'') { + if (!primitive.is_string()) throw std::runtime_error("Value is not a string: " + primitive.dump()); + auto s = primitive.dump(); + if (string_quote == '"' || s.find('\'') != std::string::npos) { + out << s; + return; + } + // Reuse json dump, just changing string quotes + out << string_quote; + for (size_t i = 1, n = s.size() - 1; i < n; ++i) { + if (s[i] == '\\' && s[i + 1] == '"') { + out << '"'; + i++; + } else if (s[i] == string_quote) { + out << '\\' << string_quote; + } else { + out << s[i]; + } + } + out << string_quote; + } + void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { + auto print_indent = [&](int level) { + if (indent > 0) { + out << ENDL; + for (int i = 0, n = level * indent; i < n; ++i) out << ' '; + } + }; + auto print_sub_sep = [&]() { + out << ','; + if (indent < 0) out << ' '; + else print_indent(level + 1); + }; + + auto string_quote = to_json ? '"' : '\''; + + if (is_null()) out << "null"; + else if (array_) { + out << "["; + print_indent(level + 1); + for (size_t i = 0; i < array_->size(); ++i) { + if (i) print_sub_sep(); + (*array_)[i].dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "]"; + } else if (object_) { + out << "{"; + print_indent(level + 1); + for (auto begin = object_->begin(), it = begin; it != object_->end(); ++it) { + if (it != begin) print_sub_sep(); + if (it->first.is_string()) { + dump_string(it->first, out, string_quote); + } else { + out << string_quote << it->first.dump() << string_quote; + } + out << ": "; + it->second.dump(out, indent, level + 1, to_json); + } + print_indent(level); + out << "}"; + } else if (callable_) { + throw std::runtime_error("Cannot dump callable to JSON"); + } else if (is_boolean() && !to_json) { + out << (this->to_bool() ? "True" : "False"); + } else if (is_string() && !to_json) { + dump_string(primitive_, out, string_quote); + } else { + out << primitive_.dump(); + } + } + +public: + Value() {} + Value(const bool& v) : primitive_(v) {} + Value(const int64_t & v) : primitive_(v) {} + Value(const double& v) : primitive_(v) {} + Value(const std::nullptr_t &) {} + Value(const std::string & v) : primitive_(v) {} + Value(const char * v) : primitive_(std::string(v)) {} + + Value(const json & v) { + if (v.is_object()) { + auto object = std::make_shared(); + for (auto it = v.begin(); it != v.end(); ++it) { + (*object)[it.key()] = it.value(); + } + object_ = std::move(object); + } else if (v.is_array()) { + auto array = std::make_shared(); + for (const auto& item : v) { + array->push_back(Value(item)); + } + array_ = array; + } else { + primitive_ = v; + } + } + + std::vector keys() { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + std::vector res; + for (const auto& item : *object_) { + res.push_back(item.first); + } + return res; + } + + size_t size() const { + if (is_object()) return object_->size(); + if (is_array()) return array_->size(); + if (is_string()) return primitive_.get().length(); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + static Value array(const std::vector values = {}) { + auto array = std::make_shared(); + for (const auto& item : values) { + array->push_back(item); + } + return Value(array); + } + static Value object(const std::shared_ptr object = std::make_shared()) { + return Value(object); + } + static Value callable(const CallableType & callable) { + return Value(std::make_shared(callable)); + } + + void insert(size_t index, const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->insert(array_->begin() + index, v); + } + void push_back(const Value& v) { + if (!array_) + throw std::runtime_error("Value is not an array: " + dump()); + array_->push_back(v); + } + Value get(const Value& key) { + if (array_) { + if (!key.is_number_integer()) { + return Value(); + } + auto index = key.get(); + return array_->at(index < 0 ? array_->size() + index : index); + } else if (object_) { + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + auto it = object_->find(key.primitive_); + if (it == object_->end()) return Value(); + return it->second; + } + return Value(); + } + void set(const Value& key, const Value& value) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + (*object_)[key.primitive_] = value; + } + Value call(const std::shared_ptr & context, ArgumentsValue & args) const { + if (!callable_) throw std::runtime_error("Value is not callable: " + dump()); + return (*callable_)(context, args); + } + + bool is_object() const { return !!object_; } + bool is_array() const { return !!array_; } + bool is_callable() const { return !!callable_; } + bool is_null() const { return !object_ && !array_ && primitive_.is_null() && !callable_; } + bool is_boolean() const { return primitive_.is_boolean(); } + bool is_number_integer() const { return primitive_.is_number_integer(); } + bool is_number_float() const { return primitive_.is_number_float(); } + bool is_number() const { return primitive_.is_number(); } + bool is_string() const { return primitive_.is_string(); } + bool is_iterable() const { return is_array() || is_object() || is_string(); } + + bool is_primitive() const { return !array_ && !object_ && !callable_; } + bool is_hashable() const { return is_primitive(); } + + bool empty() const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_string()) return primitive_.empty(); + if (is_array()) return array_->empty(); + if (is_object()) return object_->empty(); + return false; + } + + void for_each(const std::function & callback) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (auto& item : *array_) { + callback(item); + } + } else if (object_) { + for (auto & item : *object_) { + Value key(item.first); + callback(key); + } + } else if (is_string()) { + for (char c : primitive_.get()) { + auto val = Value(std::string(1, c)); + callback(val); + } + } else { + throw std::runtime_error("Value is not iterable: " + dump()); + } + } + + bool to_bool() const { + if (is_null()) return false; + if (is_boolean()) return get(); + if (is_number()) return get() != 0; + if (is_string()) return !get().empty(); + if (is_array()) return !empty(); + return true; + } + + int64_t to_int() const { + if (is_null()) return 0; + if (is_boolean()) return get() ? 1 : 0; + if (is_number()) return static_cast(get()); + if (is_string()) { + try { + return std::stol(get()); + } catch (const std::exception &) { + return 0; + } + } + return 0; + } + + bool operator<(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() < other.get(); + if (is_string() && other.is_string()) return get() < other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " < " + other.dump()); + } + bool operator>=(const Value & other) const { return !(*this < other); } + + bool operator>(const Value & other) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_number() && other.is_number()) return get() > other.get(); + if (is_string() && other.is_string()) return get() > other.get(); + throw std::runtime_error("Cannot compare values: " + dump() + " > " + other.dump()); + } + bool operator<=(const Value & other) const { return !(*this > other); } + + bool operator==(const Value & other) const { + if (callable_ || other.callable_) { + if (callable_.get() != other.callable_.get()) return false; + } + if (array_) { + if (!other.array_) return false; + if (array_->size() != other.array_->size()) return false; + for (size_t i = 0; i < array_->size(); ++i) { + if (!(*array_)[i].to_bool() || !(*other.array_)[i].to_bool() || (*array_)[i] != (*other.array_)[i]) return false; + } + return true; + } else if (object_) { + if (!other.object_) return false; + if (object_->size() != other.object_->size()) return false; + for (const auto& item : *object_) { + if (!item.second.to_bool() || !other.object_->count(item.first) || item.second != other.object_->at(item.first)) return false; + } + return true; + } else { + return primitive_ == other.primitive_; + } + } + bool operator!=(const Value & other) const { return !(*this == other); } + + bool contains(const char * key) const { return contains(std::string(key)); } + bool contains(const std::string & key) const { + if (array_) { + return false; + } else if (object_) { + return object_->find(key) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + bool contains(const Value & value) const { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (array_) { + for (const auto& item : *array_) { + if (item.to_bool() && item == value) return true; + } + return false; + } else if (object_) { + if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + return object_->find(value.primitive_) != object_->end(); + } else { + throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); + } + } + void erase(size_t index) { + if (array_) throw std::runtime_error("Value is not an array: " + dump()); + array_->erase(array_->begin() + index); + } + void erase(const std::string & key) { + if (object_) throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); + } + const Value& at(const Value & index) const { + return const_cast(this)->at(index); + } + Value& at(const Value & index) { + if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (is_array()) return array_->at(index.get()); + if (is_object()) return object_->at(index.primitive_); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + const Value& at(size_t index) const { + return const_cast(this)->at(index); + } + Value& at(size_t index) { + if (is_null()) + throw std::runtime_error("Undefined value or reference"); + if (is_array()) return array_->at(index); + if (is_object()) return object_->at(index); + throw std::runtime_error("Value is not an array or object: " + dump()); + } + + template + T get(const std::string & key, T default_value) const { + if (!contains(key)) return default_value; + return at(key).get(); + } + + template + T get() const { + if (is_primitive()) return primitive_.get(); + throw std::runtime_error("get not defined for this value type: " + dump()); + } + + std::string dump(int indent=-1, bool to_json=false) const { + std::ostringstream out; + dump(out, indent, 0, to_json); + return out.str(); + } + + Value operator-() const { + if (is_number_integer()) + return -get(); + else + return -get(); + } + std::string to_str() const { + if (is_string()) return get(); + if (is_number_integer()) return std::to_string(get()); + if (is_number_float()) return std::to_string(get()); + if (is_boolean()) return get() ? "True" : "False"; + if (is_null()) return "None"; + return dump(); + } + Value operator+(const Value& rhs) const { + if (is_string() || rhs.is_string()) { + return to_str() + rhs.to_str(); + } else if (is_number_integer() && rhs.is_number_integer()) { + return get() + rhs.get(); + } else if (is_array() && rhs.is_array()) { + auto res = Value::array(); + for (const auto& item : *array_) res.push_back(item); + for (const auto& item : *rhs.array_) res.push_back(item); + return res; + } else { + return get() + rhs.get(); + } + } + Value operator-(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() - rhs.get(); + else + return get() - rhs.get(); + } + Value operator*(const Value& rhs) const { + if (is_string() && rhs.is_number_integer()) { + std::ostringstream out; + for (int64_t i = 0, n = rhs.get(); i < n; ++i) { + out << to_str(); + } + return out.str(); + } + else if (is_number_integer() && rhs.is_number_integer()) + return get() * rhs.get(); + else + return get() * rhs.get(); + } + Value operator/(const Value& rhs) const { + if (is_number_integer() && rhs.is_number_integer()) + return get() / rhs.get(); + else + return get() / rhs.get(); + } + Value operator%(const Value& rhs) const { + return get() % rhs.get(); + } +}; + +struct ArgumentsValue { + std::vector args; + std::vector> kwargs; + + bool has_named(const std::string & name) { + for (const auto & p : kwargs) { + if (p.first == name) return true; + } + return false; + } + + Value get_named(const std::string & name) { + for (const auto & [key, value] : kwargs) { + if (key == name) return value; + } + return Value(); + } + + bool empty() { + return args.empty() && kwargs.empty(); + } + + void expectArgs(const std::string & method_name, const std::pair & pos_count, const std::pair & kw_count) { + if (args.size() < pos_count.first || args.size() > pos_count.second || kwargs.size() < kw_count.first || kwargs.size() > kw_count.second) { + std::ostringstream out; + out << method_name << " must have between " << pos_count.first << " and " << pos_count.second << " positional arguments and between " << kw_count.first << " and " << kw_count.second << " keyword arguments"; + throw std::runtime_error(out.str()); + } + } +}; + +template <> +inline json Value::get() const { + if (is_primitive()) return primitive_; + if (is_null()) return json(); + if (array_) { + std::vector res; + for (const auto& item : *array_) { + res.push_back(item.get()); + } + return res; + } + if (object_) { + json res = json::object(); + for (const auto& [key, value] : *object_) { + if (key.is_string()) { + res[key.get()] = value.get(); + } else if (key.is_primitive()) { + res[key.dump()] = value.get(); + } else { + throw std::runtime_error("Invalid key type for conversion to JSON: " + key.dump()); + } + } + if (is_callable()) { + res["__callable__"] = true; + } + return res; + } + throw std::runtime_error("get not defined for this value type: " + dump()); +} + +} // namespace minja + +namespace std { + template <> + struct hash { + size_t operator()(const minja::Value & v) const { + if (!v.is_hashable()) + throw std::runtime_error("Unsupported type for hashing: " + v.dump()); + return std::hash()(v.get()); + } + }; +} // namespace std + +namespace minja { + +static std::string error_location_suffix(const std::string & source, size_t pos) { + auto get_line = [&](size_t line) { + auto start = source.begin(); + for (size_t i = 1; i < line; ++i) { + start = std::find(start, source.end(), '\n') + 1; + } + auto end = std::find(start, source.end(), '\n'); + return std::string(start, end); + }; + auto start = source.begin(); + auto end = source.end(); + auto it = start + pos; + auto line = std::count(start, it, '\n') + 1; + auto max_line = std::count(start, end, '\n') + 1; + auto col = pos - std::string(start, it).rfind('\n'); + std::ostringstream out; + out << " at row " << line << ", column " << col << ":" ENDL; + if (line > 1) out << get_line(line - 1) << ENDL; + out << get_line(line) << ENDL; + out << std::string(col - 1, ' ') << "^" << ENDL; + if (line < max_line) out << get_line(line + 1) << ENDL; + + return out.str(); +} + +class Context : public std::enable_shared_from_this { + protected: + Value values_; + std::shared_ptr parent_; + public: + Context(Value && values, const std::shared_ptr & parent = nullptr) : values_(std::move(values)), parent_(parent) { + if (!values_.is_object()) throw std::runtime_error("Context values must be an object: " + values_.dump()); + } + virtual ~Context() {} + + static std::shared_ptr builtins(); + static std::shared_ptr make(Value && values, const std::shared_ptr & parent = builtins()); + + std::vector keys() { + return values_.keys(); + } + virtual Value get(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->get(key); + return Value(); + } + virtual Value & at(const Value & key) { + if (values_.contains(key)) return values_.at(key); + if (parent_) return parent_->at(key); + throw std::runtime_error("Undefined variable: " + key.dump()); + } + virtual bool contains(const Value & key) { + if (values_.contains(key)) return true; + if (parent_) return parent_->contains(key); + return false; + } + virtual void set(const Value & key, Value & value) { + values_.set(key, value); + } +}; + +struct Location { + std::shared_ptr source; + size_t pos; +}; + +class Expression { +protected: + virtual Value do_evaluate(const std::shared_ptr & context) const = 0; +public: + using Parameters = std::vector>>; + + Location location; + + Expression(const Location & location) : location(location) {} + virtual ~Expression() = default; + + Value evaluate(const std::shared_ptr & context) const { + try { + return do_evaluate(context); + } catch (const std::exception & e) { + std::ostringstream out; + out << e.what(); + if (location.source) out << error_location_suffix(*location.source, location.pos); + throw std::runtime_error(out.str()); + } + } +}; + +class VariableExpr : public Expression { + std::string name; +public: + VariableExpr(const Location & location, const std::string& n) + : Expression(location), name(n) {} + std::string get_name() const { return name; } + Value do_evaluate(const std::shared_ptr & context) const override { + if (!context->contains(name)) { + return Value(); + } + return context->at(name); + } +}; + +static void destructuring_assign(const std::vector & var_names, const std::shared_ptr & context, Value& item) { + if (var_names.size() == 1) { + Value name(var_names[0]); + context->set(name, item); + } else { + if (!item.is_array() || item.size() != var_names.size()) { + throw std::runtime_error("Mismatched number of variables and items in destructuring assignment"); + } + for (size_t i = 0; i < var_names.size(); ++i) { + context->set(var_names[i], item.at(i)); + } + } +} + +enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; + +class TemplateToken { +public: + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter }; + + static std::string typeToString(Type t) { + switch (t) { + case Type::Text: return "text"; + case Type::Expression: return "expression"; + case Type::If: return "if"; + case Type::Else: return "else"; + case Type::Elif: return "elif"; + case Type::EndIf: return "endif"; + case Type::For: return "for"; + case Type::EndFor: return "endfor"; + case Type::Set: return "set"; + case Type::EndSet: return "endset"; + case Type::Comment: return "comment"; + case Type::Macro: return "macro"; + case Type::EndMacro: return "endmacro"; + case Type::Filter: return "filter"; + case Type::EndFilter: return "endfilter"; + } + return "Unknown"; + } + + TemplateToken(Type type, const Location & location, SpaceHandling pre, SpaceHandling post) : type(type), location(location), pre_space(pre), post_space(post) {} + virtual ~TemplateToken() = default; + + Type type; + Location location; + SpaceHandling pre_space = SpaceHandling::Keep; + SpaceHandling post_space = SpaceHandling::Keep; +}; + +struct TextTemplateToken : public TemplateToken { + std::string text; + TextTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Text, location, pre, post), text(t) {} +}; + +struct ExpressionTemplateToken : public TemplateToken { + std::shared_ptr expr; + ExpressionTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) : TemplateToken(Type::Expression, location, pre, post), expr(std::move(e)) {} +}; + +struct IfTemplateToken : public TemplateToken { + std::shared_ptr condition; + IfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::If, location, pre, post), condition(std::move(c)) {} +}; + +struct ElifTemplateToken : public TemplateToken { + std::shared_ptr condition; + ElifTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && c) : TemplateToken(Type::Elif, location, pre, post), condition(std::move(c)) {} +}; + +struct ElseTemplateToken : public TemplateToken { + ElseTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::Else, location, pre, post) {} +}; + +struct EndIfTemplateToken : public TemplateToken { + EndIfTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndIf, location, pre, post) {} +}; + +struct MacroTemplateToken : public TemplateToken { + std::shared_ptr name; + Expression::Parameters params; + MacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && n, Expression::Parameters && p) + : TemplateToken(Type::Macro, location, pre, post), name(std::move(n)), params(std::move(p)) {} +}; + +struct EndMacroTemplateToken : public TemplateToken { + EndMacroTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndMacro, location, pre, post) {} +}; + +struct FilterTemplateToken : public TemplateToken { + std::shared_ptr filter; + FilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, std::shared_ptr && filter) + : TemplateToken(Type::Filter, location, pre, post), filter(std::move(filter)) {} +}; + +struct EndFilterTemplateToken : public TemplateToken { + EndFilterTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFilter, location, pre, post) {} +}; + +struct ForTemplateToken : public TemplateToken { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + bool recursive; + ForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::vector & vns, std::shared_ptr && iter, + std::shared_ptr && c, bool r) + : TemplateToken(Type::For, location, pre, post), var_names(vns), iterable(std::move(iter)), condition(std::move(c)), recursive(r) {} +}; + +struct EndForTemplateToken : public TemplateToken { + EndForTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndFor, location, pre, post) {} +}; + +struct SetTemplateToken : public TemplateToken { + std::string ns; + std::vector var_names; + std::shared_ptr value; + SetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateToken(Type::Set, location, pre, post), ns(ns), var_names(vns), value(std::move(v)) {} +}; + +struct EndSetTemplateToken : public TemplateToken { + EndSetTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post) : TemplateToken(Type::EndSet, location, pre, post) {} +}; + +struct CommentTemplateToken : public TemplateToken { + std::string text; + CommentTemplateToken(const Location & location, SpaceHandling pre, SpaceHandling post, const std::string& t) : TemplateToken(Type::Comment, location, pre, post), text(t) {} +}; + +class TemplateNode { + Location location_; +protected: + virtual void do_render(std::ostringstream & out, const std::shared_ptr & context) const = 0; + +public: + TemplateNode(const Location & location) : location_(location) {} + void render(std::ostringstream & out, const std::shared_ptr & context) const { + try { + do_render(out, context); + } catch (const std::exception & e) { + std::ostringstream err; + err << e.what(); + if (location_.source) err << error_location_suffix(*location_.source, location_.pos); + throw std::runtime_error(err.str()); + } + } + const Location & location() const { return location_; } + virtual ~TemplateNode() = default; + std::string render(const std::shared_ptr & context) const { + std::ostringstream out; + render(out, context); + return normalize_newlines(out.str()); + } +}; + +class SequenceNode : public TemplateNode { + std::vector> children; +public: + SequenceNode(const Location & location, std::vector> && c) + : TemplateNode(location), children(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& child : children) child->render(out, context); + } +}; + +class TextNode : public TemplateNode { + std::string text; +public: + TextNode(const Location & location, const std::string& t) : TemplateNode(location), text(t) {} + void do_render(std::ostringstream & out, const std::shared_ptr &) const override { + out << text; + } +}; + +class ExpressionNode : public TemplateNode { + std::shared_ptr expr; +public: + ExpressionNode(const Location & location, std::shared_ptr && e) : TemplateNode(location), expr(std::move(e)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("ExpressionNode.expr is null"); + auto result = expr->evaluate(context); + if (result.is_string()) { + out << result.get(); + } else if (result.is_boolean()) { + out << (result.get() ? "True" : "False"); + } else if (!result.is_null()) { + out << result.dump(); + } + } +}; + +class IfNode : public TemplateNode { + std::vector, std::shared_ptr>> cascade; +public: + IfNode(const Location & location, std::vector, std::shared_ptr>> && c) + : TemplateNode(location), cascade(std::move(c)) {} + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + for (const auto& branch : cascade) { + auto enter_branch = true; + if (branch.first) { + enter_branch = branch.first->evaluate(context).to_bool(); + } + if (enter_branch) { + if (!branch.second) throw std::runtime_error("IfNode.cascade.second is null"); + branch.second->render(out, context); + return; + } + } + } +}; + +class ForNode : public TemplateNode { + std::vector var_names; + std::shared_ptr iterable; + std::shared_ptr condition; + std::shared_ptr body; + bool recursive; + std::shared_ptr else_body; +public: + ForNode(const Location & location, std::vector && var_names, std::shared_ptr && iterable, + std::shared_ptr && condition, std::shared_ptr && body, bool recursive, std::shared_ptr && else_body) + : TemplateNode(location), var_names(var_names), iterable(std::move(iterable)), condition(std::move(condition)), body(std::move(body)), recursive(recursive), else_body(std::move(else_body)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + // https://jinja.palletsprojects.com/en/3.0.x/templates/#for + if (!iterable) throw std::runtime_error("ForNode.iterable is null"); + if (!body) throw std::runtime_error("ForNode.body is null"); + + auto iterable_value = iterable->evaluate(context); + Value::CallableType loop_function; + + std::function visit = [&](Value& iter) { + auto filtered_items = Value::array(); + if (!iter.is_null()) { + if (!iterable_value.is_iterable()) { + throw std::runtime_error("For loop iterable must be iterable: " + iterable_value.dump()); + } + iterable_value.for_each([&](Value & item) { + destructuring_assign(var_names, context, item); + if (!condition || condition->evaluate(context).to_bool()) { + filtered_items.push_back(item); + } + }); + } + if (filtered_items.empty()) { + if (else_body) { + else_body->render(out, context); + } + } else { + auto loop = recursive ? Value::callable(loop_function) : Value::object(); + loop.set("length", (int64_t) filtered_items.size()); + + size_t cycle_index = 0; + loop.set("cycle", Value::callable([&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.empty() || !args.kwargs.empty()) { + throw std::runtime_error("cycle() expects at least 1 positional argument and no named arg"); + } + auto item = args.args[cycle_index]; + cycle_index = (cycle_index + 1) % args.args.size(); + return item; + })); + auto loop_context = Context::make(Value::object(), context); + loop_context->set("loop", loop); + for (size_t i = 0, n = filtered_items.size(); i < n; ++i) { + auto & item = filtered_items.at(i); + destructuring_assign(var_names, loop_context, item); + loop.set("index", (int64_t) i + 1); + loop.set("index0", (int64_t) i); + loop.set("revindex", (int64_t) (n - i)); + loop.set("revindex0", (int64_t) (n - i - 1)); + loop.set("length", (int64_t) n); + loop.set("first", i == 0); + loop.set("last", i == (n - 1)); + loop.set("previtem", i > 0 ? filtered_items.at(i - 1) : Value()); + loop.set("nextitem", i < n - 1 ? filtered_items.at(i + 1) : Value()); + body->render(out, loop_context); + } + } + }; + + if (recursive) { + loop_function = [&](const std::shared_ptr &, ArgumentsValue & args) { + if (args.args.size() != 1 || !args.kwargs.empty() || !args.args[0].is_array()) { + throw std::runtime_error("loop() expects exactly 1 positional iterable argument"); + } + auto & items = args.args[0]; + visit(items); + return Value(); + }; + } + + visit(iterable_value); + } +}; + +class MacroNode : public TemplateNode { + std::shared_ptr name; + Expression::Parameters params; + std::shared_ptr body; + std::unordered_map named_param_positions; +public: + MacroNode(const Location & location, std::shared_ptr && n, Expression::Parameters && p, std::shared_ptr && b) + : TemplateNode(location), name(std::move(n)), params(std::move(p)), body(std::move(b)) { + for (size_t i = 0; i < params.size(); ++i) { + const auto & name = params[i].first; + if (!name.empty()) { + named_param_positions[name] = i; + } + } + } + void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { + if (!name) throw std::runtime_error("MacroNode.name is null"); + if (!body) throw std::runtime_error("MacroNode.body is null"); + auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { + auto call_context = macro_context; + std::vector param_set(params.size(), false); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); + param_set[i] = true; + auto & param_name = params[i].first; + call_context->set(param_name, arg); + } + for (auto & [arg_name, value] : args.kwargs) { + auto it = named_param_positions.find(arg_name); + if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); + + call_context->set(arg_name, value); + param_set[it->second] = true; + } + // Set default values for parameters that were not passed + for (size_t i = 0, n = params.size(); i < n; i++) { + if (!param_set[i] && params[i].second != nullptr) { + auto val = params[i].second->evaluate(context); + call_context->set(params[i].first, val); + } + } + return body->render(call_context); + }); + macro_context->set(name->get_name(), callable); + } +}; + +class FilterNode : public TemplateNode { + std::shared_ptr filter; + std::shared_ptr body; + +public: + FilterNode(const Location & location, std::shared_ptr && f, std::shared_ptr && b) + : TemplateNode(location), filter(std::move(f)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!filter) throw std::runtime_error("FilterNode.filter is null"); + if (!body) throw std::runtime_error("FilterNode.body is null"); + auto filter_value = filter->evaluate(context); + if (!filter_value.is_callable()) { + throw std::runtime_error("Filter must be a callable: " + filter_value.dump()); + } + std::string rendered_body = body->render(context); + + ArgumentsValue filter_args = {{Value(rendered_body)}, {}}; + auto result = filter_value.call(context, filter_args); + out << result.to_str(); + } +}; + +class SetNode : public TemplateNode { + std::string ns; + std::vector var_names; + std::shared_ptr value; +public: + SetNode(const Location & location, const std::string & ns, const std::vector & vns, std::shared_ptr && v) + : TemplateNode(location), ns(ns), var_names(vns), value(std::move(v)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!value) throw std::runtime_error("SetNode.value is null"); + if (!ns.empty()) { + if (var_names.size() != 1) { + throw std::runtime_error("Namespaced set only supports a single variable name"); + } + auto & name = var_names[0]; + auto ns_value = context->get(ns); + if (!ns_value.is_object()) throw std::runtime_error("Namespace '" + ns + "' is not an object"); + ns_value.set(name, this->value->evaluate(context)); + } else { + auto val = value->evaluate(context); + destructuring_assign(var_names, context, val); + } + } +}; + +class SetTemplateNode : public TemplateNode { + std::string name; + std::shared_ptr template_value; +public: + SetTemplateNode(const Location & location, const std::string & name, std::shared_ptr && tv) + : TemplateNode(location), name(name), template_value(std::move(tv)) {} + void do_render(std::ostringstream &, const std::shared_ptr & context) const override { + if (!template_value) throw std::runtime_error("SetTemplateNode.template_value is null"); + Value value { template_value->render(context) }; + context->set(name, value); + } +}; + +class IfExpr : public Expression { + std::shared_ptr condition; + std::shared_ptr then_expr; + std::shared_ptr else_expr; +public: + IfExpr(const Location & location, std::shared_ptr && c, std::shared_ptr && t, std::shared_ptr && e) + : Expression(location), condition(std::move(c)), then_expr(std::move(t)), else_expr(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!condition) throw std::runtime_error("IfExpr.condition is null"); + if (!then_expr) throw std::runtime_error("IfExpr.then_expr is null"); + if (condition->evaluate(context).to_bool()) { + return then_expr->evaluate(context); + } + if (else_expr) { + return else_expr->evaluate(context); + } + return nullptr; + } +}; + +class LiteralExpr : public Expression { + Value value; +public: + LiteralExpr(const Location & location, const Value& v) + : Expression(location), value(v) {} + Value do_evaluate(const std::shared_ptr &) const override { return value; } +}; + +class ArrayExpr : public Expression { + std::vector> elements; +public: + ArrayExpr(const Location & location, std::vector> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::array(); + for (const auto& e : elements) { + if (!e) throw std::runtime_error("Array element is null"); + result.push_back(e->evaluate(context)); + } + return result; + } +}; + +class DictExpr : public Expression { + std::vector, std::shared_ptr>> elements; +public: + DictExpr(const Location & location, std::vector, std::shared_ptr>> && e) + : Expression(location), elements(std::move(e)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + auto result = Value::object(); + for (const auto& [key, value] : elements) { + if (!key) throw std::runtime_error("Dict key is null"); + if (!value) throw std::runtime_error("Dict value is null"); + result.set(key->evaluate(context), value->evaluate(context)); + } + return result; + } +}; + +class SliceExpr : public Expression { +public: + std::shared_ptr start, end; + SliceExpr(const Location & location, std::shared_ptr && s, std::shared_ptr && e) + : Expression(location), start(std::move(s)), end(std::move(e)) {} + Value do_evaluate(const std::shared_ptr &) const override { + throw std::runtime_error("SliceExpr not implemented"); + } +}; + +class SubscriptExpr : public Expression { + std::shared_ptr base; + std::shared_ptr index; +public: + SubscriptExpr(const Location & location, std::shared_ptr && b, std::shared_ptr && i) + : Expression(location), base(std::move(b)), index(std::move(i)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!base) throw std::runtime_error("SubscriptExpr.base is null"); + if (!index) throw std::runtime_error("SubscriptExpr.index is null"); + auto target_value = base->evaluate(context); + if (auto slice = dynamic_cast(index.get())) { + auto start = slice->start ? slice->start->evaluate(context).get() : 0; + auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + if (target_value.is_string()) { + std::string s = target_value.get(); + if (start < 0) start = s.size() + start; + if (end < 0) end = s.size() + end; + return s.substr(start, end - start); + } else if (target_value.is_array()) { + if (start < 0) start = target_value.size() + start; + if (end < 0) end = target_value.size() + end; + auto result = Value::array(); + for (auto i = start; i < end; ++i) { + result.push_back(target_value.at(i)); + } + return result; + } else { + throw std::runtime_error(target_value.is_null() ? "Cannot subscript null" : "Subscripting only supported on arrays and strings"); + } + } else { + auto index_value = index->evaluate(context); + if (target_value.is_null()) { + if (auto t = dynamic_cast(base.get())) { + throw std::runtime_error("'" + t->get_name() + "' is " + (context->contains(t->get_name()) ? "null" : "not defined")); + } + throw std::runtime_error("Trying to access property '" + index_value.dump() + "' on null!"); + } + return target_value.get(index_value); + } + } +}; + +class UnaryOpExpr : public Expression { +public: + enum class Op { Plus, Minus, LogicalNot, Expansion, ExpansionDict }; + std::shared_ptr expr; + Op op; + UnaryOpExpr(const Location & location, std::shared_ptr && e, Op o) + : Expression(location), expr(std::move(e)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("UnaryOpExpr.expr is null"); + auto e = expr->evaluate(context); + switch (op) { + case Op::Plus: return e; + case Op::Minus: return -e; + case Op::LogicalNot: return !e.to_bool(); + case Op::Expansion: + case Op::ExpansionDict: + throw std::runtime_error("Expansion operator is only supported in function calls and collections"); + + } + throw std::runtime_error("Unknown unary operator"); + } +}; + +class BinaryOpExpr : public Expression { +public: + enum class Op { StrConcat, Add, Sub, Mul, MulMul, Div, DivDiv, Mod, Eq, Ne, Lt, Gt, Le, Ge, And, Or, In, NotIn, Is, IsNot }; +private: + std::shared_ptr left; + std::shared_ptr right; + Op op; +public: + BinaryOpExpr(const Location & location, std::shared_ptr && l, std::shared_ptr && r, Op o) + : Expression(location), left(std::move(l)), right(std::move(r)), op(o) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!left) throw std::runtime_error("BinaryOpExpr.left is null"); + if (!right) throw std::runtime_error("BinaryOpExpr.right is null"); + auto l = left->evaluate(context); + + auto do_eval = [&](const Value & l) -> Value { + if (op == Op::Is || op == Op::IsNot) { + auto t = dynamic_cast(right.get()); + if (!t) throw std::runtime_error("Right side of 'is' operator must be a variable"); + + auto eval = [&]() { + const auto & name = t->get_name(); + if (name == "none") return l.is_null(); + if (name == "boolean") return l.is_boolean(); + if (name == "integer") return l.is_number_integer(); + if (name == "float") return l.is_number_float(); + if (name == "number") return l.is_number(); + if (name == "string") return l.is_string(); + if (name == "mapping") return l.is_object(); + if (name == "iterable") return l.is_iterable(); + if (name == "sequence") return l.is_array(); + if (name == "defined") return !l.is_null(); + throw std::runtime_error("Unknown type for 'is' operator: " + name); + }; + auto value = eval(); + return Value(op == Op::Is ? value : !value); + } + + if (op == Op::And) { + if (!l.to_bool()) return Value(false); + return right->evaluate(context).to_bool(); + } else if (op == Op::Or) { + if (l.to_bool()) return l; + return right->evaluate(context); + } + + auto r = right->evaluate(context); + switch (op) { + case Op::StrConcat: return l.to_str() + r.to_str(); + case Op::Add: return l + r; + case Op::Sub: return l - r; + case Op::Mul: return l * r; + case Op::Div: return l / r; + case Op::MulMul: return std::pow(l.get(), r.get()); + case Op::DivDiv: return l.get() / r.get(); + case Op::Mod: return l.get() % r.get(); + case Op::Eq: return l == r; + case Op::Ne: return l != r; + case Op::Lt: return l < r; + case Op::Gt: return l > r; + case Op::Le: return l <= r; + case Op::Ge: return l >= r; + case Op::In: return (r.is_array() || r.is_object()) && r.contains(l); + case Op::NotIn: return !(r.is_array() && r.contains(l)); + default: break; + } + throw std::runtime_error("Unknown binary operator"); + }; + + if (l.is_callable()) { + return Value::callable([l, do_eval](const std::shared_ptr & context, ArgumentsValue & args) { + auto ll = l.call(context, args); + return do_eval(ll); //args[0].second); + }); + } else { + return do_eval(l); + } + } +}; + +struct ArgumentsExpression { + std::vector> args; + std::vector>> kwargs; + + ArgumentsValue evaluate(const std::shared_ptr & context) const { + ArgumentsValue vargs; + for (const auto& arg : this->args) { + if (auto un_expr = std::dynamic_pointer_cast(arg)) { + if (un_expr->op == UnaryOpExpr::Op::Expansion) { + auto array = un_expr->expr->evaluate(context); + if (!array.is_array()) { + throw std::runtime_error("Expansion operator only supported on arrays"); + } + array.for_each([&](Value & value) { + vargs.args.push_back(value); + }); + continue; + } else if (un_expr->op == UnaryOpExpr::Op::ExpansionDict) { + auto dict = un_expr->expr->evaluate(context); + if (!dict.is_object()) { + throw std::runtime_error("ExpansionDict operator only supported on objects"); + } + dict.for_each([&](const Value & key) { + vargs.kwargs.push_back({key.get(), dict.at(key)}); + }); + continue; + } + } + vargs.args.push_back(arg->evaluate(context)); + } + for (const auto& [name, value] : this->kwargs) { + vargs.kwargs.push_back({name, value->evaluate(context)}); + } + return vargs; + } +}; + +static std::string strip(const std::string & s) { + static std::regex trailing_spaces_regex("^\\s+|\\s+$"); + return std::regex_replace(s, trailing_spaces_regex, ""); + // auto start = s.find_first_not_of(" \t\n\r"); + // if (start == std::string::npos) return ""; + // auto end = s.find_last_not_of(" \t\n\r"); + // return s.substr(start, end - start + 1); +} + +static std::string html_escape(const std::string & s) { + std::string result; + result.reserve(s.size()); + for (const auto & c : s) { + switch (c) { + case '&': result += "&"; break; + case '<': result += "<"; break; + case '>': result += ">"; break; + case '"': result += """; break; + case '\'': result += "'"; break; + default: result += c; break; + } + } + return result; +} + +class MethodCallExpr : public Expression { + std::shared_ptr object; + std::shared_ptr method; + ArgumentsExpression args; +public: + MethodCallExpr(const Location & location, std::shared_ptr && obj, std::shared_ptr && m, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), method(std::move(m)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("MethodCallExpr.object is null"); + if (!method) throw std::runtime_error("MethodCallExpr.method is null"); + auto obj = object->evaluate(context); + auto vargs = args.evaluate(context); + if (obj.is_null()) { + throw std::runtime_error("Trying to call method '" + method->get_name() + "' on null"); + } + if (obj.is_array()) { + if (method->get_name() == "append") { + vargs.expectArgs("append method", {1, 1}, {0, 0}); + obj.push_back(vargs.args[0]); + return Value(); + } else if (method->get_name() == "insert") { + vargs.expectArgs("insert method", {2, 2}, {0, 0}); + auto index = vargs.args[0].get(); + if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); + obj.insert(index, vargs.args[1]); + return Value(); + } + } else if (obj.is_object()) { + if (method->get_name() == "items") { + vargs.expectArgs("items method", {0, 0}, {0, 0}); + auto result = Value::array(); + for (const auto& key : obj.keys()) { + result.push_back(Value::array({key, obj.at(key)})); + } + return result; + } else if (method->get_name() == "get") { + vargs.expectArgs("get method", {1, 2}, {0, 0}); + auto key = vargs.args[0]; + if (vargs.args.size() == 1) { + return obj.contains(key) ? obj.at(key) : Value(); + } else { + return obj.contains(key) ? obj.at(key) : vargs.args[1]; + } + } else if (obj.contains(method->get_name())) { + auto callable = obj.at(method->get_name()); + if (!callable.is_callable()) { + throw std::runtime_error("Property '" + method->get_name() + "' is not callable"); + } + return callable.call(context, vargs); + } + } else if (obj.is_string()) { + auto str = obj.get(); + if (method->get_name() == "strip") { + vargs.expectArgs("strip method", {0, 0}, {0, 0}); + return Value(strip(str)); + } else if (method->get_name() == "endswith") { + vargs.expectArgs("endswith method", {1, 1}, {0, 0}); + auto suffix = vargs.args[0].get(); + return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "title") { + vargs.expectArgs("title method", {0, 0}, {0, 0}); + auto res = str; + for (size_t i = 0, n = res.size(); i < n; ++i) { + if (i == 0 || std::isspace(res[i - 1])) res[i] = std::toupper(res[i]); + else res[i] = std::tolower(res[i]); + } + return res; + } + } + throw std::runtime_error("Unknown method: " + method->get_name()); + } +}; + +class CallExpr : public Expression { +public: + std::shared_ptr object; + ArgumentsExpression args; + CallExpr(const Location & location, std::shared_ptr && obj, ArgumentsExpression && a) + : Expression(location), object(std::move(obj)), args(std::move(a)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + if (!object) throw std::runtime_error("CallExpr.object is null"); + auto obj = object->evaluate(context); + if (!obj.is_callable()) { + throw std::runtime_error("Object is not callable: " + obj.dump(2)); + } + auto vargs = args.evaluate(context); + return obj.call(context, vargs); + } +}; + +class FilterExpr : public Expression { + std::vector> parts; +public: + FilterExpr(const Location & location, std::vector> && p) + : Expression(location), parts(std::move(p)) {} + Value do_evaluate(const std::shared_ptr & context) const override { + Value result; + bool first = true; + for (const auto& part : parts) { + if (!part) throw std::runtime_error("FilterExpr.part is null"); + if (first) { + first = false; + result = part->evaluate(context); + } else { + if (auto ce = dynamic_cast(part.get())) { + auto target = ce->object->evaluate(context); + ArgumentsValue args = ce->args.evaluate(context); + args.args.insert(args.args.begin(), result); + result = target.call(context, args); + } else { + auto callable = part->evaluate(context); + ArgumentsValue args; + args.args.insert(args.args.begin(), result); + result = callable.call(context, args); + } + } + } + return result; + } + + void prepend(std::shared_ptr && e) { + parts.insert(parts.begin(), std::move(e)); + } +}; + +class Parser { +private: + using CharIterator = std::string::const_iterator; + + std::shared_ptr template_str; + CharIterator start, end, it; + Options options; + + Parser(const std::shared_ptr& template_str, const Options & options) : template_str(template_str), options(options) { + if (!template_str) throw std::runtime_error("Template string is null"); + start = it = this->template_str->begin(); + end = this->template_str->end(); + } + + bool consumeSpaces(SpaceHandling space_handling = SpaceHandling::Strip) { + if (space_handling == SpaceHandling::Strip) { + while (it != end && std::isspace(*it)) ++it; + } + return true; + } + + std::unique_ptr parseString() { + auto doParse = [&](char quote) -> std::unique_ptr { + if (it == end || *it != quote) return nullptr; + std::string result; + bool escape = false; + for (++it; it != end; ++it) { + if (escape) { + escape = false; + switch (*it) { + case 'n': result += '\n'; break; + case 'r': result += '\r'; break; + case 't': result += '\t'; break; + case 'b': result += '\b'; break; + case 'f': result += '\f'; break; + case '\\': result += '\\'; break; + default: + if (*it == quote) { + result += quote; + } else { + result += *it; + } + break; + } + } else if (*it == '\\') { + escape = true; + } else if (*it == quote) { + ++it; + return std::make_unique(std::move(result)); + } else { + result += *it; + } + } + return nullptr; + }; + + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"') return doParse('"'); + if (*it == '\'') return doParse('\''); + return nullptr; + } + + json parseNumber(CharIterator& it, const CharIterator& end) { + auto before = it; + consumeSpaces(); + auto start = it; + bool hasDecimal = false; + bool hasExponent = false; + + if (it != end && (*it == '-' || *it == '+')) ++it; + + while (it != end) { + if (std::isdigit(*it)) { + ++it; + } else if (*it == '.') { + if (hasDecimal) throw std::runtime_error("Multiple decimal points"); + hasDecimal = true; + ++it; + } else if (it != start && (*it == 'e' || *it == 'E')) { + if (hasExponent) throw std::runtime_error("Multiple exponents"); + hasExponent = true; + ++it; + } else { + break; + } + } + if (start == it) { + it = before; + return json(); // No valid characters found + } + + std::string str(start, it); + try { + return json::parse(str); + } catch (json::parse_error& e) { + throw std::runtime_error("Failed to parse number: '" + str + "' (" + std::string(e.what()) + ")"); + return json(); + } + } + + /** integer, float, bool, string */ + std::shared_ptr parseConstant() { + auto start = it; + consumeSpaces(); + if (it == end) return nullptr; + if (*it == '"' || *it == '\'') { + auto str = parseString(); + if (str) return std::make_shared(*str); + } + static std::regex prim_tok(R"(true\b|True\b|false\b|False\b|None\b)"); + auto token = consumeToken(prim_tok); + if (!token.empty()) { + if (token == "true" || token == "True") return std::make_shared(true); + if (token == "false" || token == "False") return std::make_shared(false); + if (token == "None") return std::make_shared(nullptr); + throw std::runtime_error("Unknown constant token: " + token); + } + + auto number = parseNumber(it, end); + if (!number.is_null()) return std::make_shared(number); + + it = start; + return nullptr; + } + + class expression_parsing_error : public std::runtime_error { + const CharIterator it; + public: + expression_parsing_error(const std::string & message, const CharIterator & it) + : std::runtime_error(message), it(it) {} + size_t get_pos(const CharIterator & begin) const { + return std::distance(begin, it); + } + }; + + bool peekSymbols(const std::vector & symbols) const { + for (const auto & symbol : symbols) { + if (std::distance(it, end) >= (int64_t) symbol.size() && std::string(it, it + symbol.size()) == symbol) { + return true; + } + } + return false; + } + + std::vector consumeTokenGroups(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + std::vector ret; + for (size_t i = 0, n = match.size(); i < n; ++i) { + ret.push_back(match[i].str()); + } + return ret; + } + it = start; + return {}; + } + std::string consumeToken(const std::regex & regex, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + std::smatch match; + if (std::regex_search(it, end, match, regex) && match.position() == 0) { + it += match[0].length(); + return match[0].str(); + } + it = start; + return ""; + } + + std::string consumeToken(const std::string & token, SpaceHandling space_handling = SpaceHandling::Strip) { + auto start = it; + consumeSpaces(space_handling); + if (std::distance(it, end) >= (int64_t) token.size() && std::string(it, it + token.size()) == token) { + it += token.size(); + return token; + } + it = start; + return ""; + } + + std::shared_ptr parseExpression(bool allow_if_expr = true) { + auto left = parseLogicalOr(); + if (it == end) return left; + + if (!allow_if_expr) return left; + + static std::regex if_tok(R"(if\b)"); + if (consumeToken(if_tok).empty()) { + return left; + } + + auto location = get_location(); + auto [condition, else_expr] = parseIfExpression(); + return std::make_shared(location, std::move(condition), std::move(left), std::move(else_expr)); + } + + Location get_location() const { + return {template_str, (size_t) std::distance(start, it)}; + } + + std::pair, std::shared_ptr> parseIfExpression() { + auto condition = parseLogicalOr(); + if (!condition) throw std::runtime_error("Expected condition expression"); + + static std::regex else_tok(R"(else\b)"); + std::shared_ptr else_expr; + if (!consumeToken(else_tok).empty()) { + else_expr = parseExpression(); + if (!else_expr) throw std::runtime_error("Expected 'else' expression"); + } + return std::pair(std::move(condition), std::move(else_expr)); + } + + std::shared_ptr parseLogicalOr() { + auto left = parseLogicalAnd(); + if (!left) throw std::runtime_error("Expected left side of 'logical or' expression"); + + static std::regex or_tok(R"(or\b)"); + auto location = get_location(); + while (!consumeToken(or_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'or' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::Or); + } + return left; + } + + std::shared_ptr parseLogicalNot() { + static std::regex not_tok(R"(not\b)"); + auto location = get_location(); + + if (!consumeToken(not_tok).empty()) { + auto sub = parseLogicalNot(); + if (!sub) throw std::runtime_error("Expected expression after 'not' keyword"); + return std::make_shared(location, std::move(sub), UnaryOpExpr::Op::LogicalNot); + } + return parseLogicalCompare(); + } + + std::shared_ptr parseLogicalAnd() { + auto left = parseLogicalNot(); + if (!left) throw std::runtime_error("Expected left side of 'logical and' expression"); + + static std::regex and_tok(R"(and\b)"); + auto location = get_location(); + while (!consumeToken(and_tok).empty()) { + auto right = parseLogicalNot(); + if (!right) throw std::runtime_error("Expected right side of 'and' expression"); + left = std::make_shared(location, std::move(left), std::move(right), BinaryOpExpr::Op::And); + } + return left; + } + + std::shared_ptr parseLogicalCompare() { + auto left = parseStringConcat(); + if (!left) throw std::runtime_error("Expected left side of 'logical compare' expression"); + + static std::regex compare_tok(R"(==|!=|<=?|>=?|in\b|is\b|not[\r\n\s]+in\b)"); + static std::regex not_tok(R"(not\b)"); + std::string op_str; + while (!(op_str = consumeToken(compare_tok)).empty()) { + auto location = get_location(); + if (op_str == "is") { + auto negated = !consumeToken(not_tok).empty(); + + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier after 'is' keyword"); + + return std::make_shared( + left->location, + std::move(left), std::move(identifier), + negated ? BinaryOpExpr::Op::IsNot : BinaryOpExpr::Op::Is); + } + auto right = parseStringConcat(); + if (!right) throw std::runtime_error("Expected right side of 'logical compare' expression"); + BinaryOpExpr::Op op; + if (op_str == "==") op = BinaryOpExpr::Op::Eq; + else if (op_str == "!=") op = BinaryOpExpr::Op::Ne; + else if (op_str == "<") op = BinaryOpExpr::Op::Lt; + else if (op_str == ">") op = BinaryOpExpr::Op::Gt; + else if (op_str == "<=") op = BinaryOpExpr::Op::Le; + else if (op_str == ">=") op = BinaryOpExpr::Op::Ge; + else if (op_str == "in") op = BinaryOpExpr::Op::In; + else if (op_str.substr(0, 3) == "not") op = BinaryOpExpr::Op::NotIn; + else throw std::runtime_error("Unknown comparison operator: " + op_str); + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + Expression::Parameters parseParameters() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in param list"); + + Expression::Parameters result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.emplace_back(ident->get_name(), std::move(value)); + } else { + result.emplace_back(ident->get_name(), nullptr); + } + } else { + result.emplace_back(std::string(), std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + ArgumentsExpression parseCallArgs() { + consumeSpaces(); + if (consumeToken("(").empty()) throw std::runtime_error("Expected opening parenthesis in call args"); + + ArgumentsExpression result; + + while (it != end) { + if (!consumeToken(")").empty()) { + return result; + } + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call args"); + + if (auto ident = dynamic_cast(expr.get())) { + if (!consumeToken("=").empty()) { + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected expression in for named arg"); + result.kwargs.emplace_back(ident->get_name(), std::move(value)); + } else { + result.args.emplace_back(std::move(expr)); + } + } else { + result.args.emplace_back(std::move(expr)); + } + if (consumeToken(",").empty()) { + if (consumeToken(")").empty()) { + throw std::runtime_error("Expected closing parenthesis in call args"); + } + return result; + } + } + throw std::runtime_error("Expected closing parenthesis in call args"); + } + + std::shared_ptr parseIdentifier() { + static std::regex ident_regex(R"((?!(?:not|is|and|or|del)\b)[a-zA-Z_]\w*)"); + auto location = get_location(); + auto ident = consumeToken(ident_regex); + if (ident.empty()) + return nullptr; + return std::make_shared(location, ident); + } + + std::shared_ptr parseStringConcat() { + auto left = parseMathPow(); + if (!left) throw std::runtime_error("Expected left side of 'string concat' expression"); + + static std::regex concat_tok(R"(~(?!\}))"); + if (!consumeToken(concat_tok).empty()) { + auto right = parseLogicalAnd(); + if (!right) throw std::runtime_error("Expected right side of 'string concat' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::StrConcat); + } + return left; + } + + std::shared_ptr parseMathPow() { + auto left = parseMathPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math pow' expression"); + + while (!consumeToken("**").empty()) { + auto right = parseMathPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math pow' expression"); + left = std::make_shared(get_location(), std::move(left), std::move(right), BinaryOpExpr::Op::MulMul); + } + return left; + } + + std::shared_ptr parseMathPlusMinus() { + static std::regex plus_minus_tok(R"(\+|-(?![}%#]\}))"); + + auto left = parseMathMulDiv(); + if (!left) throw std::runtime_error("Expected left side of 'math plus/minus' expression"); + std::string op_str; + while (!(op_str = consumeToken(plus_minus_tok)).empty()) { + auto right = parseMathMulDiv(); + if (!right) throw std::runtime_error("Expected right side of 'math plus/minus' expression"); + auto op = op_str == "+" ? BinaryOpExpr::Op::Add : BinaryOpExpr::Op::Sub; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + return left; + } + + std::shared_ptr parseMathMulDiv() { + auto left = parseMathUnaryPlusMinus(); + if (!left) throw std::runtime_error("Expected left side of 'math mul/div' expression"); + + static std::regex mul_div_tok(R"(\*\*?|//?|%(?!\}))"); + std::string op_str; + while (!(op_str = consumeToken(mul_div_tok)).empty()) { + auto right = parseMathUnaryPlusMinus(); + if (!right) throw std::runtime_error("Expected right side of 'math mul/div' expression"); + auto op = op_str == "*" ? BinaryOpExpr::Op::Mul + : op_str == "**" ? BinaryOpExpr::Op::MulMul + : op_str == "/" ? BinaryOpExpr::Op::Div + : op_str == "//" ? BinaryOpExpr::Op::DivDiv + : BinaryOpExpr::Op::Mod; + left = std::make_shared(get_location(), std::move(left), std::move(right), op); + } + + if (!consumeToken("|").empty()) { + auto expr = parseMathMulDiv(); + if (auto filter = dynamic_cast(expr.get())) { + filter->prepend(std::move(left)); + return expr; + } else { + std::vector> parts; + parts.emplace_back(std::move(left)); + parts.emplace_back(std::move(expr)); + return std::make_shared(get_location(), std::move(parts)); + } + } + return left; + } + + std::shared_ptr call_func(const std::string & name, ArgumentsExpression && args) const { + return std::make_shared(get_location(), std::make_shared(get_location(), name), std::move(args)); + } + + std::shared_ptr parseMathUnaryPlusMinus() { + static std::regex unary_plus_minus_tok(R"(\+|-(?![}%#]\}))"); + auto op_str = consumeToken(unary_plus_minus_tok); + auto expr = parseExpansion(); + if (!expr) throw std::runtime_error("Expected expr of 'unary plus/minus/expansion' expression"); + + if (!op_str.empty()) { + auto op = op_str == "+" ? UnaryOpExpr::Op::Plus : UnaryOpExpr::Op::Minus; + return std::make_shared(get_location(), std::move(expr), op); + } + return expr; + } + + std::shared_ptr parseExpansion() { + static std::regex expansion_tok(R"(\*\*?)"); + auto op_str = consumeToken(expansion_tok); + auto expr = parseValueExpression(); + if (op_str.empty()) return expr; + if (!expr) throw std::runtime_error("Expected expr of 'expansion' expression"); + return std::make_shared(get_location(), std::move(expr), op_str == "*" ? UnaryOpExpr::Op::Expansion : UnaryOpExpr::Op::ExpansionDict); + } + + std::shared_ptr parseValueExpression() { + auto parseValue = [&]() -> std::shared_ptr { + auto location = get_location(); + auto constant = parseConstant(); + if (constant) return std::make_shared(location, *constant); + + static std::regex null_regex(R"(null\b)"); + if (!consumeToken(null_regex).empty()) return std::make_shared(location, Value()); + + auto identifier = parseIdentifier(); + if (identifier) return identifier; + + auto braced = parseBracedExpressionOrArray(); + if (braced) return braced; + + auto array = parseArray(); + if (array) return array; + + auto dictionary = parseDictionary(); + if (dictionary) return dictionary; + + throw std::runtime_error("Expected value expression"); + }; + + auto value = parseValue(); + + while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { + if (!consumeToken("[").empty()) { + std::shared_ptr index; + if (!consumeToken(":").empty()) { + auto slice_end = parseExpression(); + index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); + } else { + auto slice_start = parseExpression(); + if (!consumeToken(":").empty()) { + consumeSpaces(); + if (peekSymbols({ "]" })) { + index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); + } else { + auto slice_end = parseExpression(); + index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); + } + } else { + index = std::move(slice_start); + } + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + + value = std::make_shared(value->location, std::move(value), std::move(index)); + } else if (!consumeToken(".").empty()) { + auto identifier = parseIdentifier(); + if (!identifier) throw std::runtime_error("Expected identifier in subscript"); + + consumeSpaces(); + if (peekSymbols({ "(" })) { + auto callParams = parseCallArgs(); + value = std::make_shared(identifier->location, std::move(value), std::move(identifier), std::move(callParams)); + } else { + auto key = std::make_shared(identifier->location, Value(identifier->get_name())); + value = std::make_shared(identifier->location, std::move(value), std::move(key)); + } + } + consumeSpaces(); + } + + if (peekSymbols({ "(" })) { + auto location = get_location(); + auto callParams = parseCallArgs(); + value = std::make_shared(location, std::move(value), std::move(callParams)); + } + return value; + } + + std::shared_ptr parseBracedExpressionOrArray() { + if (consumeToken("(").empty()) return nullptr; + + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in braced expression"); + + if (!consumeToken(")").empty()) { + return expr; // Drop the parentheses + } + + std::vector> tuple; + tuple.emplace_back(std::move(expr)); + + while (it != end) { + if (consumeToken(",").empty()) throw std::runtime_error("Expected comma in tuple"); + auto next = parseExpression(); + if (!next) throw std::runtime_error("Expected expression in tuple"); + tuple.push_back(std::move(next)); + + if (!consumeToken(")").empty()) { + return std::make_shared(get_location(), std::move(tuple)); + } + } + throw std::runtime_error("Expected closing parenthesis"); + } + + std::shared_ptr parseArray() { + if (consumeToken("[").empty()) return nullptr; + + std::vector> elements; + if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + auto first_expr = parseExpression(); + if (!first_expr) throw std::runtime_error("Expected first expression in array"); + elements.push_back(std::move(first_expr)); + + while (it != end) { + if (!consumeToken(",").empty()) { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in array"); + elements.push_back(std::move(expr)); + } else if (!consumeToken("]").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing bracket in array"); + } + } + throw std::runtime_error("Expected closing bracket"); + } + + std::shared_ptr parseDictionary() { + if (consumeToken("{").empty()) return nullptr; + + std::vector, std::shared_ptr>> elements; + if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } + + auto parseKeyValuePair = [&]() { + auto key = parseExpression(); + if (!key) throw std::runtime_error("Expected key in dictionary"); + if (consumeToken(":").empty()) throw std::runtime_error("Expected colon betweek key & value in dictionary"); + auto value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in dictionary"); + elements.emplace_back(std::pair(std::move(key), std::move(value))); + }; + + parseKeyValuePair(); + + while (it != end) { + if (!consumeToken(",").empty()) { + parseKeyValuePair(); + } else if (!consumeToken("}").empty()) { + return std::make_shared(get_location(), std::move(elements)); + } else { + throw std::runtime_error("Expected comma or closing brace in dictionary"); + } + } + throw std::runtime_error("Expected closing brace"); + } + + SpaceHandling parsePreSpace(const std::string& s) const { + if (s == "-") + return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + SpaceHandling parsePostSpace(const std::string& s) const { + if (s == "-") return SpaceHandling::Strip; + return SpaceHandling::Keep; + } + + using TemplateTokenVector = std::vector>; + using TemplateTokenIterator = TemplateTokenVector::const_iterator; + + std::vector parseVarNames() { + static std::regex varnames_regex(R"(((?:\w+)(?:[\r\n\s]*,[\r\n\s]*(?:\w+))*)[\r\n\s]*)"); + + std::vector group; + if ((group = consumeTokenGroups(varnames_regex)).empty()) throw std::runtime_error("Expected variable names"); + std::vector varnames; + std::istringstream iss(group[1]); + std::string varname; + while (std::getline(iss, varname, ',')) { + varnames.push_back(strip(varname)); + } + return varnames; + } + + std::runtime_error unexpected(const TemplateToken & token) const { + return std::runtime_error("Unexpected " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + std::runtime_error unterminated(const TemplateToken & token) const { + return std::runtime_error("Unterminated " + TemplateToken::typeToString(token.type) + + error_location_suffix(*template_str, token.location.pos)); + } + + TemplateTokenVector tokenize() { + static std::regex comment_tok(R"(\{#([-~]?)(.*?)([-~]?)#\})"); + static std::regex expr_open_regex(R"(\{\{([-~])?)"); + static std::regex block_open_regex(R"(^\{%([-~])?[\s\n\r]*)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|set|endset|block|endblock|macro|endmacro|filter|endfilter)\b)"); + static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); + static std::regex expr_close_regex(R"([\s\n\r]*([-~])?\}\})"); + static std::regex block_close_regex(R"([\s\n\r]*([-~])?%\})"); + + TemplateTokenVector tokens; + std::vector group; + std::string text; + std::smatch match; + + try { + while (it != end) { + auto location = get_location(); + + if (!(group = consumeTokenGroups(comment_tok, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto content = group[2]; + auto post_space = parsePostSpace(group[3]); + tokens.push_back(std::make_unique(location, pre_space, post_space, content)); + } else if (!(group = consumeTokenGroups(expr_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + auto expr = parseExpression(); + + if ((group = consumeTokenGroups(expr_close_regex)).empty()) { + throw std::runtime_error("Expected closing expression tag"); + } + + auto post_space = parsePostSpace(group[1]); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (!(group = consumeTokenGroups(block_open_regex, SpaceHandling::Keep)).empty()) { + auto pre_space = parsePreSpace(group[1]); + + std::string keyword; + + auto parseBlockClose = [&]() -> SpaceHandling { + if ((group = consumeTokenGroups(block_close_regex)).empty()) throw std::runtime_error("Expected closing block tag"); + return parsePostSpace(group[1]); + }; + + if ((keyword = consumeToken(block_keyword_tok)).empty()) throw std::runtime_error("Expected block keyword"); + + if (keyword == "if") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in if block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "elif") { + auto condition = parseExpression(); + if (!condition) throw std::runtime_error("Expected condition in elif block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(condition))); + } else if (keyword == "else") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "endif") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "for") { + static std::regex recursive_tok(R"(recursive\b)"); + static std::regex if_tok(R"(if\b)"); + + auto varnames = parseVarNames(); + static std::regex in_tok(R"(in\b)"); + if (consumeToken(in_tok).empty()) throw std::runtime_error("Expected 'in' keyword in for block"); + auto iterable = parseExpression(/* allow_if_expr = */ false); + if (!iterable) throw std::runtime_error("Expected iterable in for block"); + + std::shared_ptr condition; + if (!consumeToken(if_tok).empty()) { + condition = parseExpression(); + } + auto recursive = !consumeToken(recursive_tok).empty(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(varnames), std::move(iterable), std::move(condition), recursive)); + } else if (keyword == "endfor") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "set") { + static std::regex namespaced_var_regex(R"((\w+)[\s\n\r]*\.[\s\n\r]*(\w+))"); + + std::string ns; + std::vector var_names; + std::shared_ptr value; + if (!(group = consumeTokenGroups(namespaced_var_regex)).empty()) { + ns = group[1]; + var_names.push_back(group[2]); + + if (consumeToken("=").empty()) throw std::runtime_error("Expected equals sign in set block"); + + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } else { + var_names = parseVarNames(); + + if (!consumeToken("=").empty()) { + value = parseExpression(); + if (!value) throw std::runtime_error("Expected value in set block"); + } + } + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, ns, var_names, std::move(value))); + } else if (keyword == "endset") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "macro") { + auto macroname = parseIdentifier(); + if (!macroname) throw std::runtime_error("Expected macro name in macro block"); + auto params = parseParameters(); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(macroname), std::move(params))); + } else if (keyword == "endmacro") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "filter") { + auto filter = parseExpression(); + if (!filter) throw std::runtime_error("Expected expression in filter block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(filter))); + } else if (keyword == "endfilter") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else { + throw std::runtime_error("Unexpected block: " + keyword); + } + } else if (std::regex_search(it, end, match, non_text_open_regex)) { + auto text_end = it + match.position(); + text = std::string(it, text_end); + it = text_end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } else { + text = std::string(it, end); + it = end; + tokens.push_back(std::make_unique(location, SpaceHandling::Keep, SpaceHandling::Keep, text)); + } + } + return tokens; + } catch (const std::exception & e) { + throw std::runtime_error(e.what() + error_location_suffix(*template_str, std::distance(start, it))); + } + } + + std::shared_ptr parseTemplate( + const TemplateTokenIterator & begin, + TemplateTokenIterator & it, + const TemplateTokenIterator & end, + bool fully = false) const { + std::vector> children; + while (it != end) { + const auto start = it; + const auto & token = *(it++); + if (auto if_token = dynamic_cast(token.get())) { + std::vector, std::shared_ptr>> cascade; + cascade.emplace_back(std::move(if_token->condition), parseTemplate(begin, it, end)); + + while (it != end && (*it)->type == TemplateToken::Type::Elif) { + auto elif_token = dynamic_cast((*(it++)).get()); + cascade.emplace_back(std::move(elif_token->condition), parseTemplate(begin, it, end)); + } + + if (it != end && (*it)->type == TemplateToken::Type::Else) { + cascade.emplace_back(nullptr, parseTemplate(begin, ++it, end)); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndIf) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(cascade))); + } else if (auto for_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + auto else_body = std::shared_ptr(); + if (it != end && (*it)->type == TemplateToken::Type::Else) { + else_body = parseTemplate(begin, ++it, end); + } + if (it == end || (*(it++))->type != TemplateToken::Type::EndFor) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(for_token->var_names), std::move(for_token->iterable), std::move(for_token->condition), std::move(body), for_token->recursive, std::move(else_body))); + } else if (auto text_token = dynamic_cast(token.get())) { + SpaceHandling pre_space = (it - 1) != begin ? (*(it - 2))->post_space : SpaceHandling::Keep; + SpaceHandling post_space = it != end ? (*it)->pre_space : SpaceHandling::Keep; + + auto text = text_token->text; + if (post_space == SpaceHandling::Strip) { + static std::regex trailing_space_regex(R"((\s|\r|\n)+$)"); + text = std::regex_replace(text, trailing_space_regex, ""); + } else if (options.lstrip_blocks && it != end) { + auto i = text.size(); + while (i > 0 && (text[i - 1] == ' ' || text[i - 1] == '\t')) i--; + if ((i == 0 && (it - 1) == begin) || (i > 0 && text[i - 1] == '\n')) { + text.resize(i); + } + } + if (pre_space == SpaceHandling::Strip) { + static std::regex leading_space_regex(R"(^(\s|\r|\n)+)"); + text = std::regex_replace(text, leading_space_regex, ""); + } else if (options.trim_blocks && (it - 1) != begin && !dynamic_cast((*(it - 2)).get())) { + if (text.length() > 0 && text[0] == '\n') { + text.erase(0, 1); + } + } + if (it == end && !options.keep_trailing_newline) { + auto i = text.size(); + if (i > 0 && text[i - 1] == '\n') { + i--; + if (i > 0 && text[i - 1] == '\r') i--; + text.resize(i); + } + } + children.emplace_back(std::make_shared(token->location, text)); + } else if (auto expr_token = dynamic_cast(token.get())) { + children.emplace_back(std::make_shared(token->location, std::move(expr_token->expr))); + } else if (auto set_token = dynamic_cast(token.get())) { + if (set_token->value) { + children.emplace_back(std::make_shared(token->location, set_token->ns, set_token->var_names, std::move(set_token->value))); + } else { + auto value_template = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndSet) { + throw unterminated(**start); + } + if (!set_token->ns.empty()) throw std::runtime_error("Namespaced set not supported in set with template value"); + if (set_token->var_names.size() != 1) throw std::runtime_error("Structural assignment not supported in set with template value"); + auto & name = set_token->var_names[0]; + children.emplace_back(std::make_shared(token->location, name, std::move(value_template))); + } + } else if (auto macro_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndMacro) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto filter_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(filter_token->filter), std::move(body))); + } else if (dynamic_cast(token.get())) { + // Ignore comments + } else if (dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get()) + || dynamic_cast(token.get())) { + it--; // unconsume the token + break; // exit the loop + } else { + throw unexpected(**(it-1)); + } + } + if (fully && it != end) { + throw unexpected(**it); + } + if (children.empty()) { + return std::make_shared(Location { template_str, 0 }, std::string()); + } else if (children.size() == 1) { + return std::move(children[0]); + } else { + return std::make_shared(children[0]->location(), std::move(children)); + } + } + +public: + + static std::shared_ptr parse(const std::string& template_str, const Options & options) { + Parser parser(std::make_shared(normalize_newlines(template_str)), options); + auto tokens = parser.tokenize(); + TemplateTokenIterator begin = tokens.begin(); + auto it = begin; + TemplateTokenIterator end = tokens.end(); + return parser.parseTemplate(begin, it, end, /* full= */ true); + } +}; + +static Value simple_function(const std::string & fn_name, const std::vector & params, const std::function &, Value & args)> & fn) { + std::map named_positions; + for (size_t i = 0, n = params.size(); i < n; i++) named_positions[params[i]] = i; + + return Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) -> Value { + auto args_obj = Value::object(); + std::vector provided_args(params.size()); + for (size_t i = 0, n = args.args.size(); i < n; i++) { + auto & arg = args.args[i]; + if (i < params.size()) { + args_obj.set(params[i], arg); + provided_args[i] = true; + } else { + throw std::runtime_error("Too many positional params for " + fn_name); + } + } + for (auto & [name, value] : args.kwargs) { + auto named_pos_it = named_positions.find(name); + if (named_pos_it == named_positions.end()) { + throw std::runtime_error("Unknown argument " + name + " for function " + fn_name); + } + provided_args[named_pos_it->second] = true; + args_obj.set(name, value); + } + return fn(context, args_obj); + }); +} + +inline std::shared_ptr Context::builtins() { + auto globals = Value::object(); + + globals.set("raise_exception", simple_function("raise_exception", { "message" }, [](const std::shared_ptr &, Value & args) -> Value { + throw std::runtime_error(args.at("message").get()); + })); + globals.set("tojson", simple_function("tojson", { "value", "indent" }, [](const std::shared_ptr &, Value & args) { + return Value(args.at("value").dump(args.get("indent", -1), /* tojson= */ true)); + })); + globals.set("items", simple_function("items", { "object" }, [](const std::shared_ptr &, Value & args) { + auto items = Value::array(); + if (args.contains("object")) { + auto & obj = args.at("object"); + if (obj.is_string()) { + auto json_obj = json::parse(obj.get()); + for (const auto & kv : json_obj.items()) { + items.push_back(Value::array({kv.key(), kv.value()})); + } + } else if (!obj.is_null()) { + for (auto & key : obj.keys()) { + items.push_back(Value::array({key, obj.at(key)})); + } + } + } + return items; + })); + globals.set("last", simple_function("last", { "items" }, [](const std::shared_ptr &, Value & args) { + auto items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not a list"); + if (items.size() == 0) return Value(); + return items.at(items.size() - 1); + })); + globals.set("trim", simple_function("trim", { "text" }, [](const std::shared_ptr &, Value & args) { + auto & text = args.at("text"); + return text.is_null() ? text : Value(strip(text.get())); + })); + globals.set("lower", simple_function("lower", { "text" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text"); + if (text.is_null()) return text; + std::string res; + auto str = text.get(); + std::transform(str.begin(), str.end(), std::back_inserter(res), ::tolower); + return Value(res); + })); + globals.set("default", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + args.expectArgs("default", {2, 3}, {0, 1}); + auto & value = args.args[0]; + auto & default_value = args.args[1]; + bool boolean = false; + if (args.args.size() == 3) { + boolean = args.args[2].get(); + } else { + Value bv = args.get_named("boolean"); + if (!bv.is_null()) { + boolean = bv.get(); + } + } + return boolean ? (value.to_bool() ? value : default_value) : value.is_null() ? default_value : value; + })); + auto escape = simple_function("escape", { "text" }, [](const std::shared_ptr &, Value & args) { + return Value(html_escape(args.at("text").get())); + }); + globals.set("e", escape); + globals.set("escape", escape); + globals.set("joiner", simple_function("joiner", { "sep" }, [](const std::shared_ptr &, Value & args) { + auto sep = args.get("sep", ""); + auto first = std::make_shared(true); + return simple_function("", {}, [sep, first](const std::shared_ptr &, const Value &) -> Value { + if (*first) { + *first = false; + return ""; + } + return sep; + }); + return Value(html_escape(args.at("text").get())); + })); + globals.set("count", simple_function("count", { "items" }, [](const std::shared_ptr &, Value & args) { + return Value((int64_t) args.at("items").size()); + })); + globals.set("dictsort", simple_function("dictsort", { "value" }, [](const std::shared_ptr &, Value & args) { + if (args.size() != 1) throw std::runtime_error("dictsort expects exactly 1 argument (TODO: fix implementation)"); + auto & value = args.at("value"); + auto keys = value.keys(); + std::sort(keys.begin(), keys.end()); + auto res = Value::array(); + for (auto & key : keys) { + res.push_back(Value::array({key, value.at(key)})); + } + return res; + })); + globals.set("join", simple_function("join", { "items", "d" }, [](const std::shared_ptr &, Value & args) { + auto do_join = [](Value & items, const std::string & sep) { + std::ostringstream oss; + auto first = true; + for (size_t i = 0, n = items.size(); i < n; ++i) { + if (first) first = false; + else oss << sep; + oss << items.at(i).to_str(); + } + return Value(oss.str()); + }; + auto sep = args.get("d", ""); + if (args.contains("items")) { + auto & items = args.at("items"); + return do_join(items, sep); + } else { + return simple_function("", {"items"}, [sep, do_join](const std::shared_ptr &, Value & args) { + auto & items = args.at("items"); + if (!items.to_bool() || !items.is_array()) throw std::runtime_error("join expects an array for items, got: " + items.dump()); + return do_join(items, sep); + }); + } + })); + globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + auto ns = Value::object(); + args.expectArgs("namespace", {0, 0}, {0, std::numeric_limits::max()}); + for (auto & [name, value] : args.kwargs) { + ns.set(name, value); + } + return ns; + })); + auto equalto = simple_function("equalto", { "expected", "actual" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("actual") == args.at("expected"); + }); + globals.set("equalto", equalto); + globals.set("==", equalto); + globals.set("length", simple_function("length", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + return (int64_t) items.size(); + })); + globals.set("safe", simple_function("safe", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("string", simple_function("string", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_str(); + })); + globals.set("int", simple_function("int", { "value" }, [](const std::shared_ptr &, Value & args) -> Value { + return args.at("value").to_int(); + })); + globals.set("list", simple_function("list", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + return items; + })); + globals.set("unique", simple_function("unique", { "items" }, [](const std::shared_ptr &, Value & args) -> Value { + auto & items = args.at("items"); + if (!items.is_array()) throw std::runtime_error("object is not iterable"); + std::unordered_set seen; + auto result = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto pair = seen.insert(items.at(i)); + if (pair.second) { + result.push_back(items.at(i)); + } + } + return result; + })); + auto make_filter = [](const Value & filter, Value & extra_args) -> Value { + return simple_function("", { "value" }, [=](const std::shared_ptr & context, Value & args) { + auto & value = args.at("value"); + ArgumentsValue actual_args; + actual_args.args.emplace_back(value); + for (size_t i = 0, n = extra_args.size(); i < n; i++) { + actual_args.args.emplace_back(extra_args.at(i)); + } + return filter.call(context, actual_args); + }); + }; + // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject + globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs("reject", {2, std::numeric_limits::max()}, {0, 0}); + auto & items = args.args[0]; + auto filter_fn = context->get(args.args[1]); + if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + + auto filter_args = Value::array(); + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.push_back(args.args[i]); + } + auto filter = make_filter(filter_fn, filter_args); + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + ArgumentsValue filter_args; + filter_args.args.emplace_back(item); + auto pred_res = filter.call(context, filter_args); + if (!pred_res.to_bool()) { + res.push_back(item); + } + } + return res; + })); + globals.set("map", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + auto res = Value::array(); + if (args.args.size() == 1 && + ((args.has_named("attribute") && args.kwargs.size() == 1) || (args.has_named("default") && args.kwargs.size() == 2))) { + auto & items = args.args[0]; + auto attr_name = args.get_named("attribute"); + auto default_value = args.get_named("default"); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + res.push_back(attr.is_null() ? default_value : attr); + } + } else if (args.kwargs.empty() && args.args.size() >= 2) { + auto fn = context->get(args.args[1]); + if (fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); + ArgumentsValue filter_args { {Value()}, {} }; + for (size_t i = 2, n = args.args.size(); i < n; i++) { + filter_args.args.emplace_back(args.args[i]); + } + for (size_t i = 0, n = args.args[0].size(); i < n; i++) { + auto & item = args.args[0].at(i); + filter_args.args[0] = item; + res.push_back(fn.call(context, filter_args)); + } + } else { + throw std::runtime_error("Invalid or unsupported arguments for map"); + } + return res; + })); + globals.set("indent", simple_function("indent", { "text", "indent", "first" }, [](const std::shared_ptr &, Value & args) { + auto text = args.at("text").get(); + auto first = args.get("first", false); + std::string out; + std::string indent(args.get("indent", 0), ' '); + std::istringstream iss(text); + std::string line; + auto is_first = true; + while (std::getline(iss, line, '\n')) { + auto needs_indent = !is_first || first; + if (is_first) is_first = false; + else out += ENDL; + if (needs_indent) out += indent; + out += line; + } + if (!text.empty() && text.back() == '\n') out += ENDL; + return out; + })); + globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { + args.expectArgs("selectattr", {2, std::numeric_limits::max()}, {0, 0}); + auto & items = args.args[0]; + if (items.is_null()) + return Value::array(); + auto attr_name = args.args[1].get(); + + bool has_test = false; + Value test_fn; + ArgumentsValue test_args {{Value()}, {}}; + if (args.args.size() >= 3) { + has_test = true; + test_fn = context->get(args.args[2]); + if (test_fn.is_null()) throw std::runtime_error("Undefined test: " + args.args[2].dump()); + for (size_t i = 3, n = args.args.size(); i < n; i++) { + test_args.args.emplace_back(args.args[i]); + } + test_args.kwargs = args.kwargs; + } + + auto res = Value::array(); + for (size_t i = 0, n = items.size(); i < n; i++) { + auto & item = items.at(i); + auto attr = item.get(attr_name); + if (has_test) { + test_args.args[0] = attr; + if (test_fn.call(context, test_args).to_bool()) { + res.push_back(item); + } + } else { + res.push_back(attr); + } + } + return res; + })); + globals.set("range", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { + std::vector startEndStep(3); + std::vector param_set(3); + if (args.args.size() == 1) { + startEndStep[1] = args.args[0].get(); + param_set[1] = true; + } else { + for (size_t i = 0; i < args.args.size(); i++) { + auto & arg = args.args[i]; + auto v = arg.get(); + startEndStep[i] = v; + param_set[i] = true; + } + } + for (auto & [name, value] : args.kwargs) { + size_t i; + if (name == "start") i = 0; + else if (name == "end") i = 1; + else if (name == "step") i = 2; + else throw std::runtime_error("Unknown argument " + name + " for function range"); + + if (param_set[i]) { + throw std::runtime_error("Duplicate argument " + name + " for function range"); + } + startEndStep[i] = value.get(); + param_set[i] = true; + } + if (!param_set[1]) { + throw std::runtime_error("Missing required argument 'end' for function range"); + } + int64_t start = param_set[0] ? startEndStep[0] : 0; + int64_t end = startEndStep[1]; + int64_t step = param_set[2] ? startEndStep[2] : 1; + + auto res = Value::array(); + if (step > 0) { + for (int64_t i = start; i < end; i += step) { + res.push_back(Value(i)); + } + } else { + for (int64_t i = start; i > end; i += step) { + res.push_back(Value(i)); + } + } + return res; + })); + + return std::make_shared(std::move(globals)); +} + +inline std::shared_ptr Context::make(Value && values, const std::shared_ptr & parent) { + return std::make_shared(values.is_null() ? Value::object() : std::move(values), parent); +} + +} // namespace minja From e5113e8d746bfc10b70d956a3ae64dd460becfda Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 03:40:34 +0000 Subject: [PATCH 02/42] Add --jinja and --chat-template-file flags --- Makefile | 2 + common/CMakeLists.txt | 2 + common/arg.cpp | 43 ++++++++++- common/common.cpp | 68 +++++++++++++++- common/common.h | 14 +++- examples/server/README.md | 2 +- examples/server/server.cpp | 67 ++++++++++++---- .../server/tests/unit/test_chat_completion.py | 15 ++-- examples/server/tests/utils.py | 7 +- examples/server/utils.hpp | 40 ++++++---- scripts/get_hf_chat_template.py | 77 +++++++++++++++++++ src/CMakeLists.txt | 2 +- 12 files changed, 289 insertions(+), 50 deletions(-) create mode 100755 scripts/get_hf_chat_template.py diff --git a/Makefile b/Makefile index 19ae0d5f1c87b..295522ba356b4 100644 --- a/Makefile +++ b/Makefile @@ -1361,7 +1361,9 @@ llama-server: \ examples/server/httplib.h \ examples/server/index.html.hpp \ examples/server/loading.html.hpp \ + common/chat-template.hpp \ common/json.hpp \ + common/minja.hpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index df1cdf9a59af3..24b7f8741aab4 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-template.hpp common.cpp common.h console.cpp @@ -64,6 +65,7 @@ add_library(${TARGET} STATIC json.hpp log.cpp log.h + minja.hpp ngram-cache.cpp ngram-cache.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index deb11378657f4..edcda60e08e16 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1889,24 +1889,59 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--jinja"}, + "use jinja template for chat (default: disabled)", + [](common_params & params) { + params.use_jinja = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() ), [](common_params & params, const std::string & value) { - if (!common_chat_verify_template(value)) { + if (!common_chat_verify_template(value, params.use_jinja)) { throw std::runtime_error(string_format( - "error: the supplied chat template is not supported: %s\n" - "note: llama.cpp does not use jinja parser, we only support commonly used templates\n", - value.c_str() + "error: the supplied chat template is not supported: %s%s\n", + value.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" )); } params.chat_template = value; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); + add_opt(common_arg( + {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", + "set custom jinja chat template file (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string chat_template; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(chat_template) + ); + if (!common_chat_verify_template(chat_template, params.use_jinja)) { + throw std::runtime_error(string_format( + "error: the supplied chat template is not supported: %s%s\n", + value.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" + )); + } + params.chat_template = chat_template; + } + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); add_opt(common_arg( {"-sps", "--slot-prompt-similarity"}, "SIMILARITY", string_format("how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity), diff --git a/common/common.cpp b/common/common.cpp index 20be9291161ca..6bdcd80a1b756 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1576,13 +1576,13 @@ std::vector common_tokenize( return result; } -std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { +static std::string _common_token_to_piece(const struct llama_model * model, llama_token token, bool special) { std::string piece; piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); if (n_chars < 0) { piece.resize(-n_chars); - int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); GGML_ASSERT(check == -n_chars); } else { @@ -1592,6 +1592,10 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token return piece; } +std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { + return _common_token_to_piece(llama_get_model(ctx), token, special); +} + std::string common_detokenize(llama_context * ctx, const std::vector & tokens, bool special) { std::string text; text.resize(std::max(text.capacity(), tokens.size())); @@ -1612,7 +1616,21 @@ std::string common_detokenize(llama_context * ctx, const std::vector", ""); + chat_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; @@ -1693,6 +1711,48 @@ std::string common_chat_format_example(const struct llama_model * model, return common_chat_apply_template(model, tmpl, msgs, true); } +static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { + int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + return std::string(curr_tmpl_buf.data(), tlen); + } + } + return ""; +} + +llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) +{ + auto bos_token = _common_token_to_piece(model, llama_token_bos(model), true); + auto eos_token = _common_token_to_piece(model, llama_token_eos(model), true); + std::string default_template_src = chat_template_override; + std::string tool_use_template_src = chat_template_override; + if (chat_template_override.empty()) { + default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template"); + tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); + } + if (default_template_src.empty() || default_template_src == "chatml") { + if (!tool_use_template_src.empty()) { + default_template_src = tool_use_template_src; + } else { + default_template_src = R"( + {%- for message in messages -%} + {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} + {%- endfor -%} + {%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} + {%- endif -%} + )"; + } + } + return { + .default_template = { default_template_src, bos_token, eos_token }, + .tool_use_template = tool_use_template_src.empty() ? std::nullopt + : std::optional({ tool_use_template_src, bos_token, eos_token }), + }; +} + // // KV cache utils // diff --git a/common/common.h b/common/common.h index 1d2bd932c211d..7747d66d55b67 100644 --- a/common/common.h +++ b/common/common.h @@ -3,6 +3,7 @@ #pragma once #include "llama.h" +#include "chat-template.hpp" #include #include @@ -324,6 +325,7 @@ struct common_params { std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT std::string chat_template = ""; // NOLINT + bool use_jinja = false; // NOLINT bool enable_chat_template = true; std::vector api_keys; @@ -571,8 +573,8 @@ struct common_chat_msg { std::string content; }; -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool common_chat_verify_template(const std::string & tmpl); +// Check if the template is supported or not. Returns true if it's valid +bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml @@ -593,6 +595,14 @@ std::string common_chat_format_single(const struct llama_model * model, std::string common_chat_format_example(const struct llama_model * model, const std::string & tmpl); + +struct llama_chat_templates { + minja::chat_template default_template; + std::optional tool_use_template; +}; + +llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); + // // KV cache utils // diff --git a/examples/server/README.md b/examples/server/README.md index c7d91be9976c4..24ef85727092d 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -129,7 +129,7 @@ The project is under active development, and we are [looking for feedback and co | `--grammar GRAMMAR` | BNF-like grammar to constrain generations (see samples in grammars/ dir) (default: '') | | `--grammar-file FNAME` | file to read grammar from | | `-j, --json-schema SCHEMA` | JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object
For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead | - +| `--jinja` | Enable experimental Jinja templating engine (needed for tool use) | **Example-specific params** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 30ff3b14957dc..cfa90056ae995 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1623,15 +1623,35 @@ struct server_context { return true; } - bool validate_model_chat_template() const { - std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - if (res >= 0) { - llama_chat_message chat[] = {{"user", "test"}}; - std::string tmpl = std::string(model_template.data(), model_template.size()); - int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); - return chat_res > 0; + bool validate_model_chat_template(bool use_jinja) const { + llama_chat_message chat[] = {{"user", "test"}}; + + if (use_jinja) { + auto templates = llama_chat_templates_from_model(model, ""); + try { + templates.default_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + if (templates.tool_use_template) { + templates.tool_use_template->apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + } + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + } + } else { + std::vector model_template(2048, 0); // longest known template is about 1200 bytes + std::string template_key = "tokenizer.chat_template"; + int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); + if (res >= 0) { + std::string tmpl = std::string(model_template.data(), model_template.size()); + int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0); + return chat_res > 0; + } } return false; } @@ -3476,15 +3496,30 @@ int main(int argc, char ** argv) { } }; - const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + std::mutex chat_templates_mutex; + std::optional chat_templates; + + auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & { + std::lock_guard lock(chat_templates_mutex); + if (!chat_templates) { + chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); + } + return *chat_templates; + }; + + const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) { // this endpoint is publicly available, please only return what is safe to be exposed + const auto & templates = get_chat_templates(); json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", llama_get_chat_template(ctx_server.model) }, + { "chat_template", templates.default_template.source() }, { "build_info", build_info }, }; + if (ctx_server.params_base.use_jinja && templates.tool_use_template) { + data["chat_template_tool_use"] = templates.tool_use_template->source(); + } res_ok(res, data); }; @@ -3685,13 +3720,17 @@ int main(int argc, char ** argv) { return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic, &get_chat_templates](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } - json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); + auto body = json::parse(req.body); + const auto & templates = get_chat_templates(); + const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; + json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja); + return handle_completions_generic( SERVER_TASK_TYPE_COMPLETION, data, @@ -4111,7 +4150,7 @@ int main(int argc, char ** argv) { // if a custom chat template is not supplied, we will use the one that comes with the model (if any) if (params.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { + if (!ctx_server.validate_model_chat_template(params.use_jinja)) { LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); params.chat_template = "chatml"; } diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 88549708113e9..ef716cc1ab223 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -4,22 +4,24 @@ server = ServerPreset.tinyllama2() - -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja", [ - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja): global server + server.jinja = jinja server.start() res = server.make_request("POST", "/chat/completions", data={ "model": model, @@ -102,6 +104,7 @@ def test_chat_completion_with_openai_library(): @pytest.mark.parametrize("response_format,n_predicted,re_content", [ ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), + ({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""), ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), ({"type": "json_object"}, 10, "(\\{|John)+"), ({"type": "sound"}, 0, None), diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 277125e88b534..f0fe7b15dbf68 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -68,8 +68,9 @@ class ServerProcess: pooling: str | None = None draft: int | None = None api_key: str | None = None - response_format: str | None = None lora_files: List[str] | None = None + chat_template_file: str | None = None + jinja: bool | None = None disable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None @@ -154,6 +155,10 @@ def start(self, timeout_seconds: int = 10) -> None: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.jinja: + server_args.append("--jinja") if self.disable_ctx_shift: server_args.extend(["--no-context-shift"]) if self.api_key: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 334f2f19207ef..81a2d62e960bc 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -16,6 +16,8 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "minja.hpp" +#include "chat-template.hpp" #include #include @@ -382,19 +384,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri return formatted_chat; } -static std::string llama_get_chat_template(const struct llama_model * model) { - std::string template_key = "tokenizer.chat_template"; - // call with NULL buffer to get the total size of the string - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0); - if (res < 2) { - return ""; - } else { - std::vector model_template(res + 1, 0); - llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - return std::string(model_template.data(), model_template.size() - 1); - } -} - // // base64 utils (TODO: move to common in the future) // @@ -552,11 +541,21 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const std::string & chat_template) { + const minja::chat_template & tmpl, + bool use_jinja) +{ json llama_params; - // Apply chat template to the list of messages - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages")); + auto tools = json_value(body, "tools", json()); + auto has_tools = tools.is_array() && !tools.empty(); + + if (has_tools) { + if (use_jinja) { + LOG_WRN("tools param is not fully supported yet\n"); + } else { + throw std::runtime_error("tools param requires --jinja flag"); + } + } // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -579,6 +578,13 @@ static json oaicompat_completion_params_parse( } } + // Apply chat template to the list of messages + if (use_jinja) { + llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); + } else { + llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); + } + // Handle "n" field int n_choices = json_value(body, "n", 1); if (n_choices != 1) { @@ -594,7 +600,7 @@ static json oaicompat_completion_params_parse( } // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "tools", "tool_choice" }; + static const std::vector unsupported_params { "tool_choice" }; for (const auto & param : unsupported_params) { if (body.contains(param)) { throw std::runtime_error("Unsupported param: " + param); diff --git a/scripts/get_hf_chat_template.py b/scripts/get_hf_chat_template.py new file mode 100755 index 0000000000000..820b84efc26b1 --- /dev/null +++ b/scripts/get_hf_chat_template.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python +''' + Fetches the Jinja chat template of a HuggingFace model. + If a model has multiple chat templates, you can specify the variant name. + + Syntax: + ./scripts/get_hf_chat_template.py model_id [variant] + + Examples: + ./scripts/get_hf_chat_template.py NousResearch/Meta-Llama-3-8B-Instruct + ./scripts/get_hf_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use + ./scripts/get_hf_chat_template.py meta-llama/Llama-3.2-3B-Instruct +''' + +import json +import re +import sys + + +def get_hf_chat_template(model_id, variant=None): + try: + # Use huggingface_hub library if available. + # Allows access to gated models if the user has access and ran `huggingface-cli login`. + from huggingface_hub import hf_hub_download + with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json")) as f: + config_str = f.read() + except ImportError: + import requests + assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" + response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") + if response.status_code == 401: + raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') + response.raise_for_status() + config_str = response.text + + try: + config = json.loads(config_str) + except json.JSONDecodeError: + # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json + # (Remove extra '}' near the end of the file) + config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + + chat_template = config['chat_template'] + if isinstance(chat_template, str): + return chat_template + else: + variants = { + ct['name']: ct['template'] + for ct in chat_template + } + + def format_variants(): + return ', '.join(f'"{v}"' for v in variants.keys()) + + if variant is None: + if 'default' not in variants: + raise Exception(f'Please specify a chat template variant (one of {format_variants()})') + variant = 'default' + print(f'Note: picked "default" chat template variant (out of {format_variants()})', file=sys.stderr) + elif variant not in variants: + raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") + + return variants[variant] + + +def main(args): + if len(args) < 1: + raise ValueError("Please provide a model ID and an optional variant name") + model_id = args[0] + variant = None if len(args) < 2 else args[1] + + template = get_hf_chat_template(model_id, variant) + print(template, end=None) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2d3ea09945790..4bb58146ede32 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,7 +17,7 @@ add_library(llama unicode-data.cpp ) -target_include_directories(llama PUBLIC . ../include) +target_include_directories(llama PUBLIC . ../include ../common) target_compile_features (llama PUBLIC cxx_std_17) # don't bump target_link_libraries(llama PUBLIC ggml) From 80138d90073f8ed3978f8688ed856a12e6509247 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 04:10:20 +0000 Subject: [PATCH 03/42] Add missing include --- common/common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/common/common.h b/common/common.h index 7747d66d55b67..2693b805ec2fa 100644 --- a/common/common.h +++ b/common/common.h @@ -5,6 +5,7 @@ #include "llama.h" #include "chat-template.hpp" +#include #include #include #include From 06b5159560de404c018026099bdc636f4d2930c6 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 04:10:35 +0000 Subject: [PATCH 04/42] Avoid print in get_hf_chat_template.py --- scripts/get_hf_chat_template.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/get_hf_chat_template.py b/scripts/get_hf_chat_template.py index 820b84efc26b1..23bb1de59acc3 100755 --- a/scripts/get_hf_chat_template.py +++ b/scripts/get_hf_chat_template.py @@ -56,7 +56,7 @@ def format_variants(): if 'default' not in variants: raise Exception(f'Please specify a chat template variant (one of {format_variants()})') variant = 'default' - print(f'Note: picked "default" chat template variant (out of {format_variants()})', file=sys.stderr) + sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n') elif variant not in variants: raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") @@ -70,7 +70,7 @@ def main(args): variant = None if len(args) < 2 else args[1] template = get_hf_chat_template(model_id, variant) - print(template, end=None) + sys.stdout.write(template) if __name__ == '__main__': From ce48584f7d1f3fb90e767f9d6ef4ddd69b05351b Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 04:19:33 +0000 Subject: [PATCH 05/42] No designated initializers yet --- common/common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 6bdcd80a1b756..45c8c9b525d96 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1747,8 +1747,8 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * } } return { - .default_template = { default_template_src, bos_token, eos_token }, - .tool_use_template = tool_use_template_src.empty() ? std::nullopt + /* .default_template = */ { default_template_src, bos_token, eos_token }, + /* .tool_use_template = */ tool_use_template_src.empty() ? std::nullopt : std::optional({ tool_use_template_src, bos_token, eos_token }), }; } From 389d79b6b4c1065a03a12a3c27870cc4f9695b80 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 04:39:35 +0000 Subject: [PATCH 06/42] Try and work around msvc++ non-macro max resolution quirk --- common/minja.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/minja.hpp b/common/minja.hpp index 9d9a1a08faf4d..2639c15a0c738 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -2541,7 +2541,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { auto ns = Value::object(); - args.expectArgs("namespace", {0, 0}, {0, std::numeric_limits::max()}); + args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); for (auto & [name, value] : args.kwargs) { ns.set(name, value); } @@ -2596,7 +2596,7 @@ inline std::shared_ptr Context::builtins() { }; // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("reject", {2, std::numeric_limits::max()}, {0, 0}); + args.expectArgs("reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2667,7 +2667,7 @@ inline std::shared_ptr Context::builtins() { return out; })); globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("selectattr", {2, std::numeric_limits::max()}, {0, 0}); + args.expectArgs("selectattr", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; if (items.is_null()) return Value::array(); From 238b9689e04e5c5c31f7f38ba89302853ce6a93e Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 30 Dec 2024 04:59:13 +0000 Subject: [PATCH 07/42] Update test_chat_completion.py --- examples/server/tests/unit/test_chat_completion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index ef716cc1ab223..996cd0aa01caf 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -104,7 +104,6 @@ def test_chat_completion_with_openai_library(): @pytest.mark.parametrize("response_format,n_predicted,re_content", [ ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), - ({"type": "json_schema", "json_schema": {"const": "42"}}, 6, "\"42\""), ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), ({"type": "json_object"}, 10, "(\\{|John)+"), ({"type": "sound"}, 0, None), From 78861a3eb2f8583115cba378caad95b34c274b9c Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 19:58:15 +0000 Subject: [PATCH 08/42] Wire LLM_KV_TOKENIZER_CHAT_TEMPLATE_N in llama_model_chat_template --- common/common.cpp | 16 ++-------------- examples/run/run.cpp | 4 ++-- examples/simple-chat/simple-chat.cpp | 2 +- include/llama.h | 2 +- src/llama-arch.cpp | 6 ++++-- src/llama-arch.h | 4 +++- src/llama-model.cpp | 6 ++++-- 7 files changed, 17 insertions(+), 23 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9cd3713269175..275aa7385b11f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1822,17 +1822,6 @@ std::string common_chat_format_example(const struct llama_model * model, return common_chat_apply_template(model, tmpl, msgs, true); } -static std::string _llama_model_meta_val_str(const struct llama_model * model, const char * key) { - int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - return std::string(curr_tmpl_buf.data(), tlen); - } - } - return ""; -} - llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); @@ -1841,9 +1830,8 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * std::string default_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override; if (chat_template_override.empty()) { - // TODO: - default_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template"); - tool_use_template_src = _llama_model_meta_val_str(model, "tokenizer.chat_template.tool_use"); + default_template_src = llama_model_chat_template(model, /* name */ nullptr); + tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use"); } if (default_template_src.empty() || default_template_src == "chatml") { if (!tool_use_template_src.empty()) { diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 0ad8bb15b27fb..1c838aa777822 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData & // Function to apply the chat template and resize `formatted` if needed static int apply_chat_template(LlamaData & llama_data, const bool append) { int result = llama_chat_apply_template( - llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append, + llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), + result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index e8eda9c223288..46aeae2a9073e 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { break; } - const char * tmpl = llama_model_chat_template(model); + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); // add the user input to the message list and format it messages.push_back({"user", strdup(user.c_str())}); diff --git a/include/llama.h b/include/llama.h index a184884c77a51..b5462157f31f2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -503,7 +503,7 @@ extern "C" { LLAMA_API uint64_t llama_model_size(const struct llama_model * model); // Get the default chat template. Returns nullptr if not available - LLAMA_API const char * llama_model_chat_template(const struct llama_model * model); + LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d7d277e72977a..a7260f495d945 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -179,6 +179,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, "tokenizer.chat_template.%s" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, @@ -1443,10 +1444,11 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; -LLM_KV::LLM_KV(llm_arch arch) : arch(arch) {} +LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { - return ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); + return suffix ? ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch), suffix) + : ::format(LLM_KV_NAMES.at(kv), LLM_ARCH_NAMES.at(arch)); } std::string LLM_TN_IMPL::str() const { diff --git a/src/llama-arch.h b/src/llama-arch.h index 349844790453f..122fdcebe0af6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -177,6 +177,7 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_CHAT_TEMPLATE_N, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -335,9 +336,10 @@ enum llm_tensor_layer { }; struct LLM_KV { - LLM_KV(llm_arch arch); + LLM_KV(llm_arch arch, const char * suffix = nullptr); llm_arch arch; + const char * suffix; std::string operator()(llm_kv kv) const; }; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f90f5e746077b..dea03c6f2979e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -3912,8 +3912,10 @@ uint64_t llama_model_size(const struct llama_model * model) { return model->size(); } -const char * llama_model_chat_template(const struct llama_model * model) { - const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE)); +const char * llama_model_chat_template(const struct llama_model * model, const char * name) { + const auto key = name ? LLM_KV(model->arch, name)(LLM_KV_TOKENIZER_CHAT_TEMPLATE_N) + : LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE); + const auto & it = model->gguf_kv.find(key); if (it == model->gguf_kv.end()) { return nullptr; } From 1aac99ad546b50def4a1ca64ad268d45cdf0f9a0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 20:11:27 +0000 Subject: [PATCH 09/42] Refactor test-chat-template --- tests/test-chat-template.cpp | 294 +++++++++++++++++++---------------- 1 file changed, 162 insertions(+), 132 deletions(-) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 77d38695498f5..e15238d40bb06 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -9,7 +9,7 @@ #include "common.h" int main(void) { - llama_chat_message conversation[] = { + std::vector conversation { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, @@ -17,130 +17,161 @@ int main(void) { {"assistant", " I am an assistant "}, {"user", "Another question"}, }; - size_t message_count = 6; - std::vector templates = { - // teknium/OpenHermes-2.5-Mistral-7B - "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - // mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt) - "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", - // TheBloke/FusionNet_34Bx2_MoE-AWQ - "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", - // bofenghuang/vigogne-2-70b-chat - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - // mlabonne/AlphaMonarch-7B - "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - // google/gemma-7b-it - "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", - // OrionStarAI/Orion-14B-Chat - "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", - // openchat/openchat-3.5-0106 - // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d - // So we match against the included template but implement the suggested version. - "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", - // deepseek-ai/deepseek-coder-33b-instruct - "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", - // eachadea/vicuna-13b-1.1 - // No template included in tokenizer_config.json, so this template likely needs to be manually set. - "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", - // Orca-Vicuna - // No template included in tokenizer_config.json, so this template likely needs to be manually set. - "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", - // CohereForAI/c4ai-command-r-plus - "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - // Llama-3 - "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", - //Phi-3-mini - "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - //Phi-3-small - "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", - //Phi-3-medium - "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - //Phi-3-vision - "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", - // ChatGLM3 - "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - // ChatGLM4 - u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - // MiniCPM-3B-OpenHermes-2.5-v2-GGUF - u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", - // DeepSeek-V2 - "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", - // ibm-granite/granite-3.0-8b-instruct - "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}", - // mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt) - "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n", - // Mistral-Large-Instruct-2407 (mistralai 'v3' template) - "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", - // Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template) - "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", - // mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template) - "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}", - // ai-sage/GigaChat-20B-A3B-instruct - "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}", - // Infinigence/Megrez-3B-Instruct - u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}", - // phi-4 - "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}", + struct ChatTemplate { + std::string name; + std::string template_str; + std::string expected_output; }; - std::vector expected_output = { - // teknium/OpenHermes-2.5-Mistral-7B - "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", - // mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt) - "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - // TheBloke/FusionNet_34Bx2_MoE-AWQ - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - // bofenghuang/vigogne-2-70b-chat - "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", - // mlabonne/AlphaMonarch-7B - "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", - // google/gemma-7b-it - "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", - // OrionStarAI/Orion-14B-Chat - "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", - // openchat/openchat-3.5-0106 - "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", - // deepseek-ai/deepseek-coder-33b-instruct - "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", - // eachadea/vicuna-13b-1.1 - "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", - // Orca-Vicuna - "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", - // CohereForAI/c4ai-command-r-plus - "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", - // Llama 3 - "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - //Phi-3-mini - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - //Phi-3-small - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - //Phi-3-medium - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - //Phi-3-vision - "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", - // ChatGLM3 - "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", - // ChatGLM4 - "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", - // MiniCPM-3B-OpenHermes-2.5-v2-GGUF - u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", - // DeepSeek-V2 - u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", - // ibm-granite/granite-3.0-8b-instruct - "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", - // mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt) - " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", - // Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start) - "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", - // Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start) - "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", - // mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template) - "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant [INST] Another question[/INST]", - // ai-sage/GigaChat-20B-A3B-instruct - "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>", - // Infinigence/Megrez-3B-Instruct - "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>", - // phi-4 - "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>", + std::vector templates { + { + /* .name= */ "teknium/OpenHermes-2.5-Mistral-7B", + /* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", + /* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + }, + { + /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)", + /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + }, + { + /* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ", + /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", + /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + }, + { + /* .name= */ "bofenghuang/vigogne-2-70b-chat", + /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", + }, + { + /* .name= */ "mlabonne/AlphaMonarch-7B", + /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", + /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + }, + { + /* .name= */ "google/gemma-7b-it", + /* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", + /* .expected_output= */ "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + }, + { + /* .name= */ "OrionStarAI/Orion-14B-Chat", + /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", + /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + }, + { + /* .name= */ "openchat/openchat-3.5-0106", + // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d + // So we match against the included template but implement the suggested version. + /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + /* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + }, + { + /* .name= */ "deepseek-ai/deepseek-coder-33b-instruct", + /* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + /* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + }, + { + /* .name= */ "eachadea/vicuna-13b-1.1", + // No template included in tokenizer_config.json, so this template likely needs to be manually set. + /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", + /* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + }, + { + /* .name= */ "Orca-Vicuna", + // No template included in tokenizer_config.json, so this template likely needs to be manually set. + /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", + /* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + }, + { + /* .name= */ "CohereForAI/c4ai-command-r-plus", + /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", + /* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + }, + { + /* .name= */ "Llama-3", + /* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", + /* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + }, + { + /* .name= */ "Phi-3-mini", + /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + /* .name= */ "Phi-3-small", + /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + /* .name= */ "Phi-3-medium", + /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + /* .name= */ "Phi-3-vision", + /* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", + /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + }, + { + /* .name= */ "ChatGLM3", + /* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + /* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + }, + { + /* .name= */ "ChatGLM4", + /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", + /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + }, + { + /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF", + /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", + /* .expected_output= */ u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", + }, + { + /* .name= */ "DeepSeek-V2", + /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + }, + { + /* .name= */ "ibm-granite/granite-3.0-8b-instruct", + /* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}", + /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + }, + { + /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)", + /* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n", + /* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + }, + { + /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)", + /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", + /* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", + }, + { + /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)", + /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", + /* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", + }, + { + /* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)", + /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}", + /* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant [INST] Another question[/INST]", + }, + { + /* .name= */ "ai-sage/GigaChat-20B-A3B-instruct", + /* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}", + /* .expected_output= */ "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>", + }, + { + /* .name= */ "Infinigence/Megrez-3B-Instruct", + /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}", + /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>", + }, + { + /* .name= */ "phi-4", + /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}", + /* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>", + }, }; std::vector formatted_chat(1024); int32_t res; @@ -157,17 +188,16 @@ int main(void) { } // test invalid chat template - res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); + res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); - for (size_t i = 0; i < templates.size(); i++) { - std::string custom_template = templates[i]; - std::string expected = expected_output[i]; + for (const auto & tmpl : templates) { + printf("\n\n=== %s ===\n\n", tmpl.name.c_str()); formatted_chat.resize(1024); res = llama_chat_apply_template( - custom_template.c_str(), - conversation, - message_count, + tmpl.template_str.c_str(), + conversation.data(), + conversation.size(), true, formatted_chat.data(), formatted_chat.size() @@ -176,7 +206,7 @@ int main(void) { std::string output(formatted_chat.data(), formatted_chat.size()); printf("%s\n", output.c_str()); printf("-------------------------\n"); - assert(output == expected); + assert(output == tmpl.expected_output); } From 7c84ebc231ce48fa052f0b08d6ef67559b7019da Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 21:23:30 +0000 Subject: [PATCH 10/42] Test templates w/ minja --- tests/test-chat-template.cpp | 186 +++++++++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 41 deletions(-) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index e15238d40bb06..cddc89f8e8f1e 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -7,6 +7,7 @@ #include "llama.h" #include "common.h" +#include "chat-template.hpp" int main(void) { std::vector conversation { @@ -17,160 +18,232 @@ int main(void) { {"assistant", " I am an assistant "}, {"user", "Another question"}, }; - struct ChatTemplate { + struct TestCase { std::string name; std::string template_str; - std::string expected_output; + std::string expected_output_adhoc; + std::string expected_output_jinja; + std::string bos_token = ""; + std::string eos_token = ""; + bool supported_with_jinja = true; }; - std::vector templates { + std::vector test_cases { { /* .name= */ "teknium/OpenHermes-2.5-Mistral-7B", /* .template_str= */ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\\n' }}{% endif %}", - /* .expected_output= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + /* .expected_output_adhoc= */ "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi there<|im_end|>\n<|im_start|>user\nWho are you<|im_end|>\n<|im_start|>assistant\n I am an assistant <|im_end|>\n<|im_start|>user\nAnother question<|im_end|>\n<|im_start|>assistant\n", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (NOTE: Old pre-v1 without a system prompt)", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", - /* .expected_output= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_adhoc= */ "[INST] You are a helpful assistant\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "TheBloke/FusionNet_34Bx2_MoE-AWQ", - /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\\n{%- if messages[idx]['role'] == 'user' -%}\\n{%- if idx > 1 -%}\\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\\n{%- else -%}\\n{{- messages[idx]['content'] + ' [/INST]' -}}\\n{%- endif -%}\\n{% elif messages[idx]['role'] == 'system' %}\\n{{- '[INST] <>\\\\n' + messages[idx]['content'] + '\\\\n<>\\\\n\\\\n' -}}\\n{%- elif messages[idx]['role'] == 'assistant' -%}\\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\\n{% endif %}\\n{% endfor %}", - /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .template_str= */ "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", + /* .expected_output_adhoc= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "bofenghuang/vigogne-2-70b-chat", - /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\\\n' + system_message + '\\\\n<>\\\\n\\\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\\\n' + content.strip() + '\\\\n<>\\\\n\\\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", - /* .expected_output= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", + /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif true == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'Vous êtes Vigogne, un assistant IA créé par Zaion Lab. Vous suivez extrêmement bien les instructions. Aidez autant que vous le pouvez.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", + /* .expected_output_adhoc= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST]Hi there[INST] Who are you [/INST]I am an assistant[INST] Another question [/INST]", + /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "mlabonne/AlphaMonarch-7B", /* .template_str= */ "{% for message in messages %}{{bos_token + message['role'] + '\\n' + message['content'] + eos_token + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ bos_token + 'assistant\\n' }}{% endif %}", - /* .expected_output= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_adhoc= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .expected_output_jinja= */ "system\nYou are a helpful assistant\nuser\nHello\nassistant\nHi there\nuser\nWho are you\nassistant\n I am an assistant \nuser\nAnother question\nassistant\n", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "google/gemma-7b-it", /* .template_str= */ "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\\n' + message['content'] | trim + '\\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\\n'}}{% endif %}", - /* .expected_output= */ "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + /* .expected_output_adhoc= */ "user\nYou are a helpful assistant\n\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", + /* .expected_output_jinja= */ "user\nYou are a helpful assistant\nHello\nmodel\nHi there\nuser\nWho are you\nmodel\nI am an assistant\nuser\nAnother question\nmodel\n", }, { /* .name= */ "OrionStarAI/Orion-14B-Chat", /* .template_str= */ "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}", - /* .expected_output= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_adhoc= */ "Human: You are a helpful assistant\n\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .expected_output_jinja= */ "Human: You are a helpful assistant\nHello\n\nAssistant: Hi thereHuman: Who are you\n\nAssistant: I am an assistant Human: Another question\n\nAssistant: ", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "openchat/openchat-3.5-0106", // The included chat_template differs from the author's suggestions here: https://huggingface.co/openchat/openchat_3.5/discussions/5#65448109b4a3f3a2f486fd9d // So we match against the included template but implement the suggested version. /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", - /* .expected_output= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + /* .expected_output_adhoc= */ "You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", + /* .expected_output_jinja= */ "GPT4 Correct System: You are a helpful assistant<|end_of_turn|>GPT4 Correct User: Hello<|end_of_turn|>GPT4 Correct Assistant: Hi there<|end_of_turn|>GPT4 Correct User: Who are you<|end_of_turn|>GPT4 Correct Assistant: I am an assistant <|end_of_turn|>GPT4 Correct User: Another question<|end_of_turn|>GPT4 Correct Assistant:", }, { /* .name= */ "deepseek-ai/deepseek-coder-33b-instruct", /* .template_str= */ "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", - /* .expected_output= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + /* .expected_output_adhoc= */ "You are a helpful assistant### Instruction:\nHello\n### Response:\nHi there\n<|EOT|>\n### Instruction:\nWho are you\n### Response:\n I am an assistant \n<|EOT|>\n### Instruction:\nAnother question\n### Response:\n", + /* .expected_output_jinja= */ "", }, { /* .name= */ "eachadea/vicuna-13b-1.1", // No template included in tokenizer_config.json, so this template likely needs to be manually set. /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{- '' + message['content'] + '\n\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", - /* .expected_output= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + /* .expected_output_adhoc= */ "You are a helpful assistant\n\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "Orca-Vicuna", // No template included in tokenizer_config.json, so this template likely needs to be manually set. /* .template_str= */ "{%- for message in messages %}{%- if message['role'] == 'system' -%}{{-'SYSTEM: ' + message['content'] + '\n' -}}{%- else -%}{%- if message['role'] == 'user' -%}{{-'USER: ' + message['content'] + '\n'-}}{%- else -%}{{-'ASSISTANT: ' + message['content'] + '\n' -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{-'ASSISTANT:'-}}{%- endif -%}", - /* .expected_output= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + /* .expected_output_adhoc= */ "SYSTEM: You are a helpful assistant\nUSER: Hello\nASSISTANT: Hi there\nUSER: Who are you\nASSISTANT: I am an assistant \nUSER: Another question\nASSISTANT:", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "CohereForAI/c4ai-command-r-plus", /* .template_str= */ "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", - /* .expected_output= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + /* .expected_output_adhoc= */ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a helpful assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Who are you<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>I am an assistant<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Another question<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + /* .expected_output_jinja= */ "", }, { /* .name= */ "Llama-3", /* .template_str= */ "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}", - /* .expected_output= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + /* .expected_output_adhoc= */ "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nHi there<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWho are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI am an assistant<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nAnother question<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", + /* .expected_output_jinja= */ "", }, { /* .name= */ "Phi-3-mini", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_adhoc= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", }, { /* .name= */ "Phi-3-small", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}", - /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_adhoc= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "", }, { /* .name= */ "Phi-3-medium", /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}", - /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_adhoc= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "<|user|>\nYou are a helpful assistant\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", }, { /* .name= */ "Phi-3-vision", /* .template_str= */ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}", - /* .expected_output= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_adhoc= */ "<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "ChatGLM3", /* .template_str= */ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - /* .expected_output= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + /* .expected_output_adhoc= */ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>", + /* .expected_output_jinja= */ "[gMASK]sop<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", }, { /* .name= */ "ChatGLM4", /* .template_str= */ u8"[gMASK]{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", - /* .expected_output= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + /* .expected_output_adhoc= */ "[gMASK]<|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "MiniCPM-3B-OpenHermes-2.5-v2-GGUF", /* .template_str= */ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + ''}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}", - /* .expected_output= */ u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", + /* .expected_output_adhoc= */ u8"You are a helpful assistant<用户>HelloHi there<用户>Who are youI am an assistant<用户>Another question", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "DeepSeek-V2", /* .template_str= */ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", - /* .expected_output= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + /* .expected_output_adhoc= */ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "<|end▁of▁sentence|>", }, { /* .name= */ "ibm-granite/granite-3.0-8b-instruct", /* .template_str= */ "{%- if tools %}\n {{- '<|start_of_role|>available_tools<|end_of_role|>\n' }}\n {%- for tool in tools %}\n {{- tool | tojson(indent=4) }}\n {%- if not loop.last %}\n {{- '\n\n' }}\n {%- endif %}\n {%- endfor %}\n {{- '<|end_of_text|>\n' }}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{- '<|start_of_role|>system<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'user' %}\n {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant' %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'assistant_tool_call' %}\n {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- elif message['role'] == 'tool_response' %}\n {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}\n {%- endif %}\n {%- if loop.last and add_generation_prompt %}\n {{- '<|start_of_role|>assistant<|end_of_role|>' }}\n {%- endif %}\n{%- endfor %}", - /* .expected_output= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + /* .expected_output_adhoc= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>\n", + /* .expected_output_jinja= */ "<|start_of_role|>system<|end_of_role|>You are a helpful assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Hello<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>Hi there<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Who are you<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|> I am an assistant <|end_of_text|>\n<|start_of_role|>user<|end_of_role|>Another question<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>", }, { /* .name= */ "mistralai/Mistral-7B-Instruct-v0.2 (mistralai 'v1' template with a system prompt)", /* .template_str= */ "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n", - /* .expected_output= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_adhoc= */ " [INST] You are a helpful assistant\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "Mistral-Large-Instruct-2407 (mistralai 'v3' template; modified to have system prompt at start)", /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS] [\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST] \" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST] \" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message.tool_calls is defined and message.tool_calls is not none %}\n {{- \"[TOOL_CALLS] [\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- \" \" + message[\"content\"]|trim + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS] {\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", - /* .expected_output= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", + /* .expected_output_adhoc= */ "[INST] You are a helpful assistant\n\nHello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] Another question[/INST]", + /* .expected_output_jinja= */ "[INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant[INST] You are a helpful assistant\n\nAnother question[/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "Mistral-Nemo-Instruct-2407 (mistralai 'v3-tekken' template; modified to have system prompt at start)", /* .template_str= */ "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{#- This block checks for alternating user/assistant messages, skipping tool calling messages #}\n{%- set ns = namespace() %}\n{%- set ns.index = 0 %}\n{%- for message in loop_messages %}\n {%- if not (message.role == \"tool\" or message.role == \"tool_results\" or (message.tool_calls is defined and message.tool_calls is not none)) %}\n {%- if (message[\"role\"] == \"user\") != (ns.index % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n {%- set ns.index = ns.index + 1 %}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif (message.tool_calls is defined and message.tool_calls is not none) %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in message.tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n", - /* .expected_output= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", + /* .expected_output_adhoc= */ "[INST]You are a helpful assistant\n\nHello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]Another question[/INST]", + /* .expected_output_jinja= */ "[INST]Hello[/INST]Hi there[INST]Who are you[/INST] I am an assistant [INST]You are a helpful assistant\n\nAnother question[/INST]", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "mistralai/Mistral-Large-Instruct-2411 (mistralai 'v7' template)", /* .template_str= */ "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'system' %}{{ '[SYSTEM_PROMPT] ' + message['content'] + '[/SYSTEM_PROMPT]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token }}{% else %}{{ raise_exception('Only user, system and assistant roles are supported!') }}{% endif %}{% endfor %}", - /* .expected_output= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant [INST] Another question[/INST]", + /* .expected_output_adhoc= */ "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT][INST] Hello[/INST] Hi there[INST] Who are you[/INST] I am an assistant [INST] Another question[/INST]", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "ai-sage/GigaChat-20B-A3B-instruct", /* .template_str= */ "{% if messages[0]['role'] == 'system' -%}\n {%- set loop_messages = messages[1:] -%}\n {%- set system_message = bos_token + messages[0]['content'] + additional_special_tokens[1] -%}\n{%- else -%}\n {%- set loop_messages = messages -%}\n {%- set system_message = bos_token + '' -%}\n{%- endif -%}\n{%- for message in loop_messages %}\n {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}\n {% endif %}\n \n {%- if loop.index0 == 0 -%}\n {{ system_message -}}\n {%- endif -%}\n {%- if message['role'] == 'user' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {{ 'available functions' + additional_special_tokens[0] + additional_special_tokens[2] + additional_special_tokens[3] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if message['role'] == 'assistant' -%}\n {{ message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1] -}}\n {%- endif -%}\n {%- if loop.last and add_generation_prompt -%}\n {{ 'assistant' + additional_special_tokens[0] -}}\n {%- endif -%}\n{%- endfor %}", - /* .expected_output= */ "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>", + /* .expected_output_adhoc= */ "You are a helpful assistant<|message_sep|>user<|role_sep|>Hello<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>Hi there<|message_sep|>user<|role_sep|>Who are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|> I am an assistant <|message_sep|>user<|role_sep|>Another question<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", + /* .supported_with_jinja= */ false, // Requires additional_special_tokens as extra context }, { /* .name= */ "Infinigence/Megrez-3B-Instruct", /* .template_str= */ u8"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|role_start|>system<|role_end|>你是Megrez-3B-Instruct,将针对用户的问题给出详细的、积极的回答。<|turn_end|>' }}{% endif %}{{ '<|role_start|>' + message['role'] + '<|role_end|>' + message['content'] + '<|turn_end|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|role_start|>assistant<|role_end|>' }}{% endif %}", - /* .expected_output= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>", + /* .expected_output_adhoc= */ "<|role_start|>system<|role_end|>You are a helpful assistant<|turn_end|><|role_start|>user<|role_end|>Hello<|turn_end|><|role_start|>assistant<|role_end|>Hi there<|turn_end|><|role_start|>user<|role_end|>Who are you<|turn_end|><|role_start|>assistant<|role_end|> I am an assistant <|turn_end|><|role_start|>user<|role_end|>Another question<|turn_end|><|role_start|>assistant<|role_end|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, { /* .name= */ "phi-4", /* .template_str= */ "{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|im_start|>system<|im_sep|>' + message['content'] + '<|im_end|>'}}{% elif (message['role'] == 'user') %}{{'<|im_start|>user<|im_sep|>' + message['content'] + '<|im_end|><|im_start|>assistant<|im_sep|>'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|im_end|>'}}{% endif %}{% endfor %}", - /* .expected_output= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>", + /* .expected_output_adhoc= */ "<|im_start|>system<|im_sep|>You are a helpful assistant<|im_end|><|im_start|>user<|im_sep|>Hello<|im_end|><|im_start|>assistant<|im_sep|>Hi there<|im_end|><|im_start|>user<|im_sep|>Who are you<|im_end|><|im_start|>assistant<|im_sep|> I am an assistant <|im_end|><|im_start|>user<|im_sep|>Another question<|im_end|><|im_start|>assistant<|im_sep|>", + /* .expected_output_jinja= */ "", + /* .bos_token= */ "", + /* .eos_token= */ "", }, }; std::vector formatted_chat(1024); @@ -190,25 +263,56 @@ int main(void) { // test invalid chat template res = llama_chat_apply_template("INVALID TEMPLATE", conversation.data(), conversation.size(), true, formatted_chat.data(), formatted_chat.size()); assert(res < 0); + const auto add_generation_prompt = true; - for (const auto & tmpl : templates) { - printf("\n\n=== %s ===\n\n", tmpl.name.c_str()); + for (const auto & test_case : test_cases) { + printf("\n\n=== %s ===\n\n", test_case.name.c_str()); formatted_chat.resize(1024); res = llama_chat_apply_template( - tmpl.template_str.c_str(), + test_case.template_str.c_str(), conversation.data(), conversation.size(), - true, + add_generation_prompt, formatted_chat.data(), formatted_chat.size() ); formatted_chat.resize(res); std::string output(formatted_chat.data(), formatted_chat.size()); - printf("%s\n", output.c_str()); - printf("-------------------------\n"); - assert(output == tmpl.expected_output); + if (output != test_case.expected_output_adhoc) { + printf("Expected:\n%s\n", test_case.expected_output_adhoc.c_str()); + printf("-------------------------\n"); + printf("Actual:\n%s\n", output.c_str()); + assert(output == test_case.expected_output_adhoc); + } } + json messages = json::array(); + for (const auto & msg : conversation) { + messages.push_back({ + {"role", msg.role}, + {"content", msg.content}, + }); + } + for (const auto & test_case : test_cases) { + if (!test_case.supported_with_jinja) { + continue; + } + printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); + try { + minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); + auto output = tmpl.apply(messages, json(), add_generation_prompt); + auto expected_output = test_case.expected_output_jinja.empty() ? test_case.expected_output_adhoc : test_case.expected_output_jinja; + if (output != expected_output) { + printf("Expected:\n%s\n", expected_output.c_str()); + printf("-------------------------\n"); + printf("Actual:\n%s\n", output.c_str()); + assert(output == expected_output); + } + } catch (const std::exception & e) { + printf("ERROR: %s\n", e.what()); + assert(false); + } + } // test llama_chat_format_single for system message printf("\n\n=== llama_chat_format_single (system message) ===\n\n"); From 18f257bf1a1aabea100935151a9e7eb09ff80f93 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 21:30:48 +0000 Subject: [PATCH 11/42] Fix deprecation --- common/common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 275aa7385b11f..763e931b199b0 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1825,8 +1825,8 @@ std::string common_chat_format_example(const struct llama_model * model, llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); - auto bos_token = common_token_to_piece(vocab, llama_token_bos(vocab), true); - auto eos_token = common_token_to_piece(vocab, llama_token_eos(vocab), true); + auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); + auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); std::string default_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override; if (chat_template_override.empty()) { From 8dd4f334a4585de49d84070a2ac41e9befc1317d Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 22:07:49 +0000 Subject: [PATCH 12/42] Add --jinja to llama-run --- common/common.cpp | 6 ++++-- examples/run/run.cpp | 40 +++++++++++++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 763e931b199b0..8009601dea431 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1830,8 +1830,10 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * std::string default_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override; if (chat_template_override.empty()) { - default_template_src = llama_model_chat_template(model, /* name */ nullptr); - tool_use_template_src = llama_model_chat_template(model, /* name */ "tool_use"); + auto str = llama_model_chat_template(model, /* name */ nullptr); + if (str) default_template_src = str; + str = llama_model_chat_template(model, /* name */ "tool_use"); + if (str) tool_use_template_src = str; } if (default_template_src.empty() || default_template_src == "chatml") { if (!tool_use_template_src.empty()) { diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 1c838aa777822..a06986df5beb5 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -103,6 +103,7 @@ class Opt { llama_model_params model_params; std::string model_; std::string user; + bool use_jinja = false; int context_size = -1, ngl = -1; float temperature = -1; bool verbose = false; @@ -154,6 +155,8 @@ class Opt { } else if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { verbose = true; + } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { + use_jinja = true; } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { help = true; return 0; @@ -711,13 +714,31 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(LlamaData & llama_data, const bool append) { +static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { + if (use_jinja) { + json messages = json::array(); + for (const auto & msg : llama_data.messages) { + messages.push_back({ + {"role", msg.role}, + { "content", msg.content} + }); + } + try { + auto result = tmpl.apply(messages, /* tools= */ json(), append); + llama_data.fmtted.resize(result.size() + 1); + memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); + return llama_data.fmtted.size(); + } catch (const std::exception & e) { + printe("failed to render the chat template: %s\n", e.what()); + return -1; + } + } int result = llama_chat_apply_template( - llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), llama_data.messages.size(), append, + tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); if (append && result > static_cast(llama_data.fmtted.size())) { llama_data.fmtted.resize(result); - result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get(), /* name */ nullptr), llama_data.messages.data(), + result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.fmtted.size()); } @@ -847,8 +868,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(LlamaData & llama_data, const bool append, int & output_length) { - const int new_len = apply_chat_template(llama_data, append); +static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { + const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); return -1; @@ -911,9 +932,10 @@ static int get_user_input(std::string & user_input, const std::string & user) { } // Main chat loop function -static int chat_loop(LlamaData & llama_data, const std::string & user) { +static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); + auto chat_templates = llama_chat_templates_from_model(llama_data.model.get(), ""); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -924,7 +946,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(llama_data, true, new_len) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -939,7 +961,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user) { } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(llama_data, false, prev_len) < 0) { + if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, false, prev_len, use_jinja) < 0) { return 1; } } @@ -999,7 +1021,7 @@ int main(int argc, const char ** argv) { return 1; } - if (chat_loop(llama_data, opt.user)) { + if (chat_loop(llama_data, opt.user, opt.use_jinja)) { return 1; } From a6afb2735f9764614db6ff69b31371abddce089b Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 22:57:35 +0000 Subject: [PATCH 13/42] Update common_chat_format_example to use minja template wrapper --- common/common.cpp | 14 +++++++++++--- common/common.h | 2 +- examples/main/main.cpp | 5 +++-- examples/server/server.cpp | 4 ++-- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8009601dea431..b390f1df324f6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1811,15 +1811,23 @@ std::string common_chat_format_single(const struct llama_model * model, return ss.str(); } -std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl) { +std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - return common_chat_apply_template(model, tmpl, msgs, true); + const auto add_generation_prompt = true; + if (use_jinja) { + auto messages = json::array(); + for (const auto & msg : msgs) { + messages.push_back({{"role", msg.role}, {"content", msg.content}}); + } + return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt); + } else { + return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt); + } } llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) diff --git a/common/common.h b/common/common.h index dea779d09d1b9..24a91cfa96493 100644 --- a/common/common.h +++ b/common/common.h @@ -619,7 +619,7 @@ std::string common_chat_format_single(const struct llama_model * model, // Returns an example of formatted chat std::string common_chat_format_example(const struct llama_model * model, - const std::string & tmpl); + const minja::chat_template & tmpl, bool use_jinja); struct llama_chat_templates { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 39666a0e8a83a..11038a7c63ce8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -165,6 +165,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); + auto chat_templates = llama_chat_templates_from_model(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -207,7 +208,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = !common_get_builtin_chat_template(model).empty() || !params.chat_template.empty(); + const bool has_chat_template = !chat_templates.default_template.source().empty(); if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -225,7 +226,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 15bcb7e0e1620..dc302ddc195b6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -4287,8 +4287,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(), - common_chat_format_example(ctx_server.model, params.chat_template).c_str()); + get_chat_templates().default_template.source().c_str(), + common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); From b4083e41556ae1faa6353e17adae33194840bedc Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 23:10:52 +0000 Subject: [PATCH 14/42] Test chat_template in e2e test --- examples/server/tests/unit/test_chat_completion.py | 14 ++++++++------ examples/server/tests/utils.py | 14 +++++++------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 5a42c5133d26f..76cab4ef9f82a 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -11,17 +11,19 @@ def create_server(): @pytest.mark.parametrize( - "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja", + "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", [ - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False), - (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None), + (None, "Book", "What is the best book", 8, " blue and shin", 23, 8, "length", True, "This is not a chat template, it is"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ] ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja): +def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): global server server.jinja = jinja + server.chat_template = chat_template server.start() res = server.make_request("POST", "/chat/completions", data={ "model": model, diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index d1c1980636413..48474a0ce4048 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -70,13 +70,13 @@ class ServerProcess: draft: int | None = None api_key: str | None = None lora_files: List[str] | None = None - chat_template_file: str | None = None - jinja: bool | None = None disable_ctx_shift: int | None = False draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None + jinja: bool | None = None chat_template: str | None = None + chat_template_file: str | None = None # session variables process: subprocess.Popen | None = None @@ -157,10 +157,6 @@ def start(self, timeout_seconds: int = 10) -> None: if self.lora_files: for lora_file in self.lora_files: server_args.extend(["--lora", lora_file]) - if self.chat_template_file: - server_args.extend(["--chat-template-file", self.chat_template_file]) - if self.jinja: - server_args.append("--jinja") if self.disable_ctx_shift: server_args.extend(["--no-context-shift"]) if self.api_key: @@ -171,9 +167,13 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.jinja: + server_args.append("--jinja") if self.chat_template: server_args.extend(["--chat-template", self.chat_template]) - + if self.chat_template_file: + server_args.extend(["--chat-template-file", self.chat_template_file]) + args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From b7e21710c47b2c7d7abac030018d71300c7667b0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 23:11:57 +0000 Subject: [PATCH 15/42] Update utils.py --- examples/server/tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 48474a0ce4048..93046b34db1ab 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -173,7 +173,7 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: server_args.extend(["--chat-template-file", self.chat_template_file]) - + args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From a57bb94e295a5cafccb35102a62d98a1287f8f87 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 23:18:03 +0000 Subject: [PATCH 16/42] Update test_chat_completion.py --- examples/server/tests/unit/test_chat_completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 76cab4ef9f82a..2e15348dceecb 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -15,7 +15,7 @@ def create_server(): [ (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", False, None), (None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length", True, None), - (None, "Book", "What is the best book", 8, " blue and shin", 23, 8, "length", True, "This is not a chat template, it is"), + (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), ] From 4daae0bfc7144cd814777a6193e1e0d32dde0d29 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 13 Jan 2025 23:26:31 +0000 Subject: [PATCH 17/42] Update run.cpp --- examples/run/run.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index a06986df5beb5..b4cbed9be6d35 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -720,14 +720,14 @@ static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & ll for (const auto & msg : llama_data.messages) { messages.push_back({ {"role", msg.role}, - { "content", msg.content} + {"content", msg.content}, }); } try { auto result = tmpl.apply(messages, /* tools= */ json(), append); llama_data.fmtted.resize(result.size() + 1); memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return llama_data.fmtted.size(); + return result.size(); } catch (const std::exception & e) { printe("failed to render the chat template: %s\n", e.what()); return -1; From 1b3bb7eeb96ba3db513073ac0cf74edc09de7119 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 14 Jan 2025 00:07:18 +0000 Subject: [PATCH 18/42] Update arg.cpp --- common/arg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/arg.cpp b/common/arg.cpp index c379e78ef93cd..cb43b0d5255c8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1919,7 +1919,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.use_jinja = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER})); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--chat-template"}, "JINJA_TEMPLATE", string_format( From b75d0622e492b739d05530b0de67437e08a8d30f Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 00:43:38 +0000 Subject: [PATCH 19/42] Refactor common_chat_* functions to accept minja template + use_jinja option --- common/common.cpp | 77 ++++++++++++++++-------------------- common/common.h | 27 +++++++------ examples/main/main.cpp | 24 +++++------ examples/run/run.cpp | 4 +- examples/server/server.cpp | 4 +- examples/server/utils.hpp | 9 ++--- tests/test-chat-template.cpp | 17 +++++--- 7 files changed, 82 insertions(+), 80 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b390f1df324f6..a8eea91f92dd8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -74,6 +74,15 @@ #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 +const char * LLAMA_CHATML_TEMPLATE = R"( + {%- for message in messages -%} + {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} + {%- endfor -%} + {%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} + {%- endif -%} +)"; + // // CURL utils // @@ -1748,56 +1757,56 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return res >= 0; } -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_apply_template( + const llama_chat_template & tmpl, const std::vector & msgs, - bool add_ass) { + bool add_ass, + bool use_jinja) { + if (use_jinja) { + auto messages = json::array(); + for (const auto & msg : msgs) { + messages.push_back({{"role", msg.role}, {"content", msg.content}}); + } + return tmpl.apply(messages, /* tools= */ json(), add_ass); + } + int alloc_size = 0; - bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; for (const auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); alloc_size += (msg.role.size() + msg.content.size()) * 1.25; } - const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model, /* name */ nullptr) : tmpl.c_str(); std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { - if (ptr_tmpl != nullptr) { - // if the custom "tmpl" is not supported, we throw an error - // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() - throw std::runtime_error("this custom template is not supported"); - } - - // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); - fallback = true; + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); } // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template( - fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template(tmpl.source().c_str(), chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); return formatted_chat; } -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_format_single( + const llama_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, - bool add_ass) { + bool add_ass, + bool use_jinja) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(model, tmpl, past_msg, false); + auto fmt_past_msg = past_msg.empty() ? "" : common_chat_apply_template(tmpl, past_msg, false, use_jinja); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1805,29 +1814,20 @@ std::string common_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = common_chat_apply_template(model, tmpl, chat_new, add_ass); + auto fmt_new_msg = common_chat_apply_template(tmpl, chat_new, add_ass, use_jinja); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string common_chat_format_example(const struct llama_model * model, const minja::chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, {"assistant", "Hi there"}, {"user", "How are you?"}, }; - const auto add_generation_prompt = true; - if (use_jinja) { - auto messages = json::array(); - for (const auto & msg : msgs) { - messages.push_back({{"role", msg.role}, {"content", msg.content}}); - } - return tmpl.apply(messages, /* tools= */ json(), add_generation_prompt); - } else { - return common_chat_apply_template(model, tmpl.source(), msgs, add_generation_prompt); - } + return common_chat_apply_template(tmpl, msgs, true, use_jinja); } llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) @@ -1847,14 +1847,7 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * if (!tool_use_template_src.empty()) { default_template_src = tool_use_template_src; } else { - default_template_src = R"( - {%- for message in messages -%} - {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} - {%- endfor -%} - {%- if add_generation_prompt -%} - {{- "<|im_start|>assistant\n" -}} - {%- endif -%} - )"; + default_template_src = LLAMA_CHATML_TEMPLATE; } } return { diff --git a/common/common.h b/common/common.h index 24a91cfa96493..474b76473280b 100644 --- a/common/common.h +++ b/common/common.h @@ -26,6 +26,8 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" +extern const char * LLAMA_CHATML_TEMPLATE; + struct common_adapter_lora_info { std::string path; float scale; @@ -602,29 +604,32 @@ struct common_chat_msg { // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); +typedef minja::chat_template llama_chat_template; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error -std::string common_chat_apply_template(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_apply_template( + const llama_chat_template & tmpl, const std::vector & chat, - bool add_ass); + bool add_ass, + bool use_jinja); // Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single(const struct llama_model * model, - const std::string & tmpl, +std::string common_chat_format_single( + const llama_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, - bool add_ass); + bool add_ass, + bool use_jinja); // Returns an example of formatted chat -std::string common_chat_format_example(const struct llama_model * model, - const minja::chat_template & tmpl, bool use_jinja); - +std::string common_chat_format_example( + const llama_chat_template & tmpl, bool use_jinja); struct llama_chat_templates { - minja::chat_template default_template; - std::optional tool_use_template; + llama_chat_template default_template; + std::optional tool_use_template; }; llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 11038a7c63ce8..986e744cef911 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -84,14 +84,6 @@ static void sigint_handler(int signo) { } #endif -static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, const std::string & role, const std::string & content) { - common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); - chat_msgs.push_back({role, content}); - LOG_DBG("formatted: '%s'\n", formatted.c_str()); - return formatted; -} - int main(int argc, char ** argv) { common_params params; g_params = ¶ms; @@ -226,7 +218,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, chat_templates.default_template, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -270,10 +262,18 @@ int main(int argc, char ** argv) { std::vector embd_inp; + auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { + common_chat_msg new_msg{role, content}; + auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja); + chat_msgs.push_back({role, content}); + LOG_DBG("formatted: '%s'\n", formatted.c_str()); + return formatted; + }; + { auto prompt = (params.conversation_mode && params.enable_chat_template) // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) + ? chat_add_and_format("system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt) // otherwise use the prompt as is : params.prompt; if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { @@ -780,7 +780,7 @@ int main(int argc, char ** argv) { } if (params.enable_chat_template) { - chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); + chat_add_and_format("assistant", assistant_ss.str()); } is_interacting = true; LOG("\n"); @@ -845,7 +845,7 @@ int main(int argc, char ** argv) { bool format_chat = params.conversation_mode && params.enable_chat_template; std::string user_inp = format_chat - ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) + ? chat_add_and_format("user", std::move(buffer)) : std::move(buffer); // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index b4cbed9be6d35..64cc2d20d545e 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -714,7 +714,7 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { +static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { if (use_jinja) { json messages = json::array(); for (const auto & msg : llama_data.messages) { @@ -868,7 +868,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const minja::chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { +static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index dc302ddc195b6..885697fdf5c0f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3869,7 +3869,7 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto & templates = get_chat_templates(); const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; - json data = oaicompat_completion_params_parse(ctx_server.model, body, chat_template, params.use_jinja); + json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); return handle_completions_impl( SERVER_TASK_TYPE_COMPLETION, @@ -4288,7 +4288,7 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, get_chat_templates().default_template.source().c_str(), - common_chat_format_example(ctx_server.model, get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); + common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b1d08a5cf1bf6..b6cec0eb81e2a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -351,7 +351,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector & messages) { +inline std::string format_chat(const llama_chat_template & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -379,7 +379,7 @@ inline std::string format_chat(const struct llama_model * model, const std::stri chat.push_back({role, content}); } - const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); + const auto formatted_chat = common_chat_apply_template(tmpl, chat, true, /* use_jinja= */ false); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; @@ -579,9 +579,8 @@ static json oaicompat_completion_params_parse(const json & body) { } static json oaicompat_completion_params_parse( - const struct llama_model * model, const json & body, /* openai api json semantics */ - const minja::chat_template & tmpl, + const llama_chat_template & tmpl, bool use_jinja) { json llama_params; @@ -622,7 +621,7 @@ static json oaicompat_completion_params_parse( if (use_jinja) { llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); } else { - llama_params["prompt"] = format_chat(model, tmpl.source(), body.at("messages")); + llama_params["prompt"] = format_chat(tmpl, body.at("messages")); } // Handle "n" field diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 9560d4fa3ccd7..0c3f20f3df765 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -8,6 +8,7 @@ #include "llama.h" #include "common.h" #include "chat-template.hpp" +#include "llama-chat.h" int main(void) { std::vector conversation { @@ -319,9 +320,10 @@ int main(void) { std::vector chat2; common_chat_msg sys_msg{"system", "You are a helpful assistant"}; - auto fmt_sys = [&](std::string tmpl) { - auto output = common_chat_format_single(nullptr, tmpl, chat2, sys_msg, false); - printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str()); + auto fmt_sys = [&](std::string tmpl_str) { + minja::chat_template tmpl(tmpl_str, "", ""); + auto output = common_chat_format_single(tmpl, chat2, sys_msg, false, /* use_jinja= */ false); + printf("fmt_sys(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; }; @@ -345,9 +347,10 @@ int main(void) { chat2.push_back({"assistant", "I am assistant"}); common_chat_msg new_msg{"user", "How are you"}; - auto fmt_single = [&](std::string tmpl) { - auto output = common_chat_format_single(nullptr, tmpl, chat2, new_msg, true); - printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str()); + auto fmt_single = [&](std::string tmpl_str) { + minja::chat_template tmpl(tmpl_str, "", ""); + auto output = common_chat_format_single(tmpl, chat2, new_msg, true, /* use_jinja= */ false); + printf("fmt_single(%s) : %s\n", tmpl_str.c_str(), output.c_str()); printf("-------------------------\n"); return output; }; @@ -362,5 +365,7 @@ int main(void) { assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>"); + assert(llm_chat_detect_template(LLAMA_CHATML_TEMPLATE) == LLM_CHAT_TEMPLATE_CHATML); + return 0; } From 81c0d437a5f10c6ef8777183efe9437ab84e5a00 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 00:56:19 +0000 Subject: [PATCH 20/42] Attempt to fix linkage of LLAMA_CHATML_TEMPLATE --- common/common.cpp | 4 ++-- common/common.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 03128d8d5ed13..8dd8912e5a43e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -74,14 +74,14 @@ #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 -const char * LLAMA_CHATML_TEMPLATE = R"( +const std::string LLAMA_CHATML_TEMPLATE(R"( {%- for message in messages -%} {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} {%- endfor -%} {%- if add_generation_prompt -%} {{- "<|im_start|>assistant\n" -}} {%- endif -%} -)"; +)"); // // CURL utils diff --git a/common/common.h b/common/common.h index 977819459d926..04e1272d6bcb6 100644 --- a/common/common.h +++ b/common/common.h @@ -26,7 +26,7 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -extern const char * LLAMA_CHATML_TEMPLATE; +extern const std::string LLAMA_CHATML_TEMPLATE; struct common_adapter_lora_info { std::string path; From d5fa351a2494836742b935442aefc12fdc13b4ad Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 01:04:12 +0000 Subject: [PATCH 21/42] Revert LLAMA_CHATML_TEMPLATE refactor --- common/common.cpp | 18 ++++++++---------- common/common.h | 2 -- tests/test-chat-template.cpp | 3 --- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8dd8912e5a43e..b7770b02c414c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -74,15 +74,6 @@ #endif #define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 -const std::string LLAMA_CHATML_TEMPLATE(R"( - {%- for message in messages -%} - {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} - {%- endfor -%} - {%- if add_generation_prompt -%} - {{- "<|im_start|>assistant\n" -}} - {%- endif -%} -)"); - // // CURL utils // @@ -1846,7 +1837,14 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * if (!tool_use_template_src.empty()) { default_template_src = tool_use_template_src; } else { - default_template_src = LLAMA_CHATML_TEMPLATE; + default_template_src = R"( + {%- for message in messages -%} + {{- "<|im_start|>" + message.role + "\n" + message.content + "<|im_end|>\n" -}} + {%- endfor -%} + {%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} + {%- endif -%} + )"; } } return { diff --git a/common/common.h b/common/common.h index 04e1272d6bcb6..2a7c3ee3cf5ad 100644 --- a/common/common.h +++ b/common/common.h @@ -26,8 +26,6 @@ #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" -extern const std::string LLAMA_CHATML_TEMPLATE; - struct common_adapter_lora_info { std::string path; float scale; diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 0c3f20f3df765..3bd11a1f0cd56 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -8,7 +8,6 @@ #include "llama.h" #include "common.h" #include "chat-template.hpp" -#include "llama-chat.h" int main(void) { std::vector conversation { @@ -365,7 +364,5 @@ int main(void) { assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>"); - assert(llm_chat_detect_template(LLAMA_CHATML_TEMPLATE) == LLM_CHAT_TEMPLATE_CHATML); - return 0; } From ee1e10e21ea6b2f2a85b0244fc7923cdbbd2d4ae Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 02:52:40 +0000 Subject: [PATCH 22/42] Normalize newlines in test-chat-templates for windows tests --- tests/test-chat-template.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 3bd11a1f0cd56..d9e25124092e5 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -9,6 +9,15 @@ #include "common.h" #include "chat-template.hpp" +static std::string normalize_newlines(const std::string & s) { +#ifdef _WIN32 + static const std::regex nl_regex("\r\n"); + return std::regex_replace(s, nl_regex, "\n"); +#else + return s; +#endif +} + int main(void) { std::vector conversation { {"system", "You are a helpful assistant"}, @@ -300,8 +309,8 @@ int main(void) { printf("\n\n=== %s (jinja) ===\n\n", test_case.name.c_str()); try { minja::chat_template tmpl(test_case.template_str, test_case.bos_token, test_case.eos_token); - auto output = tmpl.apply(messages, json(), add_generation_prompt); - auto expected_output = test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja; + auto output = normalize_newlines(tmpl.apply(messages, json(), add_generation_prompt)); + auto expected_output = normalize_newlines(test_case.expected_output_jinja.empty() ? test_case.expected_output : test_case.expected_output_jinja); if (output != expected_output) { printf("Expected:\n%s\n", expected_output.c_str()); printf("-------------------------\n"); From e63520f37ac3fe55c1e25adc3be7ae9d5ad90dcb Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 10:37:56 +0000 Subject: [PATCH 23/42] Forward decl minja::chat_template to avoid eager json dep --- common/common.cpp | 20 +++++++++++++++----- common/common.h | 16 ++++++++++------ examples/main/main.cpp | 7 ++++--- examples/run/run.cpp | 6 ++++-- examples/server/server.cpp | 12 +++++++----- 5 files changed, 40 insertions(+), 21 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b7770b02c414c..881828bcd38f9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -12,6 +12,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat-template.hpp" #include #include @@ -1827,11 +1828,18 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); std::string default_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override; + bool has_explicit_template = !chat_template_override.empty(); if (chat_template_override.empty()) { auto str = llama_model_chat_template(model, /* name */ nullptr); - if (str) default_template_src = str; + if (str) { + default_template_src = str; + has_explicit_template = true; + } str = llama_model_chat_template(model, /* name */ "tool_use"); - if (str) tool_use_template_src = str; + if (str) { + tool_use_template_src = str; + has_explicit_template = true; + } } if (default_template_src.empty() || default_template_src == "chatml") { if (!tool_use_template_src.empty()) { @@ -1848,9 +1856,11 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * } } return { - /* .default_template = */ { default_template_src, bos_token, eos_token }, - /* .tool_use_template = */ tool_use_template_src.empty() ? std::nullopt - : std::optional({ tool_use_template_src, bos_token, eos_token }), + has_explicit_template, + std::move(std::make_unique(default_template_src, bos_token, eos_token)), + tool_use_template_src.empty() + ? nullptr + : std::move(std::make_unique(tool_use_template_src, bos_token, eos_token)) }; } diff --git a/common/common.h b/common/common.h index 2a7c3ee3cf5ad..1c01cd9ef2297 100644 --- a/common/common.h +++ b/common/common.h @@ -3,7 +3,6 @@ #pragma once #include "llama-cpp.h" -#include "chat-template.hpp" #include #include @@ -601,8 +600,18 @@ struct common_chat_msg { // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); +namespace minja { + class chat_template; +} + typedef minja::chat_template llama_chat_template; +struct llama_chat_templates { + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr default_template; // always set (defaults to chatml) + std::unique_ptr tool_use_template; +}; + // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error @@ -624,11 +633,6 @@ std::string common_chat_format_single( std::string common_chat_format_example( const llama_chat_template & tmpl, bool use_jinja); -struct llama_chat_templates { - llama_chat_template default_template; - std::optional tool_use_template; -}; - llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); // diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 986e744cef911..903a92faffe95 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -4,6 +4,7 @@ #include "log.h" #include "sampling.h" #include "llama.h" +#include "chat-template.hpp" #include #include @@ -200,7 +201,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = !chat_templates.default_template.source().empty(); + const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.default_template; if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -218,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.default_template, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.default_template, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -264,7 +265,7 @@ int main(int argc, char ** argv) { auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single(chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja); + auto formatted = common_chat_format_single(*chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja); chat_msgs.push_back({role, content}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 64cc2d20d545e..46a9453472097 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -26,6 +26,7 @@ #include "common.h" #include "json.hpp" #include "llama-cpp.h" +#include "chat-template.hpp" #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) [[noreturn]] static void sigint_handler(int) { @@ -936,6 +937,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); auto chat_templates = llama_chat_templates_from_model(llama_data.model.get(), ""); + GGML_ASSERT(chat_templates.default_template); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -946,7 +948,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, true, new_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(*chat_templates.default_template, llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -961,7 +963,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(chat_templates.default_template, llama_data, false, prev_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(*chat_templates.default_template, llama_data, false, prev_len, use_jinja) < 0) { return 1; } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 885697fdf5c0f..6d86338a8fe28 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1745,8 +1745,9 @@ struct server_context { if (use_jinja) { auto templates = llama_chat_templates_from_model(model, ""); + GGML_ASSERT(templates.default_template); try { - templates.default_template.apply({{ + templates.default_template->apply({{ {"role", "user"}, {"content", "test"}, }}, json(), true); @@ -3630,6 +3631,7 @@ int main(int argc, char ** argv) { std::lock_guard lock(chat_templates_mutex); if (!chat_templates) { chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); + GGML_ASSERT(chat_templates->default_template); } return *chat_templates; }; @@ -3641,7 +3643,7 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", templates.default_template.source() }, + { "chat_template", templates.default_template->source() }, { "build_info", build_info }, }; if (ctx_server.params_base.use_jinja && templates.tool_use_template) { @@ -3868,7 +3870,7 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto & templates = get_chat_templates(); - const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : templates.default_template; + const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : *templates.default_template; json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); return handle_completions_impl( @@ -4287,8 +4289,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - get_chat_templates().default_template.source().c_str(), - common_chat_format_example(get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); + get_chat_templates().default_template->source().c_str(), + common_chat_format_example(*get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); From 33322e823e783a9b22e350dd89727f8aa6b82073 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 10:38:21 +0000 Subject: [PATCH 24/42] Flush stdout in chat template before potential crash --- tests/test-chat-template.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index d9e25124092e5..1906431362e9b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -291,6 +291,7 @@ int main(void) { printf("Expected:\n%s\n", test_case.expected_output.c_str()); printf("-------------------------\n"); printf("Actual:\n%s\n", output.c_str()); + fflush(stdout); assert(output == test_case.expected_output); } } @@ -315,6 +316,7 @@ int main(void) { printf("Expected:\n%s\n", expected_output.c_str()); printf("-------------------------\n"); printf("Actual:\n%s\n", output.c_str()); + fflush(stdout); assert(output == expected_output); } } catch (const std::exception & e) { From 5074e6fecdab206787286c799629b1789e55b182 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 10:48:03 +0000 Subject: [PATCH 25/42] Fix copy elision warning --- common/common.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 881828bcd38f9..9c535a1765131 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1857,10 +1857,10 @@ llama_chat_templates llama_chat_templates_from_model(const struct llama_model * } return { has_explicit_template, - std::move(std::make_unique(default_template_src, bos_token, eos_token)), + std::make_unique(default_template_src, bos_token, eos_token), tool_use_template_src.empty() ? nullptr - : std::move(std::make_unique(tool_use_template_src, bos_token, eos_token)) + : std::make_unique(tool_use_template_src, bos_token, eos_token) }; } From fc60802b6e99862b7bef506e04eb9a8f99d0beea Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 11:35:54 +0000 Subject: [PATCH 26/42] Rm unused optional include --- common/common.h | 1 - 1 file changed, 1 deletion(-) diff --git a/common/common.h b/common/common.h index 1c01cd9ef2297..a96a995311340 100644 --- a/common/common.h +++ b/common/common.h @@ -4,7 +4,6 @@ #include "llama-cpp.h" -#include #include #include #include From 0e74c9dabe31c91e1e3dd4909e25c3624793b124 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 11:58:00 +0000 Subject: [PATCH 27/42] Add missing optional include to server.cpp --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6d86338a8fe28..189290df94e38 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include From e3c475cd127911eec9a0e8cc8aa33614d43cdfe1 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 14:55:27 +0000 Subject: [PATCH 28/42] Disable jinja test that has a cryptic windows failure --- tests/test-chat-template.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 1906431362e9b..6b877f65901e3 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -68,6 +68,7 @@ int main(void) { /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", /* .bos_token= */ "", /* .eos_token= */ "", + /* .supported_with_jinja= */ false, // Mysteriously fails on windows-latest in llama.cpp's CI, although that template works fine in Minja's CI on windows-latest }, { /* .name= */ "mlabonne/AlphaMonarch-7B", From cc503564702917992e101a9c79f15335dac1a5b0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 18 Jan 2025 17:55:04 +0000 Subject: [PATCH 29/42] minja: fix vigogne (https://github.com/google/minja/pull/22) --- common/minja.hpp | 10 ++++------ tests/test-chat-template.cpp | 1 - 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/common/minja.hpp b/common/minja.hpp index 2639c15a0c738..c1c4212c74a16 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -1305,12 +1305,10 @@ struct ArgumentsExpression { }; static std::string strip(const std::string & s) { - static std::regex trailing_spaces_regex("^\\s+|\\s+$"); - return std::regex_replace(s, trailing_spaces_regex, ""); - // auto start = s.find_first_not_of(" \t\n\r"); - // if (start == std::string::npos) return ""; - // auto end = s.find_last_not_of(" \t\n\r"); - // return s.substr(start, end - start + 1); + auto start = s.find_first_not_of(" \t\n\r"); + if (start == std::string::npos) return ""; + auto end = s.find_last_not_of(" \t\n\r"); + return s.substr(start, end - start + 1); } static std::string html_escape(const std::string & s) { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 6b877f65901e3..1906431362e9b 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -68,7 +68,6 @@ int main(void) { /* .expected_output_jinja= */ "[INST] <>\nYou are a helpful assistant\n<>\n\nHello [/INST] Hi there [INST] Who are you [/INST] I am an assistant [INST] Another question [/INST]", /* .bos_token= */ "", /* .eos_token= */ "", - /* .supported_with_jinja= */ false, // Mysteriously fails on windows-latest in llama.cpp's CI, although that template works fine in Minja's CI on windows-latest }, { /* .name= */ "mlabonne/AlphaMonarch-7B", From 153e8524113621d3ca90d146e6dc5d42a5c42160 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Mon, 20 Jan 2025 20:55:52 +0000 Subject: [PATCH 30/42] Apply suggestions from code review Co-authored-by: Xuan Son Nguyen Co-authored-by: Georgi Gerganov --- common/common.cpp | 6 +++--- common/common.h | 4 ++-- include/llama.h | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 9c535a1765131..ce023fc2be0cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1821,11 +1821,11 @@ std::string common_chat_format_example(const llama_chat_template & tmpl, bool us return common_chat_apply_template(tmpl, msgs, true, use_jinja); } -llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) +llama_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); - auto bos_token = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); - auto eos_token = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); + auto token_bos = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); + auto token_eos = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); std::string default_template_src = chat_template_override; std::string tool_use_template_src = chat_template_override; bool has_explicit_template = !chat_template_override.empty(); diff --git a/common/common.h b/common/common.h index a96a995311340..352cbb0fa9189 100644 --- a/common/common.h +++ b/common/common.h @@ -607,8 +607,8 @@ typedef minja::chat_template llama_chat_template; struct llama_chat_templates { bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr default_template; // always set (defaults to chatml) - std::unique_ptr tool_use_template; + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; }; // CPP wrapper for llama_chat_apply_template diff --git a/include/llama.h b/include/llama.h index dca9314aa92f6..3b75e760780ef 100644 --- a/include/llama.h +++ b/include/llama.h @@ -510,6 +510,7 @@ extern "C" { LLAMA_API uint64_t llama_model_size(const struct llama_model * model); // Get the default chat template. Returns nullptr if not available + // If name is NULL, returns the default chat template LLAMA_API const char * llama_model_chat_template(const struct llama_model * model, const char * name); // Returns the total number of parameters in the model From db9dd0c1acc497766f5b0957f4d5d32c883d7904 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 21:06:18 +0000 Subject: [PATCH 31/42] Finish suggested renamings --- common/common.cpp | 14 +++++++------- common/common.h | 2 +- examples/main/main.cpp | 8 ++++---- examples/run/run.cpp | 8 ++++---- examples/server/server.cpp | 26 +++++++++++++------------- 5 files changed, 29 insertions(+), 29 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index ce023fc2be0cb..2c0558b5b5b2b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1827,7 +1827,7 @@ llama_chat_templates common_chat_templates_from_model(const struct llama_model * auto token_bos = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); auto token_eos = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); std::string default_template_src = chat_template_override; - std::string tool_use_template_src = chat_template_override; + std::string template_tool_use_src = chat_template_override; bool has_explicit_template = !chat_template_override.empty(); if (chat_template_override.empty()) { auto str = llama_model_chat_template(model, /* name */ nullptr); @@ -1837,13 +1837,13 @@ llama_chat_templates common_chat_templates_from_model(const struct llama_model * } str = llama_model_chat_template(model, /* name */ "tool_use"); if (str) { - tool_use_template_src = str; + template_tool_use_src = str; has_explicit_template = true; } } if (default_template_src.empty() || default_template_src == "chatml") { - if (!tool_use_template_src.empty()) { - default_template_src = tool_use_template_src; + if (!template_tool_use_src.empty()) { + default_template_src = template_tool_use_src; } else { default_template_src = R"( {%- for message in messages -%} @@ -1857,10 +1857,10 @@ llama_chat_templates common_chat_templates_from_model(const struct llama_model * } return { has_explicit_template, - std::make_unique(default_template_src, bos_token, eos_token), - tool_use_template_src.empty() + std::make_unique(default_template_src, token_bos, token_eos), + template_tool_use_src.empty() ? nullptr - : std::make_unique(tool_use_template_src, bos_token, eos_token) + : std::make_unique(template_tool_use_src, token_bos, token_eos) }; } diff --git a/common/common.h b/common/common.h index 352cbb0fa9189..7b50c82d2a2e3 100644 --- a/common/common.h +++ b/common/common.h @@ -632,7 +632,7 @@ std::string common_chat_format_single( std::string common_chat_format_example( const llama_chat_template & tmpl, bool use_jinja); -llama_chat_templates llama_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); +llama_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); // // KV cache utils diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 903a92faffe95..da2a03ab9ba10 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -158,7 +158,7 @@ int main(int argc, char ** argv) { } const llama_vocab * vocab = llama_model_get_vocab(model); - auto chat_templates = llama_chat_templates_from_model(model, params.chat_template); + auto chat_templates = common_chat_templates_from_model(model, params.chat_template); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -201,7 +201,7 @@ int main(int argc, char ** argv) { } // auto enable conversation mode if chat template is available - const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.default_template; + const bool has_chat_template = chat_templates.has_explicit_template && chat_templates.template_default; if (params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO) { if (has_chat_template) { LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__); @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { // print chat template example in conversation mode if (params.conversation_mode) { if (params.enable_chat_template) { - LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.default_template, params.use_jinja).c_str()); + LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(*chat_templates.template_default, params.use_jinja).c_str()); } else { LOG_INF("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__); } @@ -265,7 +265,7 @@ int main(int argc, char ** argv) { auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { common_chat_msg new_msg{role, content}; - auto formatted = common_chat_format_single(*chat_templates.default_template, chat_msgs, new_msg, role == "user", g_params->use_jinja); + auto formatted = common_chat_format_single(*chat_templates.template_default, chat_msgs, new_msg, role == "user", g_params->use_jinja); chat_msgs.push_back({role, content}); LOG_DBG("formatted: '%s'\n", formatted.c_str()); return formatted; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 46a9453472097..408bd7181a3d7 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -936,8 +936,8 @@ static int get_user_input(std::string & user_input, const std::string & user) { static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) { int prev_len = 0; llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - auto chat_templates = llama_chat_templates_from_model(llama_data.model.get(), ""); - GGML_ASSERT(chat_templates.default_template); + auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), ""); + GGML_ASSERT(chat_templates.template_default); static const bool stdout_a_terminal = is_stdout_a_terminal(); while (true) { // Get user input @@ -948,7 +948,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ add_message("user", user.empty() ? user_input : user, llama_data); int new_len; - if (apply_chat_template_with_error_handling(*chat_templates.default_template, llama_data, true, new_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) { return 1; } @@ -963,7 +963,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_ } add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(*chat_templates.default_template, llama_data, false, prev_len, use_jinja) < 0) { + if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) { return 1; } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 189290df94e38..6717198c5415d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1745,15 +1745,15 @@ struct server_context { llama_chat_message chat[] = {{"user", "test"}}; if (use_jinja) { - auto templates = llama_chat_templates_from_model(model, ""); - GGML_ASSERT(templates.default_template); + auto templates = common_chat_templates_from_model(model, ""); + GGML_ASSERT(templates.template_default); try { - templates.default_template->apply({{ + templates.template_default->apply({{ {"role", "user"}, {"content", "test"}, }}, json(), true); - if (templates.tool_use_template) { - templates.tool_use_template->apply({{ + if (templates.template_tool_use) { + templates.template_tool_use->apply({{ {"role", "user"}, {"content", "test"}, }}, json(), true); @@ -3631,8 +3631,8 @@ int main(int argc, char ** argv) { auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & { std::lock_guard lock(chat_templates_mutex); if (!chat_templates) { - chat_templates = llama_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); - GGML_ASSERT(chat_templates->default_template); + chat_templates = common_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); + GGML_ASSERT(chat_templates->template_default); } return *chat_templates; }; @@ -3644,11 +3644,11 @@ int main(int argc, char ** argv) { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", templates.default_template->source() }, + { "chat_template", templates.template_default->source() }, { "build_info", build_info }, }; - if (ctx_server.params_base.use_jinja && templates.tool_use_template) { - data["chat_template_tool_use"] = templates.tool_use_template->source(); + if (ctx_server.params_base.use_jinja && templates.template_tool_use) { + data["chat_template_tool_use"] = templates.template_tool_use->source(); } res_ok(res, data); @@ -3871,7 +3871,7 @@ int main(int argc, char ** argv) { auto body = json::parse(req.body); const auto & templates = get_chat_templates(); - const auto & chat_template = body.contains("tools") && templates.tool_use_template ? *templates.tool_use_template : *templates.default_template; + const auto & chat_template = body.contains("tools") && templates.template_tool_use ? *templates.template_tool_use : *templates.template_default; json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); return handle_completions_impl( @@ -4290,8 +4290,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - get_chat_templates().default_template->source().c_str(), - common_chat_format_example(*get_chat_templates().default_template, ctx_server.params_base.use_jinja).c_str()); + get_chat_templates().template_default->source().c_str(), + common_chat_format_example(*get_chat_templates().template_default, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); From c9e8fdd70e576c1c71635db645227a3d5738423a Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 21:25:18 +0000 Subject: [PATCH 32/42] Move chat_templates inside server_context + remove mutex --- examples/server/server.cpp | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6717198c5415d..eabbf79408616 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1662,6 +1662,8 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + llama_chat_templates chat_templates; + ~server_context() { // Clear any sampling context for (server_slot & slot : slots) { @@ -1738,6 +1740,8 @@ struct server_context { cparams_dft.type_v = GGML_TYPE_F16; } + chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + return true; } @@ -3625,30 +3629,17 @@ int main(int argc, char ** argv) { } }; - std::mutex chat_templates_mutex; - std::optional chat_templates; - - auto get_chat_templates = [&ctx_server, &chat_templates_mutex, &chat_templates]() -> const llama_chat_templates & { - std::lock_guard lock(chat_templates_mutex); - if (!chat_templates) { - chat_templates = common_chat_templates_from_model(ctx_server.model, ctx_server.params_base.chat_template); - GGML_ASSERT(chat_templates->template_default); - } - return *chat_templates; - }; - - const auto handle_props = [&ctx_server, &res_ok, &get_chat_templates](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { // this endpoint is publicly available, please only return what is safe to be exposed - const auto & templates = get_chat_templates(); json data = { { "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "total_slots", ctx_server.params_base.n_parallel }, { "model_path", ctx_server.params_base.model }, - { "chat_template", templates.template_default->source() }, + { "chat_template", ctx_server.chat_templates.template_default->source() }, { "build_info", build_info }, }; - if (ctx_server.params_base.use_jinja && templates.template_tool_use) { - data["chat_template_tool_use"] = templates.template_tool_use->source(); + if (ctx_server.params_base.use_jinja && ctx_server.chat_templates.template_tool_use) { + data["chat_template_tool_use"] = ctx_server.chat_templates.template_tool_use->source(); } res_ok(res, data); @@ -3863,15 +3854,14 @@ int main(int argc, char ** argv) { OAICOMPAT_TYPE_NONE); // infill is not OAI compatible }; - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl, &get_chat_templates](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params_base.embedding) { res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return; } auto body = json::parse(req.body); - const auto & templates = get_chat_templates(); - const auto & chat_template = body.contains("tools") && templates.template_tool_use ? *templates.template_tool_use : *templates.template_default; + const auto & chat_template = body.contains("tools") && ctx_server.chat_templates.template_tool_use ? *ctx_server.chat_templates.template_tool_use : *ctx_server.chat_templates.template_default; json data = oaicompat_completion_params_parse(body, chat_template, params.use_jinja); return handle_completions_impl( @@ -4290,8 +4280,8 @@ int main(int argc, char ** argv) { // print sample chat example to make it clear which template is used LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - get_chat_templates().template_default->source().c_str(), - common_chat_format_example(*get_chat_templates().template_default, ctx_server.params_base.use_jinja).c_str()); + ctx_server.chat_templates.template_default->source().c_str(), + common_chat_format_example(*ctx_server.chat_templates.template_default, ctx_server.params_base.use_jinja).c_str()); ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1)); From 8c84aefd4d8609def3127cc37f091648a1af8820 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 21:48:31 +0000 Subject: [PATCH 33/42] Update --chat-template-file w/ recent change to --chat-template --- common/arg.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b46f205f69438..53bd32e3aeaff 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1966,10 +1966,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); add_opt(common_arg( {"--chat-template-file"}, "JINJA_TEMPLATE_FILE", - "set custom jinja chat template file (default: template taken from model's metadata)\n" - "if suffix/prefix are specified, template will be disabled\n" - "only commonly used templates are accepted (unless --jinja is set before this flag):\n" - "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template", + string_format( + "set custom jinja chat template file (default: template taken from model's metadata)\n" + "if suffix/prefix are specified, template will be disabled\n" + "only commonly used templates are accepted (unless --jinja is set before this flag):\n" + "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() + ), [](common_params & params, const std::string & value) { std::ifstream file(value); if (!file) { From 154bfaaa390d537b4e84a9cc5f9c539bcb93bf2c Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 21:54:34 +0000 Subject: [PATCH 34/42] Refactor chat template validation --- common/arg.cpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 53bd32e3aeaff..5799d7832f1ba 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -323,6 +323,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: either --embedding or --reranking can be specified, but not both"); } + if (!params.chat_template.empty() && !common_chat_verify_template(params.chat_template, params.use_jinja)) { + throw std::runtime_error(string_format( + "error: the supplied chat template is not supported: %s%s\n", + params.chat_template.c_str(), + params.use_jinja ? "" : "\nnote: llama.cpp was started without --jinja, we only support commonly used templates" + )); + } + return true; } @@ -1954,13 +1962,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "list of built-in templates:\n%s", list_builtin_chat_templates().c_str() ), [](common_params & params, const std::string & value) { - if (!common_chat_verify_template(value, params.use_jinja)) { - throw std::runtime_error(string_format( - "error: the supplied chat template is not supported: %s%s\n", - value.c_str(), - params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" - )); - } params.chat_template = value; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE")); @@ -1977,20 +1978,10 @@ common_params_context common_params_parser_init(common_params & params, llama_ex if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); } - std::string chat_template; std::copy( std::istreambuf_iterator(file), std::istreambuf_iterator(), - std::back_inserter(chat_template) - ); - if (!common_chat_verify_template(chat_template, params.use_jinja)) { - throw std::runtime_error(string_format( - "error: the supplied chat template is not supported: %s%s\n", - value.c_str(), - params.use_jinja ? "" : "\nnote: llama.cpp does not use jinja parser, we only support commonly used templates" - )); - } - params.chat_template = chat_template; + std::back_inserter(params.chat_template)); } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE_FILE")); add_opt(common_arg( From 54a669e09e8c565bb8b1b14bc6340da685632529 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 22:50:08 +0000 Subject: [PATCH 35/42] Guard against missing eos/bos tokens (null token otherwise throws in llama_vocab::impl::token_get_attr) --- common/common.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 2c0558b5b5b2b..58529b63d5b2c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1824,8 +1824,9 @@ std::string common_chat_format_example(const llama_chat_template & tmpl, bool us llama_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); - auto token_bos = common_token_to_piece(vocab, llama_vocab_bos(vocab), true); - auto token_eos = common_token_to_piece(vocab, llama_vocab_eos(vocab), true); + // TODO: consider detecting if the template needs bos / eos tokens and warn / error when missing. + auto token_bos = llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(vocab, llama_vocab_bos(vocab), true); + auto token_eos = llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(vocab, llama_vocab_eos(vocab), true); std::string default_template_src = chat_template_override; std::string template_tool_use_src = chat_template_override; bool has_explicit_template = !chat_template_override.empty(); From 8348c605acc017fe46dd5fd2e460d7d69758a231 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 23:00:47 +0000 Subject: [PATCH 36/42] Warn against missing eos / bos tokens when jinja template references them --- common/common.cpp | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 58529b63d5b2c..161e2aa35ff94 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1824,9 +1824,6 @@ std::string common_chat_format_example(const llama_chat_template & tmpl, bool us llama_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); - // TODO: consider detecting if the template needs bos / eos tokens and warn / error when missing. - auto token_bos = llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(vocab, llama_vocab_bos(vocab), true); - auto token_eos = llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(vocab, llama_vocab_eos(vocab), true); std::string default_template_src = chat_template_override; std::string template_tool_use_src = chat_template_override; bool has_explicit_template = !chat_template_override.empty(); @@ -1856,6 +1853,19 @@ llama_chat_templates common_chat_templates_from_model(const struct llama_model * )"; } } + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + if (token == LLAMA_TOKEN_NULL) { + if (default_template_src.find(jinja_variable_name) != std::string::npos + || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN("%s: warning: vocab does not have a %s token, jinja template won't work as intended.\n", __func__, name); + } + return std::string(); + } else { + return common_token_to_piece(vocab, token, true); + } + }; + auto token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); + auto token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); return { has_explicit_template, std::make_unique(default_template_src, token_bos, token_eos), From ee475d2f513b15956db8a18f5507fedeb04f171e Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 23:42:07 +0000 Subject: [PATCH 37/42] rename: common_chat_template[s] --- common/common.cpp | 8 ++++---- common/common.h | 16 ++++++++-------- examples/run/run.cpp | 4 ++-- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 4 ++-- 5 files changed, 17 insertions(+), 17 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 161e2aa35ff94..727ab0a109ec8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1749,7 +1749,7 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { } std::string common_chat_apply_template( - const llama_chat_template & tmpl, + const common_chat_template & tmpl, const std::vector & msgs, bool add_ass, bool use_jinja) { @@ -1791,7 +1791,7 @@ std::string common_chat_apply_template( } std::string common_chat_format_single( - const llama_chat_template & tmpl, + const common_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -1811,7 +1811,7 @@ std::string common_chat_format_single( return ss.str(); } -std::string common_chat_format_example(const llama_chat_template & tmpl, bool use_jinja) { +std::string common_chat_format_example(const common_chat_template & tmpl, bool use_jinja) { std::vector msgs = { {"system", "You are a helpful assistant"}, {"user", "Hello"}, @@ -1821,7 +1821,7 @@ std::string common_chat_format_example(const llama_chat_template & tmpl, bool us return common_chat_apply_template(tmpl, msgs, true, use_jinja); } -llama_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) +common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override) { auto vocab = llama_model_get_vocab(model); std::string default_template_src = chat_template_override; diff --git a/common/common.h b/common/common.h index ac25a6f65a81e..7c9d73ce1e49e 100644 --- a/common/common.h +++ b/common/common.h @@ -611,26 +611,26 @@ namespace minja { class chat_template; } -typedef minja::chat_template llama_chat_template; +typedef minja::chat_template common_chat_template; -struct llama_chat_templates { +struct common_chat_templates { bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr template_default; // always set (defaults to chatml) - std::unique_ptr template_tool_use; + std::unique_ptr template_default; // always set (defaults to chatml) + std::unique_ptr template_tool_use; }; // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml // If the custom "tmpl" is not supported, we throw an error std::string common_chat_apply_template( - const llama_chat_template & tmpl, + const common_chat_template & tmpl, const std::vector & chat, bool add_ass, bool use_jinja); // Format single message, while taking into account the position of that message in chat history std::string common_chat_format_single( - const llama_chat_template & tmpl, + const common_chat_template & tmpl, const std::vector & past_msg, const common_chat_msg & new_msg, bool add_ass, @@ -638,9 +638,9 @@ std::string common_chat_format_single( // Returns an example of formatted chat std::string common_chat_format_example( - const llama_chat_template & tmpl, bool use_jinja); + const common_chat_template & tmpl, bool use_jinja); -llama_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); +common_chat_templates common_chat_templates_from_model(const struct llama_model * model, const std::string & chat_template_override); // // KV cache utils diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 4c72f22f9db0e..e567ad716a30d 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -717,7 +717,7 @@ static void add_message(const char * role, const std::string & text, LlamaData & } // Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { +static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) { if (use_jinja) { json messages = json::array(); for (const auto & msg : llama_data.messages) { @@ -893,7 +893,7 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt, } // Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const llama_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { +static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja); if (new_len < 0) { printe("failed to apply the chat template\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 408e50e399e42..798b7faccaf4e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1689,7 +1689,7 @@ struct server_context { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; - llama_chat_templates chat_templates; + common_chat_templates chat_templates; ~server_context() { // Clear any sampling context diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b6cec0eb81e2a..c5987250cce3a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -351,7 +351,7 @@ static llama_tokens format_infill( } // Format given chat. If tmpl is empty, we take the template from model metadata -inline std::string format_chat(const llama_chat_template & tmpl, const std::vector & messages) { +inline std::string format_chat(const common_chat_template & tmpl, const std::vector & messages) { std::vector chat; for (size_t i = 0; i < messages.size(); ++i) { @@ -580,7 +580,7 @@ static json oaicompat_completion_params_parse(const json & body) { static json oaicompat_completion_params_parse( const json & body, /* openai api json semantics */ - const llama_chat_template & tmpl, + const common_chat_template & tmpl, bool use_jinja) { json llama_params; From 8a7c89e60c90be8c04f58335cd11ab5c91ae1ac7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 23:44:42 +0000 Subject: [PATCH 38/42] reinstate assert on chat_templates.template_default --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 798b7faccaf4e..865be4d8da669 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1771,6 +1771,7 @@ struct server_context { } chat_templates = common_chat_templates_from_model(model, params_base.chat_template); + GGML_ASSERT(chat_templates.template_default.get() != nullptr); return true; } From 8347da907d714a6df4ad0b9606e8cd0e43cbd753 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 20 Jan 2025 23:59:15 +0000 Subject: [PATCH 39/42] Update minja to https://github.com/google/minja/commit/b8437df626ac6cd0ce3b333b3c74ed1129c19f25 --- common/chat-template.hpp | 2 ++ common/minja.hpp | 25 ++++++++++++++++--------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/common/chat-template.hpp b/common/chat-template.hpp index 302a173c29d95..b4a90145c9a89 100644 --- a/common/chat-template.hpp +++ b/common/chat-template.hpp @@ -113,6 +113,8 @@ class chat_template { } const std::string & source() const { return source_; } + const std::string & bos_token() const { return bos_token_; } + const std::string & eos_token() const { return eos_token_; } bool supports_tools() const { return supports_tools_; } bool supports_parallel_tool_calls() const { return supports_parallel_tool_calls_; } diff --git a/common/minja.hpp b/common/minja.hpp index c1c4212c74a16..aa0a5019d394c 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -366,13 +366,11 @@ class Value : public std::enable_shared_from_this { throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); } } - void erase(size_t index) { - if (array_) throw std::runtime_error("Value is not an array: " + dump()); + Value pop(size_t index) { + if (!array_) throw std::runtime_error("Value is not an array: " + dump()); + auto value = array_->at(index); array_->erase(array_->begin() + index); - } - void erase(const std::string & key) { - if (object_) throw std::runtime_error("Value is not an object: " + dump()); - object_->erase(key); + return value; } const Value& at(const Value & index) const { return const_cast(this)->at(index); @@ -1353,6 +1351,15 @@ class MethodCallExpr : public Expression { if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); obj.insert(index, vargs.args[1]); return Value(); + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {0, 1}, {0, 0}); + if (vargs.args.empty()) { + return obj.pop(obj.size() - 1); + } else { + auto index = vargs.args[0].get(); + if (index < 0 || index >= (int64_t) obj.size()) throw std::runtime_error("Index out of range for pop method"); + return obj.pop(index); + } } } else if (obj.is_object()) { if (method->get_name() == "items") { @@ -2539,7 +2546,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { auto ns = Value::object(); - args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); + args.expectArgs("namespace", {0, 0}, {0, std::numeric_limits::max()}); for (auto & [name, value] : args.kwargs) { ns.set(name, value); } @@ -2594,7 +2601,7 @@ inline std::shared_ptr Context::builtins() { }; // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("reject", {2, (std::numeric_limits::max)()}, {0, 0}); + args.expectArgs("reject", {2, std::numeric_limits::max()}, {0, 0}); auto & items = args.args[0]; auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2665,7 +2672,7 @@ inline std::shared_ptr Context::builtins() { return out; })); globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("selectattr", {2, (std::numeric_limits::max)()}, {0, 0}); + args.expectArgs("selectattr", {2, std::numeric_limits::max()}, {0, 0}); auto & items = args.args[0]; if (items.is_null()) return Value::array(); From ff2cce57ad3ca70fb5db629b88d8cc3a729ecf8d Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 21 Jan 2025 01:26:19 +0000 Subject: [PATCH 40/42] Update minja to https://github.com/google/minja/pull/25 --- common/minja.hpp | 61 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/common/minja.hpp b/common/minja.hpp index aa0a5019d394c..e8ac04ec64059 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -206,6 +206,38 @@ class Value : public std::enable_shared_from_this { throw std::runtime_error("Value is not an array: " + dump()); array_->push_back(v); } + Value pop(const Value& index) { + if (is_array()) { + if (array_->empty()) + throw std::runtime_error("pop from empty list"); + if (index.is_null()) { + auto ret = array_->back(); + array_->pop_back(); + return ret; + } else if (!index.is_number_integer()) { + throw std::runtime_error("pop index must be an integer: " + index.dump()); + } else { + auto i = index.get(); + if (i < 0 || i >= static_cast(array_->size())) + throw std::runtime_error("pop index out of range: " + index.dump()); + auto it = array_->begin() + (i < 0 ? array_->size() + i : i); + auto ret = *it; + array_->erase(it); + return ret; + } + } else if (is_object()) { + if (!index.is_hashable()) + throw std::runtime_error("Unashable type: " + index.dump()); + auto it = object_->find(index.primitive_); + if (it == object_->end()) + throw std::runtime_error("Key not found: " + index.dump()); + auto ret = it->second; + object_->erase(it); + return ret; + } else { + throw std::runtime_error("Value is not an array or object: " + dump()); + } + } Value get(const Value& key) { if (array_) { if (!key.is_number_integer()) { @@ -366,11 +398,13 @@ class Value : public std::enable_shared_from_this { throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); } } - Value pop(size_t index) { + void erase(size_t index) { if (!array_) throw std::runtime_error("Value is not an array: " + dump()); - auto value = array_->at(index); array_->erase(array_->begin() + index); - return value; + } + void erase(const std::string & key) { + if (!object_) throw std::runtime_error("Value is not an object: " + dump()); + object_->erase(key); } const Value& at(const Value & index) const { return const_cast(this)->at(index); @@ -1345,21 +1379,15 @@ class MethodCallExpr : public Expression { vargs.expectArgs("append method", {1, 1}, {0, 0}); obj.push_back(vargs.args[0]); return Value(); + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {0, 1}, {0, 0}); + return obj.pop(vargs.args.empty() ? Value() : vargs.args[0]); } else if (method->get_name() == "insert") { vargs.expectArgs("insert method", {2, 2}, {0, 0}); auto index = vargs.args[0].get(); if (index < 0 || index > (int64_t) obj.size()) throw std::runtime_error("Index out of range for insert method"); obj.insert(index, vargs.args[1]); return Value(); - } else if (method->get_name() == "pop") { - vargs.expectArgs("pop method", {0, 1}, {0, 0}); - if (vargs.args.empty()) { - return obj.pop(obj.size() - 1); - } else { - auto index = vargs.args[0].get(); - if (index < 0 || index >= (int64_t) obj.size()) throw std::runtime_error("Index out of range for pop method"); - return obj.pop(index); - } } } else if (obj.is_object()) { if (method->get_name() == "items") { @@ -1369,6 +1397,9 @@ class MethodCallExpr : public Expression { result.push_back(Value::array({key, obj.at(key)})); } return result; + } else if (method->get_name() == "pop") { + vargs.expectArgs("pop method", {1, 1}, {0, 0}); + return obj.pop(vargs.args[0]); } else if (method->get_name() == "get") { vargs.expectArgs("get method", {1, 2}, {0, 0}); auto key = vargs.args[0]; @@ -2546,7 +2577,7 @@ inline std::shared_ptr Context::builtins() { })); globals.set("namespace", Value::callable([=](const std::shared_ptr &, ArgumentsValue & args) { auto ns = Value::object(); - args.expectArgs("namespace", {0, 0}, {0, std::numeric_limits::max()}); + args.expectArgs("namespace", {0, 0}, {0, (std::numeric_limits::max)()}); for (auto & [name, value] : args.kwargs) { ns.set(name, value); } @@ -2601,7 +2632,7 @@ inline std::shared_ptr Context::builtins() { }; // https://jinja.palletsprojects.com/en/3.0.x/templates/#jinja-filters.reject globals.set("reject", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("reject", {2, std::numeric_limits::max()}, {0, 0}); + args.expectArgs("reject", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; auto filter_fn = context->get(args.args[1]); if (filter_fn.is_null()) throw std::runtime_error("Undefined filter: " + args.args[1].dump()); @@ -2672,7 +2703,7 @@ inline std::shared_ptr Context::builtins() { return out; })); globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { - args.expectArgs("selectattr", {2, std::numeric_limits::max()}, {0, 0}); + args.expectArgs("selectattr", {2, (std::numeric_limits::max)()}, {0, 0}); auto & items = args.args[0]; if (items.is_null()) return Value::array(); From 9d8ebd62c612d46187856880bd85137fa8c4c027 Mon Sep 17 00:00:00 2001 From: ochafik Date: Tue, 21 Jan 2025 03:18:06 +0000 Subject: [PATCH 41/42] Update minja from https://github.com/google/minja/pull/27 --- common/minja.hpp | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/common/minja.hpp b/common/minja.hpp index e8ac04ec64059..f0ee7a49a43e1 100644 --- a/common/minja.hpp +++ b/common/minja.hpp @@ -18,12 +18,6 @@ #include #include -#ifdef _WIN32 -#define ENDL "\r\n" -#else -#define ENDL "\n" -#endif - using json = nlohmann::ordered_json; namespace minja { @@ -38,7 +32,7 @@ struct Options { struct ArgumentsValue; -static std::string normalize_newlines(const std::string & s) { +inline std::string normalize_newlines(const std::string & s) { #ifdef _WIN32 static const std::regex nl_regex("\r\n"); return std::regex_replace(s, nl_regex, "\n"); @@ -91,7 +85,7 @@ class Value : public std::enable_shared_from_this { void dump(std::ostringstream & out, int indent = -1, int level = 0, bool to_json = false) const { auto print_indent = [&](int level) { if (indent > 0) { - out << ENDL; + out << "\n"; for (int i = 0, n = level * indent; i < n; ++i) out << ' '; } }; @@ -594,11 +588,11 @@ static std::string error_location_suffix(const std::string & source, size_t pos) auto max_line = std::count(start, end, '\n') + 1; auto col = pos - std::string(start, it).rfind('\n'); std::ostringstream out; - out << " at row " << line << ", column " << col << ":" ENDL; - if (line > 1) out << get_line(line - 1) << ENDL; - out << get_line(line) << ENDL; - out << std::string(col - 1, ' ') << "^" << ENDL; - if (line < max_line) out << get_line(line + 1) << ENDL; + out << " at row " << line << ", column " << col << ":\n"; + if (line > 1) out << get_line(line - 1) << "\n"; + out << get_line(line) << "\n"; + out << std::string(col - 1, ' ') << "^\n"; + if (line < max_line) out << get_line(line + 1) << "\n"; return out.str(); } @@ -833,7 +827,7 @@ class TemplateNode { std::string render(const std::shared_ptr & context) const { std::ostringstream out; render(out, context); - return normalize_newlines(out.str()); + return out.str(); } }; @@ -2695,11 +2689,11 @@ inline std::shared_ptr Context::builtins() { while (std::getline(iss, line, '\n')) { auto needs_indent = !is_first || first; if (is_first) is_first = false; - else out += ENDL; + else out += "\n"; if (needs_indent) out += indent; out += line; } - if (!text.empty() && text.back() == '\n') out += ENDL; + if (!text.empty() && text.back() == '\n') out += "\n"; return out; })); globals.set("selectattr", Value::callable([=](const std::shared_ptr & context, ArgumentsValue & args) { From cbb9b819da848453471c0afd1b33004386670e61 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 21 Jan 2025 12:29:51 +0000 Subject: [PATCH 42/42] rm unused optional header --- examples/server/server.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 865be4d8da669..5f08c4eccd54a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -26,7 +26,6 @@ #include #include #include -#include #include #include #include