Skip to content

Commit 7adfa18

Browse files
committed
Re-Prompt after toolcall
1 parent c8843da commit 7adfa18

File tree

1 file changed

+53
-28
lines changed

1 file changed

+53
-28
lines changed

examples/main/main.cpp

+53-28
Original file line numberDiff line numberDiff line change
@@ -90,39 +90,51 @@ static void sigint_handler(int signo) {
9090

9191
class chat_formatter {
9292
public:
93-
chat_formatter(common_params & params, std::vector<common_chat_msg> & chat_msgs, struct common_chat_templates * chat_templates)
93+
94+
struct result {
95+
std::string formatted;
96+
bool tool_was_called;
97+
};
98+
99+
chat_formatter(common_params & params,
100+
std::vector<common_chat_msg> & chat_msgs,
101+
struct common_chat_templates * chat_templates)
102+
94103
: params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates) {}
95104

96105
#ifdef LLAMA_USE_TOOLCALL
97106
chat_formatter(common_params & params,
98107
std::vector<common_chat_msg> & chat_msgs,
99108
struct common_chat_templates * chat_templates,
100109
const llama_vocab * vocab,
101-
toolcall::client::ptr tc_client,
102-
common_chat_format * chat_format)
110+
toolcall::client::ptr tc_client)
103111

104-
: params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {}
112+
: params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates),
113+
vocab_(vocab), tc_client_(tc_client),
114+
chat_format_(COMMON_CHAT_FORMAT_CONTENT_ONLY),
115+
formatted_() {}
105116
#endif
106117

107-
std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false) {
118+
chat_formatter::result operator() (const std::string & role, const std::string & content) {
119+
120+
common_chat_msg new_msg = common_chat_parse(content, chat_format_);
121+
new_msg.role = role;
108122

109123
common_chat_templates_inputs cinputs;
110124
cinputs.use_jinja = params_.use_jinja;
111125
cinputs.add_generation_prompt = (role == "user");
112126
#ifdef LLAMA_USE_TOOLCALL
113-
if (tc_client_ != nullptr && use_toolcalls) {
127+
if (tc_client_ != nullptr) {
114128
cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tc_client_->tool_choice());
115129
cinputs.tools = common_chat_tools_parse_oaicompat(tc_client_->tool_list());
116130
}
117131
#endif
118-
for (const auto & msg : chat_msgs_) {
119-
cinputs.messages.push_back(common_chat_msg(msg));
120-
}
121-
122-
common_chat_msg new_msg = common_chat_parse(content, *chat_format_);
123-
new_msg.role = role;
132+
cinputs.messages.assign(chat_msgs_.cbegin(), chat_msgs_.cend());
133+
cinputs.messages.push_back(new_msg);
134+
chat_msgs_.push_back(new_msg);
124135

125-
if (! new_msg.tool_calls.empty()) {
136+
bool tool_was_called = false;
137+
if (! new_msg.tool_calls.empty()) { // Call tool and re-prompt
126138
nlohmann::json result_array = nlohmann::json::array();
127139
for (const auto & tc : new_msg.tool_calls) {
128140
toolcall::result_set res = tc_client_->call(tc.name, tc.arguments, tc.id);
@@ -132,21 +144,28 @@ class chat_formatter {
132144
}
133145
}
134146
}
135-
new_msg.content += result_array.dump(-1);
147+
common_chat_msg toolcall_msg;
148+
toolcall_msg.role = "tool";
149+
toolcall_msg.content = result_array.dump(-1);
150+
151+
cinputs.add_generation_prompt = true;
152+
cinputs.messages.push_back(toolcall_msg);
153+
chat_msgs_.push_back(toolcall_msg);
154+
155+
tool_was_called = true;
136156
}
137157

138-
cinputs.messages.push_back(new_msg);
139158
common_chat_params cparams = common_chat_templates_apply(chat_templates_, cinputs);
159+
std::string formatted = cparams.prompt.substr(formatted_.size(), cparams.prompt.size());
160+
formatted_ = cparams.prompt;
140161

141-
auto formatted = cparams.prompt;
142-
chat_msgs_.push_back(new_msg);
143162
LOG_DBG("formatted: '%s'\n", formatted.c_str());
144163

145164
#ifdef LLAMA_USE_TOOLCALL
146-
if (chat_format_) *chat_format_ = cparams.format;
165+
chat_format_ = cparams.format;
147166
common_chat_grammar_to_sampler(&cparams, vocab_, &params_.sampling);
148167
#endif
149-
return formatted;
168+
return chat_formatter::result{std::move(formatted), tool_was_called};
150169
}
151170

152171
private:
@@ -157,7 +176,8 @@ class chat_formatter {
157176
#ifdef LLAMA_USE_TOOLCALL
158177
const llama_vocab * vocab_;
159178
toolcall::client::ptr tc_client_;
160-
common_chat_format * chat_format_;
179+
common_chat_format chat_format_;
180+
std::string formatted_;
161181
#endif
162182
};
163183

@@ -355,8 +375,7 @@ int main(int argc, char ** argv) {
355375
if (tc_client) {
356376
tc_client->initialize();
357377
}
358-
common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
359-
chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client, &chat_format);
378+
chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get(), vocab, tc_client);
360379
#else
361380
chat_formatter chat_add_and_format(params, chat_msgs, chat_templates.get());
362381
#endif
@@ -366,12 +385,12 @@ int main(int argc, char ** argv) {
366385
if (params.conversation_mode && params.enable_chat_template) {
367386
if (!params.system_prompt.empty()) {
368387
// format the system prompt (will use template default if empty)
369-
chat_add_and_format("system", params.system_prompt, true);
388+
chat_add_and_format("system", params.system_prompt);
370389
}
371390

372391
if (!params.prompt.empty()) {
373392
// format and append the user prompt
374-
chat_add_and_format("user", params.prompt, true);
393+
chat_add_and_format("user", params.prompt);
375394
} else {
376395
waiting_for_first_input = true;
377396
}
@@ -905,9 +924,15 @@ int main(int argc, char ** argv) {
905924
}
906925

907926
if (params.enable_chat_template) {
908-
chat_add_and_format("assistant", assistant_ss.str(), true);
909-
is_interacting = true;
910-
LOG("\n");
927+
auto format_res = chat_add_and_format("assistant", assistant_ss.str());
928+
if (format_res.tool_was_called) {
929+
auto format_res_tok = common_tokenize(ctx, format_res.formatted, false, true);
930+
embd_inp.insert(embd_inp.end(), format_res_tok.begin(), format_res_tok.end());
931+
932+
} else {
933+
is_interacting = true;
934+
LOG("\n");
935+
}
911936
}
912937
}
913938
}
@@ -975,7 +1000,7 @@ int main(int argc, char ** argv) {
9751000

9761001
bool format_chat = params.conversation_mode && params.enable_chat_template;
9771002
std::string user_inp = format_chat
978-
? chat_add_and_format("user", std::move(buffer))
1003+
? chat_add_and_format("user", std::move(buffer)).formatted
9791004
: std::move(buffer);
9801005
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
9811006
const auto line_pfx = common_tokenize(ctx, params.input_prefix, false, true);

0 commit comments

Comments
 (0)