Skip to content

Commit e402de3

Browse files
authored
grammars: fix resampling logic regression (#7424)
1 parent fcf6538 commit e402de3

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

common/sampling.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
179179
struct llama_context * ctx_main,
180180
struct llama_context * ctx_cfg,
181181
const int idx,
182-
bool is_resampling) { // Add a parameter to indicate if we are resampling
182+
bool is_resampling) {
183183
const llama_sampling_params & params = ctx_sampling->params;
184184

185185
const float temp = params.temp;
@@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
188188
const float mirostat_eta = params.mirostat_eta;
189189

190190
std::vector<float> original_logits;
191-
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
192-
if (!is_resampling) {
191+
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
192+
if (ctx_sampling->grammar != NULL && !is_resampling) {
193193
GGML_ASSERT(!original_logits.empty());
194194
}
195195
llama_token id = 0;
@@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
252252
// Restore logits from the copy
253253
std::copy(original_logits.begin(), original_logits.end(), logits);
254254

255-
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
255+
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
256256
}
257257
}
258258

@@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
285285
// Get a pointer to the logits
286286
float * logits = llama_get_logits_ith(ctx_main, idx);
287287

288-
if (apply_grammar && original_logits != NULL) {
288+
if (ctx_sampling->grammar != NULL && !apply_grammar) {
289+
GGML_ASSERT(original_logits != NULL);
289290
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
290291
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
291292
}
@@ -342,7 +343,7 @@ llama_token llama_sampling_sample(
342343
struct llama_context * ctx_cfg,
343344
const int idx) {
344345
// Call the implementation function with is_resampling set to false by default
345-
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
346+
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
346347
}
347348

348349
llama_token_data_array llama_sampling_prepare(

examples/main/main.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ int main(int argc, char ** argv) {
707707

708708
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
709709

710-
llama_sampling_accept(ctx_sampling, ctx, id, true);
710+
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
711711

712712
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
713713

@@ -728,7 +728,7 @@ int main(int argc, char ** argv) {
728728

729729
// push the prompt in the sampling context in order to apply repetition penalties later
730730
// for the prompt, we don't apply grammar rules
731-
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
731+
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
732732

733733
++n_consumed;
734734
if ((int) embd.size() >= params.n_batch) {

0 commit comments

Comments
 (0)