Skip to content

sampling: remove duplicated code for probability distribution access #6240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 22 additions & 73 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,77 +168,20 @@ static llama_token llama_sampling_sample_impl(
bool is_resampling) { // Add a parameter to indicate if we are resampling
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));

const float temp = params.temp;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl;

auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;

std::vector<float>* original_logits = nullptr;
auto cur_p = llama_sampling_configure_token_candidates(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
if (!is_resampling) {
GGML_ASSERT(original_logits != nullptr);
}
llama_token id = 0;

// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;

if (!is_resampling) {
// Only make a copy of the original logits if we are not in the resampling phase, not sure if I actually have to do this.
original_logits = std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second;
}

if (ctx_cfg) {
float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
}

cur.clear();

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}

llama_token_data_array cur_p = { cur.data(), cur.size(), false };

// apply penalties
const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];

llama_sample_repetition_penalties(ctx_main, &cur_p,
penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);

if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
cur_p.data[idx].logit = nl_logit;
break;
}
}
}
}

// If we are in the resampling phase, apply grammar checks before sampling logic
if (is_resampling && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

if (temp < 0.0) {
// greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p);
Expand Down Expand Up @@ -293,7 +236,7 @@ static llama_token llama_sampling_sample_impl(
LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str());

// Restore logits from the copy
std::copy(original_logits.begin(), original_logits.end(), logits);
std::copy((*original_logits).begin(), (*original_logits).end(), logits);

return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true); // Pass true for is_resampling
}
Expand All @@ -302,11 +245,13 @@ static llama_token llama_sampling_sample_impl(
return id;
}

static llama_token_data_array llama_sample_probability_distribution_impl(
static llama_token_data_array llama_sampling_configure_token_candidates_impl(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
const int idx,
bool apply_grammar,
std::vector<float>** original_logits) {
const llama_sampling_params & params = ctx_sampling->params;

const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
Expand All @@ -315,6 +260,7 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;

const bool penalize_nl = params.penalize_nl;

auto & prev = ctx_sampling->prev;
Expand All @@ -323,8 +269,10 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
// Get a pointer to the logits
float * logits = llama_get_logits_ith(ctx_main, idx);

// Declare original_logits at the beginning of the function scope
std::vector<float> original_logits;
if (apply_grammar && original_logits != nullptr) {
// Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
*original_logits = new std::vector<float>(logits, logits + llama_n_vocab(llama_get_model(ctx_main)));
}

// apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
Expand Down Expand Up @@ -364,12 +312,11 @@ static llama_token_data_array llama_sample_probability_distribution_impl(
}
}

// apply grammar checks
if (ctx_sampling->grammar != NULL) {
// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
}

llama_sample_softmax(ctx_main, &cur_p);
return cur_p;
}

Expand All @@ -382,12 +329,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}

llama_token_data_array llama_sampling_probability_distribution(
llama_token_data_array llama_sampling_configure_token_candidates(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
const int idx) {
return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
const int idx,
bool apply_grammar,
std::vector<float>** original_logits) {
return llama_sampling_configure_token_candidates_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits);
}

void llama_sampling_accept(
Expand Down
8 changes: 5 additions & 3 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
int idx = 0);

// returns the probability that token of given id will be sampled
llama_token_data_array llama_sampling_probability_distribution(
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_configure_token_candidates(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
int idx = 0,
bool apply_grammar = true,
std::vector<float>** original_logits = nullptr);

void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
Expand Down
3 changes: 2 additions & 1 deletion examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ int main(int argc, char ** argv) {
if (params.sparams.temp > 0) {
// stochastic verification

llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_token_data_array dist_tgt = llama_sampling_configure_token_candidates(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
llama_sample_softmax(ctx_tgt, &dist_tgt);
float p_tgt = 0, p_dft = 0;

// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
Expand Down