@@ -90,39 +90,51 @@ static void sigint_handler(int signo) {
90
90
91
91
class chat_formatter {
92
92
public:
93
- chat_formatter (common_params & params, std::vector<common_chat_msg> & chat_msgs, struct common_chat_templates * chat_templates)
93
+
94
+ struct result {
95
+ std::string formatted;
96
+ bool tool_was_called;
97
+ };
98
+
99
+ chat_formatter (common_params & params,
100
+ std::vector<common_chat_msg> & chat_msgs,
101
+ struct common_chat_templates * chat_templates)
102
+
94
103
: params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates) {}
95
104
96
105
#ifdef LLAMA_USE_TOOLCALL
97
106
chat_formatter (common_params & params,
98
107
std::vector<common_chat_msg> & chat_msgs,
99
108
struct common_chat_templates * chat_templates,
100
109
const llama_vocab * vocab,
101
- toolcall::client::ptr tc_client,
102
- common_chat_format * chat_format)
110
+ toolcall::client::ptr tc_client)
103
111
104
- : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates), vocab_(vocab), tc_client_(tc_client), chat_format_(chat_format) {}
112
+ : params_(params), chat_msgs_(chat_msgs), chat_templates_(chat_templates),
113
+ vocab_(vocab), tc_client_(tc_client),
114
+ chat_format_(COMMON_CHAT_FORMAT_CONTENT_ONLY),
115
+ formatted_() {}
105
116
#endif
106
117
107
- std::string operator () (const std::string & role, const std::string & content, [[maybe_unused]] bool use_toolcalls = false ) {
118
+ chat_formatter::result operator () (const std::string & role, const std::string & content) {
119
+
120
+ common_chat_msg new_msg = common_chat_parse (content, chat_format_);
121
+ new_msg.role = role;
108
122
109
123
common_chat_templates_inputs cinputs;
110
124
cinputs.use_jinja = params_.use_jinja ;
111
125
cinputs.add_generation_prompt = (role == " user" );
112
126
#ifdef LLAMA_USE_TOOLCALL
113
- if (tc_client_ != nullptr && use_toolcalls ) {
127
+ if (tc_client_ != nullptr ) {
114
128
cinputs.tool_choice = common_chat_tool_choice_parse_oaicompat (tc_client_->tool_choice ());
115
129
cinputs.tools = common_chat_tools_parse_oaicompat (tc_client_->tool_list ());
116
130
}
117
131
#endif
118
- for (const auto & msg : chat_msgs_) {
119
- cinputs.messages .push_back (common_chat_msg (msg));
120
- }
121
-
122
- common_chat_msg new_msg = common_chat_parse (content, *chat_format_);
123
- new_msg.role = role;
132
+ cinputs.messages .assign (chat_msgs_.cbegin (), chat_msgs_.cend ());
133
+ cinputs.messages .push_back (new_msg);
134
+ chat_msgs_.push_back (new_msg);
124
135
125
- if (! new_msg.tool_calls .empty ()) {
136
+ bool tool_was_called = false ;
137
+ if (! new_msg.tool_calls .empty ()) { // Call tool and re-prompt
126
138
nlohmann::json result_array = nlohmann::json::array ();
127
139
for (const auto & tc : new_msg.tool_calls ) {
128
140
toolcall::result_set res = tc_client_->call (tc.name , tc.arguments , tc.id );
@@ -132,21 +144,28 @@ class chat_formatter {
132
144
}
133
145
}
134
146
}
135
- new_msg.content += result_array.dump (-1 );
147
+ common_chat_msg toolcall_msg;
148
+ toolcall_msg.role = " tool" ;
149
+ toolcall_msg.content = result_array.dump (-1 );
150
+
151
+ cinputs.add_generation_prompt = true ;
152
+ cinputs.messages .push_back (toolcall_msg);
153
+ chat_msgs_.push_back (toolcall_msg);
154
+
155
+ tool_was_called = true ;
136
156
}
137
157
138
- cinputs.messages .push_back (new_msg);
139
158
common_chat_params cparams = common_chat_templates_apply (chat_templates_, cinputs);
159
+ std::string formatted = cparams.prompt .substr (formatted_.size (), cparams.prompt .size ());
160
+ formatted_ = cparams.prompt ;
140
161
141
- auto formatted = cparams.prompt ;
142
- chat_msgs_.push_back (new_msg);
143
162
LOG_DBG (" formatted: '%s'\n " , formatted.c_str ());
144
163
145
164
#ifdef LLAMA_USE_TOOLCALL
146
- if (chat_format_) * chat_format_ = cparams.format ;
165
+ chat_format_ = cparams.format ;
147
166
common_chat_grammar_to_sampler (&cparams, vocab_, ¶ms_.sampling );
148
167
#endif
149
- return formatted;
168
+ return chat_formatter::result{ std::move ( formatted), tool_was_called} ;
150
169
}
151
170
152
171
private:
@@ -157,7 +176,8 @@ class chat_formatter {
157
176
#ifdef LLAMA_USE_TOOLCALL
158
177
const llama_vocab * vocab_;
159
178
toolcall::client::ptr tc_client_;
160
- common_chat_format * chat_format_;
179
+ common_chat_format chat_format_;
180
+ std::string formatted_;
161
181
#endif
162
182
};
163
183
@@ -355,8 +375,7 @@ int main(int argc, char ** argv) {
355
375
if (tc_client) {
356
376
tc_client->initialize ();
357
377
}
358
- common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
359
- chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get (), vocab, tc_client, &chat_format);
378
+ chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get (), vocab, tc_client);
360
379
#else
361
380
chat_formatter chat_add_and_format (params, chat_msgs, chat_templates.get ());
362
381
#endif
@@ -366,12 +385,12 @@ int main(int argc, char ** argv) {
366
385
if (params.conversation_mode && params.enable_chat_template ) {
367
386
if (!params.system_prompt .empty ()) {
368
387
// format the system prompt (will use template default if empty)
369
- chat_add_and_format (" system" , params.system_prompt , true );
388
+ chat_add_and_format (" system" , params.system_prompt );
370
389
}
371
390
372
391
if (!params.prompt .empty ()) {
373
392
// format and append the user prompt
374
- chat_add_and_format (" user" , params.prompt , true );
393
+ chat_add_and_format (" user" , params.prompt );
375
394
} else {
376
395
waiting_for_first_input = true ;
377
396
}
@@ -905,9 +924,15 @@ int main(int argc, char ** argv) {
905
924
}
906
925
907
926
if (params.enable_chat_template ) {
908
- chat_add_and_format (" assistant" , assistant_ss.str (), true );
909
- is_interacting = true ;
910
- LOG (" \n " );
927
+ auto format_res = chat_add_and_format (" assistant" , assistant_ss.str ());
928
+ if (format_res.tool_was_called ) {
929
+ auto format_res_tok = common_tokenize (ctx, format_res.formatted , false , true );
930
+ embd_inp.insert (embd_inp.end (), format_res_tok.begin (), format_res_tok.end ());
931
+
932
+ } else {
933
+ is_interacting = true ;
934
+ LOG (" \n " );
935
+ }
911
936
}
912
937
}
913
938
}
@@ -975,7 +1000,7 @@ int main(int argc, char ** argv) {
975
1000
976
1001
bool format_chat = params.conversation_mode && params.enable_chat_template ;
977
1002
std::string user_inp = format_chat
978
- ? chat_add_and_format (" user" , std::move (buffer))
1003
+ ? chat_add_and_format (" user" , std::move (buffer)). formatted
979
1004
: std::move (buffer);
980
1005
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
981
1006
const auto line_pfx = common_tokenize (ctx, params.input_prefix , false , true );
0 commit comments