@@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
179
179
struct llama_context * ctx_main,
180
180
struct llama_context * ctx_cfg,
181
181
const int idx,
182
- bool is_resampling) { // Add a parameter to indicate if we are resampling
182
+ bool is_resampling) {
183
183
const llama_sampling_params & params = ctx_sampling->params ;
184
184
185
185
const float temp = params.temp ;
@@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
188
188
const float mirostat_eta = params.mirostat_eta ;
189
189
190
190
std::vector<float > original_logits;
191
- auto cur_p = llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, ! is_resampling, &original_logits);
192
- if (!is_resampling) {
191
+ auto cur_p = llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
192
+ if (ctx_sampling-> grammar != NULL && !is_resampling) {
193
193
GGML_ASSERT (!original_logits.empty ());
194
194
}
195
195
llama_token id = 0 ;
@@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
252
252
// Restore logits from the copy
253
253
std::copy (original_logits.begin (), original_logits.end (), logits);
254
254
255
- return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, true ); // Pass true for is_resampling
255
+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true );
256
256
}
257
257
}
258
258
@@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
285
285
// Get a pointer to the logits
286
286
float * logits = llama_get_logits_ith (ctx_main, idx);
287
287
288
- if (apply_grammar && original_logits != NULL ) {
288
+ if (ctx_sampling->grammar != NULL && !apply_grammar) {
289
+ GGML_ASSERT (original_logits != NULL );
289
290
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
290
291
*original_logits = {logits, logits + llama_n_vocab (llama_get_model (ctx_main))};
291
292
}
@@ -342,7 +343,7 @@ llama_token llama_sampling_sample(
342
343
struct llama_context * ctx_cfg,
343
344
const int idx) {
344
345
// Call the implementation function with is_resampling set to false by default
345
- return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
346
+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false );
346
347
}
347
348
348
349
llama_token_data_array llama_sampling_prepare (
0 commit comments