Skip to content

Commit 73fbd67

Browse files
committed
llama_chat_apply_template: use term "chat" everywhere
1 parent dba4337 commit 73fbd67

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

llama.cpp

+17-11
Original file line numberDiff line numberDiff line change
@@ -12459,7 +12459,10 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
1245912459
return 0;
1246012460
}
1246112461

12462-
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass);
12462+
int32_t llama_chat_apply_template_internal(
12463+
const std::string & chat_template,
12464+
const std::vector<const llama_chat_message *> & chat,
12465+
std::string & dest, bool add_ass);
1246312466

1246412467
// trim whitespace from the beginning and end of a string
1246512468
static std::string trim(const std::string & str) {
@@ -12476,12 +12479,15 @@ static std::string trim(const std::string & str) {
1247612479

1247712480
// Simple version of "llama_apply_chat_template" that only works with strings
1247812481
// This function uses heuristic checks to determine commonly used template. It is not a jinja parser.
12479-
int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_template, std::vector<const llama_chat_message *> conversation, bool add_ass) {
12482+
int32_t llama_chat_apply_template_internal(
12483+
const std::string & chat_template,
12484+
const std::vector<const llama_chat_message *> & chat,
12485+
std::string & dest, bool add_ass) {
1248012486
// Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527
1248112487
std::stringstream ss;
1248212488
if (chat_template.find("<|im_start|>") != std::string::npos) {
1248312489
// chatml template
12484-
for (auto message : conversation) {
12490+
for (auto message : chat) {
1248512491
ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n";
1248612492
}
1248712493
if (add_ass) {
@@ -12500,7 +12506,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
1250012506
// construct the prompt
1250112507
bool is_inside_turn = true; // skip BOS at the beginning
1250212508
ss << "[INST] ";
12503-
for (auto message : conversation) {
12509+
for (auto message : chat) {
1250412510
std::string content = strip_message ? trim(message->content) : message->content;
1250512511
std::string role(message->role);
1250612512
if (!is_inside_turn) {
@@ -12524,7 +12530,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
1252412530
// llama2 templates seem to not care about "add_generation_prompt"
1252512531
} else if (chat_template.find("<|user|>") != std::string::npos) {
1252612532
// zephyr template
12527-
for (auto message : conversation) {
12533+
for (auto message : chat) {
1252812534
ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n";
1252912535
}
1253012536
if (add_ass) {
@@ -12541,7 +12547,7 @@ int32_t llama_chat_apply_template_internal(std::string &dest, std::string chat_t
1254112547
LLAMA_API int32_t llama_chat_apply_template(
1254212548
const struct llama_model * model,
1254312549
const char * custom_template,
12544-
const struct llama_chat_message * msg,
12550+
const struct llama_chat_message * chat,
1254512551
size_t n_msg,
1254612552
bool add_ass,
1254712553
char * buf,
@@ -12560,14 +12566,14 @@ LLAMA_API int32_t llama_chat_apply_template(
1256012566
current_template = std::string(model_template.data(), model_template.size());
1256112567
}
1256212568
}
12563-
// format the conversation to string
12564-
std::vector<const llama_chat_message *> conversation_vec;
12565-
conversation_vec.resize(n_msg);
12569+
// format the chat to string
12570+
std::vector<const llama_chat_message *> chat_vec;
12571+
chat_vec.resize(n_msg);
1256612572
for (size_t i = 0; i < n_msg; i++) {
12567-
conversation_vec[i] = &msg[i];
12573+
chat_vec[i] = &chat[i];
1256812574
}
1256912575
std::string formatted_chat;
12570-
int32_t res = llama_chat_apply_template_internal(formatted_chat, current_template, conversation_vec, add_ass);
12576+
int32_t res = llama_chat_apply_template_internal(current_template, chat_vec, formatted_chat, add_ass);
1257112577
if (res < 0) {
1257212578
return res;
1257312579
}

llama.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -704,18 +704,20 @@ extern "C" {
704704
char * buf,
705705
int32_t length);
706706

707-
/// Apply chat template and maybe tokenize it. Inspired by hf apply_chat_template() on python.
707+
/// Apply chat template. Inspired by hf apply_chat_template() on python.
708708
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
709709
/// NOTE: This function only support some known jinja templates. It is not a jinja parser.
710-
/// @param custom_template A Jinja template to use for this conversion. If this is nullptr, the model’s default chat template will be used instead.
711-
/// @param msg Pointer to a list of multiple llama_chat_message
710+
/// @param custom_template A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead.
711+
/// @param chat Pointer to a list of multiple llama_chat_message
712+
/// @param n_msg Number of llama_chat_message in this chat
712713
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.
713714
/// @param buf A buffer to hold the output formatted prompt. The recommended alloc size is 2 * (total number of characters of all messages)
715+
/// @param length The size of the allocated buffer
714716
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
715717
LLAMA_API int32_t llama_chat_apply_template(
716718
const struct llama_model * model,
717719
const char * custom_template,
718-
const struct llama_chat_message * msg,
720+
const struct llama_chat_message * chat,
719721
size_t n_msg,
720722
bool add_ass,
721723
char * buf,

0 commit comments

Comments
 (0)