Skip to content

Commit b990379

Browse files
VJHackggerganov
authored andcommitted
sampling: add Top-nσ sampler (ggml-org#11223)
* initial sampling changes: * completed top nsigma sampler implementation * apply parameter to only llama-cli * updated readme * added tests and fixed nsigma impl * cleaned up pr * format * format * format * removed commented tests * cleanup pr and remove explicit floats * added top-k sampler to improve performance * changed sigma to float * fixed string format to float * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update common/sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Update src/llama-sampling.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * added llama_sampler_init --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 091adba commit b990379

File tree

7 files changed

+147
-40
lines changed

7 files changed

+147
-40
lines changed

common/arg.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
946946
params.sampling.min_p = std::stof(value);
947947
}
948948
).set_sparam());
949+
add_opt(common_arg(
950+
{"--top-nsigma"}, "N",
951+
string_format("top-n-sigma sampling (default: %.1f, -1.0 = disabled)", params.sampling.top_n_sigma),
952+
[](common_params & params, const std::string & value) {
953+
params.sampling.top_n_sigma = std::stof(value);
954+
}
955+
).set_examples({LLAMA_EXAMPLE_MAIN}).set_sparam());
949956
add_opt(common_arg(
950957
{"--xtc-probability"}, "N",
951958
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),

common/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ struct common_params_sampling {
140140
int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty
141141
int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size)
142142
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
143+
float top_n_sigma = -1.00f;// -1.0 = disabled
143144
float mirostat_tau = 5.00f; // target entropy
144145
float mirostat_eta = 0.10f; // learning rate
145146
bool ignore_eos = false;

common/sampling.cpp

+46-40
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,11 @@ std::string common_params_sampling::print() const {
134134
snprintf(result, sizeof(result),
135135
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
136136
"\tdry_multiplier = %.3f, dry_base = %.3f, dry_allowed_length = %d, dry_penalty_last_n = %d\n"
137-
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, temp = %.3f\n"
137+
"\ttop_k = %d, top_p = %.3f, min_p = %.3f, xtc_probability = %.3f, xtc_threshold = %.3f, typical_p = %.3f, top_n_sigma = %.3f, temp = %.3f\n"
138138
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f",
139139
penalty_last_n, penalty_repeat, penalty_freq, penalty_present,
140140
dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n,
141-
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, temp,
141+
top_k, top_p, min_p, xtc_probability, xtc_threshold, typ_p, top_n_sigma, temp,
142142
mirostat, mirostat_eta, mirostat_tau);
143143

144144
return std::string(result);
@@ -188,45 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
188188
params.logit_bias.data()));
189189

190190
if (params.mirostat == 0) {
191-
for (const auto & cnstr : params.samplers) {
192-
switch (cnstr) {
193-
case COMMON_SAMPLER_TYPE_DRY:
194-
{
195-
std::vector<const char *> c_breakers;
196-
c_breakers.reserve(params.dry_sequence_breakers.size());
197-
for (const auto & str : params.dry_sequence_breakers) {
198-
c_breakers.push_back(str.c_str());
191+
if (params.top_n_sigma >= 0) {
192+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
193+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
194+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
195+
} else {
196+
for (const auto & cnstr : params.samplers) {
197+
switch (cnstr) {
198+
case COMMON_SAMPLER_TYPE_DRY:
199+
{
200+
std::vector<const char *> c_breakers;
201+
c_breakers.reserve(params.dry_sequence_breakers.size());
202+
for (const auto & str : params.dry_sequence_breakers) {
203+
c_breakers.push_back(str.c_str());
204+
}
205+
206+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
199207
}
200-
201-
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
202-
}
203-
break;
204-
case COMMON_SAMPLER_TYPE_TOP_K:
205-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
206-
break;
207-
case COMMON_SAMPLER_TYPE_TOP_P:
208-
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
209-
break;
210-
case COMMON_SAMPLER_TYPE_MIN_P:
211-
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
212-
break;
213-
case COMMON_SAMPLER_TYPE_XTC:
214-
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
215-
break;
216-
case COMMON_SAMPLER_TYPE_TYPICAL_P:
217-
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
218-
break;
219-
case COMMON_SAMPLER_TYPE_TEMPERATURE:
220-
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
221-
break;
222-
case COMMON_SAMPLER_TYPE_INFILL:
223-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
224-
break;
225-
case COMMON_SAMPLER_TYPE_PENALTIES:
226-
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
227-
break;
228-
default:
229-
GGML_ASSERT(false && "unknown sampler type");
208+
break;
209+
case COMMON_SAMPLER_TYPE_TOP_K:
210+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
211+
break;
212+
case COMMON_SAMPLER_TYPE_TOP_P:
213+
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
214+
break;
215+
case COMMON_SAMPLER_TYPE_MIN_P:
216+
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
217+
break;
218+
case COMMON_SAMPLER_TYPE_XTC:
219+
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
220+
break;
221+
case COMMON_SAMPLER_TYPE_TYPICAL_P:
222+
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
223+
break;
224+
case COMMON_SAMPLER_TYPE_TEMPERATURE:
225+
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
226+
break;
227+
case COMMON_SAMPLER_TYPE_INFILL:
228+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
229+
break;
230+
case COMMON_SAMPLER_TYPE_PENALTIES:
231+
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
232+
break;
233+
default:
234+
GGML_ASSERT(false && "unknown sampler type");
235+
}
230236
}
231237
}
232238
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));

examples/main/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ Being experimental and unique, XTC is disabled by default. The recommended combi
265265

266266
Example usage: `--xtc-probability 0.5 --xtc-threshold 0.1`
267267

268+
### Top-nσ Sampling
269+
270+
- `--top-nsigma N`: Limit the next token selection to a subset of tokens with pre-softmax logits that are within n * σ less than the max logit (default: -1, -1 = disabled).
271+
272+
Top-nσ sampling is a text generation method that selects tokens based on a statistical threshold in pre-softmax logits. It works by only sampling from tokens with logits that are within n * σ of the maximum logit. This method helps maintain a stable sampling space regardless of temperature scaling, allowing it to perform well on reasoning tasks even in high temperatures. Without complex probability manipulation, it efficiently filters tokens directly on the pre-softmax logits. A higher value for top-nsigma (e.g., 5) will take more noisy tokens into consideration, while a lower value (e.g., 1) will focous on the more informative region of the sampling space.
273+
274+
Example usage: `--top-nsigma 1`
275+
268276
### Logit Bias
269277

270278
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.

include/llama.h

+3
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,9 @@ extern "C" {
11721172
/// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
11731173
LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed);
11741174

1175+
/// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
1176+
LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);
1177+
11751178
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
11761179
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
11771180
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.

src/llama-sampling.cpp

+67
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,73 @@ struct llama_sampler * llama_sampler_init_penalties(
16981698
);
16991699
}
17001700

1701+
// top-n-sigma
1702+
1703+
struct llama_sampler_top_n_sigma {
1704+
const float n;
1705+
};
1706+
1707+
static const char * llama_sampler_top_n_sigma_name(const struct llama_sampler * /*smpl*/) {
1708+
return "top-n-sigma";
1709+
}
1710+
1711+
static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
1712+
const auto * ctx = (llama_sampler_top_n_sigma *) smpl->ctx;
1713+
1714+
// find max logit and calculate mean
1715+
float max = cur_p->data[0].logit;
1716+
float logits_sum = 0;
1717+
for (size_t i = 0; i < cur_p->size; ++i) {
1718+
if (cur_p->data[i].logit > max) {
1719+
max = cur_p->data[i].logit;
1720+
}
1721+
logits_sum += cur_p->data[i].logit;
1722+
}
1723+
float mean = logits_sum/cur_p->size;
1724+
1725+
// calculate standard deviation
1726+
float acc = 0;
1727+
for (size_t i = 0; i < cur_p->size; ++i) {
1728+
acc += pow(cur_p->data[i].logit - mean, 2);
1729+
}
1730+
float std = sqrt(acc/cur_p->size);
1731+
1732+
//apply mask
1733+
for (size_t i = 0; i < cur_p->size; ++i) {
1734+
if (cur_p->data[i].logit < max - (ctx->n * std)) {
1735+
cur_p->data[i].logit = -INFINITY;
1736+
}
1737+
}
1738+
llama_sampler_softmax_impl(cur_p);
1739+
}
1740+
1741+
static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) {
1742+
const auto * ctx = (const llama_sampler_top_n_sigma *) smpl->ctx;
1743+
return llama_sampler_init_top_n_sigma(ctx->n);
1744+
}
1745+
1746+
static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
1747+
delete (llama_sampler_top_n_sigma *) smpl->ctx;
1748+
}
1749+
1750+
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
1751+
/* .name = */ llama_sampler_top_n_sigma_name,
1752+
/* .accept = */ nullptr,
1753+
/* .apply = */ llama_sampler_top_n_sigma_apply,
1754+
/* .reset = */ nullptr,
1755+
/* .clone = */ llama_sampler_top_n_sigma_clone,
1756+
/* .free = */ llama_sampler_top_n_sigma_free,
1757+
};
1758+
1759+
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
1760+
return llama_sampler_init(
1761+
/* .iface = */ &llama_sampler_top_n_sigma_i,
1762+
/* .ctx = */ new llama_sampler_top_n_sigma {
1763+
/* .n = */ n,
1764+
}
1765+
);
1766+
}
1767+
17011768
// DRY
17021769

17031770
struct llama_sampler_dry {

tests/test-sampling.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,17 @@ static void test_dry(
181181
tester.check();
182182
}
183183

184+
static void test_top_n_sigma(const std::vector<float> & probs, const std::vector<float> & probs_expected, int n) {
185+
sampler_tester tester(probs, probs_expected);
186+
187+
DUMP(&tester.cur_p);
188+
tester.apply(llama_sampler_init_top_n_sigma(n));
189+
tester.apply(llama_sampler_init_dist (0));
190+
DUMP(&tester.cur_p);
191+
192+
tester.check();
193+
}
194+
184195
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
185196
) {
186197
sampler_tester tester(n_vocab);
@@ -348,6 +359,10 @@ int main(void) {
348359
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 1}, {0.241818f, 0.241818f, 0.241818f, 0.241818f, 0.032727f}, 2.0f, 1.1f, 2, 5, {});
349360
test_dry({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 3, 4, 0, 1}, {0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, 1.0f, 1.1f, 4, 7, {});
350361

362+
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.571429f, 0.428571f, 0.0f, 0.0f}, 1.00f);
363+
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.00f);
364+
test_top_n_sigma({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 3.00f);
365+
351366
test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f);
352367
test_sampler_queue(10000, "k", 1, 1.0f, 1.0f);
353368
test_sampler_queue(10000, "p", 10000, 1.0f, 1.0f);

0 commit comments

Comments
 (0)