@@ -168,76 +168,19 @@ static llama_token llama_sampling_sample_impl(
168
168
bool is_resampling) { // Add a parameter to indicate if we are resampling
169
169
const llama_sampling_params & params = ctx_sampling->params ;
170
170
171
- const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
172
-
173
171
const float temp = params.temp ;
174
- const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n ;
175
- const float penalty_repeat = params.penalty_repeat ;
176
- const float penalty_freq = params.penalty_freq ;
177
- const float penalty_present = params.penalty_present ;
178
172
const int mirostat = params.mirostat ;
179
173
const float mirostat_tau = params.mirostat_tau ;
180
174
const float mirostat_eta = params.mirostat_eta ;
181
- const bool penalize_nl = params.penalize_nl ;
182
175
183
- auto & prev = ctx_sampling->prev ;
184
- auto & cur = ctx_sampling->cur ;
185
-
186
- llama_token id = 0 ;
187
-
188
- // Get a pointer to the logits
189
- float * logits = llama_get_logits_ith (ctx_main, idx);
190
-
191
- // Declare original_logits at the beginning of the function scope
192
176
std::vector<float > original_logits;
193
-
177
+ auto cur_p = llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
194
178
if (!is_resampling) {
195
- // Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
196
- original_logits = std::vector<float >(logits, logits + llama_n_vocab (llama_get_model (ctx_main)));
197
- }
198
-
199
- // apply params.logit_bias map
200
- for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
201
- logits[it->first ] += it->second ;
202
- }
203
-
204
- if (ctx_cfg) {
205
- float * logits_guidance = llama_get_logits_ith (ctx_cfg, idx);
206
- llama_sample_apply_guidance (ctx_main, logits, logits_guidance, params.cfg_scale );
207
- }
208
-
209
- cur.clear ();
210
-
211
- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
212
- cur.emplace_back (llama_token_data{token_id, logits[token_id], 0 .0f });
213
- }
214
-
215
- llama_token_data_array cur_p = { cur.data (), cur.size (), false };
216
-
217
- // apply penalties
218
- const auto & penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
219
- const int penalty_tokens_used_size = std::min ((int )penalty_tokens.size (), penalty_last_n);
220
- if (penalty_tokens_used_size) {
221
- const float nl_logit = logits[llama_token_nl (llama_get_model (ctx_main))];
222
-
223
- llama_sample_repetition_penalties (ctx_main, &cur_p,
224
- penalty_tokens.data () + penalty_tokens.size () - penalty_tokens_used_size,
225
- penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
226
-
227
- if (!penalize_nl) {
228
- for (size_t idx = 0 ; idx < cur_p.size ; idx++) {
229
- if (cur_p.data [idx].id == llama_token_nl (llama_get_model (ctx_main))) {
230
- cur_p.data [idx].logit = nl_logit;
231
- break ;
232
- }
233
- }
234
- }
235
- }
236
-
237
- // If we are in the resampling phase, apply grammar checks before sampling logic
238
- if (is_resampling && ctx_sampling->grammar != NULL ) {
239
- llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
179
+ GGML_ASSERT (!original_logits.empty ());
240
180
}
181
+ llama_token id = 0 ;
182
+ // Get a pointer to the logits
183
+ float * logits = llama_get_logits_ith (ctx_main, idx);
241
184
242
185
if (temp < 0.0 ) {
243
186
// greedy sampling, with probs
@@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
302
245
return id;
303
246
}
304
247
305
- static llama_token_data_array llama_sample_probability_distribution_impl (
248
+ static llama_token_data_array llama_sampling_prepare_impl (
306
249
struct llama_sampling_context * ctx_sampling,
307
250
struct llama_context * ctx_main,
308
251
struct llama_context * ctx_cfg,
309
- const int idx) {
252
+ const int idx,
253
+ bool apply_grammar,
254
+ std::vector<float > * original_logits) {
310
255
const llama_sampling_params & params = ctx_sampling->params ;
311
256
312
257
const int n_vocab = llama_n_vocab (llama_get_model (ctx_main));
@@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
315
260
const float penalty_repeat = params.penalty_repeat ;
316
261
const float penalty_freq = params.penalty_freq ;
317
262
const float penalty_present = params.penalty_present ;
263
+
318
264
const bool penalize_nl = params.penalize_nl ;
319
265
320
266
auto & prev = ctx_sampling->prev ;
@@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
323
269
// Get a pointer to the logits
324
270
float * logits = llama_get_logits_ith (ctx_main, idx);
325
271
326
- // Declare original_logits at the beginning of the function scope
327
- std::vector<float > original_logits;
272
+ if (apply_grammar && original_logits != NULL ) {
273
+ // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
274
+ *original_logits = {logits, logits + llama_n_vocab (llama_get_model (ctx_main))};
275
+ }
328
276
329
277
// apply params.logit_bias map
330
278
for (auto it = params.logit_bias .begin (); it != params.logit_bias .end (); it++) {
@@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
364
312
}
365
313
}
366
314
367
- // apply grammar checks
368
- if (ctx_sampling->grammar != NULL ) {
315
+ // apply grammar checks before sampling logic
316
+ if (apply_grammar && ctx_sampling->grammar != NULL ) {
369
317
llama_sample_grammar (ctx_main, &cur_p, ctx_sampling->grammar );
370
318
}
371
319
372
- llama_sample_softmax (ctx_main, &cur_p);
373
320
return cur_p;
374
321
}
375
322
@@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
382
329
return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
383
330
}
384
331
385
- llama_token_data_array llama_sampling_probability_distribution (
332
+ llama_token_data_array llama_sampling_prepare (
386
333
struct llama_sampling_context * ctx_sampling,
387
334
struct llama_context * ctx_main,
388
335
struct llama_context * ctx_cfg,
389
- const int idx) {
390
- return llama_sample_probability_distribution_impl (ctx_sampling,ctx_main, ctx_cfg, idx);
336
+ const int idx,
337
+ bool apply_grammar,
338
+ std::vector<float > * original_logits) {
339
+ return llama_sampling_prepare_impl (ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
391
340
}
392
341
393
342
void llama_sampling_accept (
0 commit comments