Skip to content

Commit 68cf1e5

Browse files
mscheong01tybalex
authored andcommitted
sampling : deduplicated code for probability distribution access (ggml-org#6240)
* sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare`
1 parent bd69ff2 commit 68cf1e5

File tree

4 files changed

+28
-76
lines changed

4 files changed

+28
-76
lines changed

common/sampling.cpp

+21-72
Original file line numberDiff line numberDiff line change
@@ -168,76 +168,19 @@ static llama_token llama_sampling_sample_impl(
168168
bool is_resampling) { // Add a parameter to indicate if we are resampling
169169
const llama_sampling_params & params = ctx_sampling->params;
170170

171-
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
172-
173171
const float temp = params.temp;
174-
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
175-
const float penalty_repeat = params.penalty_repeat;
176-
const float penalty_freq = params.penalty_freq;
177-
const float penalty_present = params.penalty_present;
178172
const int mirostat = params.mirostat;
179173
const float mirostat_tau = params.mirostat_tau;
180174
const float mirostat_eta = params.mirostat_eta;
181-
const bool penalize_nl = params.penalize_nl;
182175

183-
auto & prev = ctx_sampling->prev;
184-
auto & cur = ctx_sampling->cur;
185-
186-
llama_token id = 0;
187-
188-
// Get a pointer to the logits
189-
float * logits = llama_get_logits_ith(ctx_main, idx);
190-
191-
// Declare original_logits at the beginning of the function scope
192176
std::vector<float> original_logits;
193-
177+
auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
194178
if (!is_resampling) {
195-
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
196-
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
197-
}
198-
199-
// apply params.logit_bias map
200-
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
201-
logits[it->first] += it->second;
202-
}
203-
204-
if (ctx_cfg) {
205-
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
206-
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
207-
}
208-
209-
cur.clear();
210-
211-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
212-
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
213-
}
214-
215-
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
216-
217-
// apply penalties
218-
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
219-
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
220-
if (penalty_tokens_used_size) {
221-
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
222-
223-
llama_sample_repetition_penalties(ctx_main, &cur_p,
224-
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
225-
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
226-
227-
if (!penalize_nl) {
228-
for (size_t idx = 0; idx < cur_p.size; idx++) {
229-
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
230-
cur_p.data[idx].logit = nl_logit;
231-
break;
232-
}
233-
}
234-
}
235-
}
236-
237-
// If we are in the resampling phase, apply grammar checks before sampling logic
238-
if (is_resampling && ctx_sampling->grammar != NULL) {
239-
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
179+
GGML_ASSERT(!original_logits.empty());
240180
}
181+
llama_token id = 0;
182+
// Get a pointer to the logits
183+
float * logits = llama_get_logits_ith(ctx_main, idx);
241184

242185
if (temp < 0.0) {
243186
// greedy sampling, with probs
@@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
302245
return id;
303246
}
304247

305-
static llama_token_data_array llama_sample_probability_distribution_impl(
248+
static llama_token_data_array llama_sampling_prepare_impl(
306249
struct llama_sampling_context * ctx_sampling,
307250
struct llama_context * ctx_main,
308251
struct llama_context * ctx_cfg,
309-
const int idx) {
252+
const int idx,
253+
bool apply_grammar,
254+
std::vector<float> * original_logits) {
310255
const llama_sampling_params & params = ctx_sampling->params;
311256

312257
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
@@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
315260
const float penalty_repeat = params.penalty_repeat;
316261
const float penalty_freq = params.penalty_freq;
317262
const float penalty_present = params.penalty_present;
263+
318264
const bool penalize_nl = params.penalize_nl;
319265

320266
auto & prev = ctx_sampling->prev;
@@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
323269
// Get a pointer to the logits
324270
float * logits = llama_get_logits_ith(ctx_main, idx);
325271

326-
// Declare original_logits at the beginning of the function scope
327-
std::vector<float> original_logits;
272+
if (apply_grammar && original_logits != NULL) {
273+
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
274+
*original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
275+
}
328276

329277
// apply params.logit_bias map
330278
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
@@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
364312
}
365313
}
366314

367-
// apply grammar checks
368-
if (ctx_sampling->grammar != NULL) {
315+
// apply grammar checks before sampling logic
316+
if (apply_grammar && ctx_sampling->grammar != NULL) {
369317
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
370318
}
371319

372-
llama_sample_softmax(ctx_main, &cur_p);
373320
return cur_p;
374321
}
375322

@@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
382329
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
383330
}
384331

385-
llama_token_data_array llama_sampling_probability_distribution(
332+
llama_token_data_array llama_sampling_prepare(
386333
struct llama_sampling_context * ctx_sampling,
387334
struct llama_context * ctx_main,
388335
struct llama_context * ctx_cfg,
389-
const int idx) {
390-
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
336+
const int idx,
337+
bool apply_grammar,
338+
std::vector<float> * original_logits) {
339+
return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
391340
}
392341

393342
void llama_sampling_accept(

common/sampling.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
131131
struct llama_context * ctx_cfg,
132132
int idx = 0);
133133

134-
// returns the probability that token of given id will be sampled
135-
llama_token_data_array llama_sampling_probability_distribution(
134+
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
135+
llama_token_data_array llama_sampling_prepare(
136136
struct llama_sampling_context * ctx_sampling,
137137
struct llama_context * ctx_main,
138138
struct llama_context * ctx_cfg,
139-
int idx = 0);
139+
int idx = 0,
140+
bool apply_grammar = true,
141+
std::vector<float> * original_logits = nullptr);
140142

141143
void llama_sampling_accept(
142144
struct llama_sampling_context * ctx_sampling,

examples/speculative/speculative.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
219219
if (params.sparams.temp > 0) {
220220
// stochastic verification
221221

222-
llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
222+
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
223+
llama_sample_softmax(ctx_tgt, &dist_tgt);
223224
float p_tgt = 0, p_dft = 0;
224225

225226
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());

retrieval

1.56 MB
Binary file not shown.

0 commit comments

Comments
 (0)