8
8
#include < cstring>
9
9
#include < ctime>
10
10
#include < cfloat>
11
+ #include < chrono>
11
12
#include < cmath>
12
13
#include < numeric>
13
14
#include < random>
@@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
162
163
cur_p->size = k;
163
164
}
164
165
166
+ static uint32_t get_rng_seed (uint32_t seed) {
167
+ if (seed == LLAMA_DEFAULT_SEED) {
168
+ // use system clock if std::random_device is not a true RNG
169
+ static bool is_rd_prng = std::random_device ().entropy () == 0 ;
170
+ if (is_rd_prng) {
171
+ return (uint32_t ) std::chrono::system_clock::now ().time_since_epoch ().count ();
172
+ }
173
+ std::random_device rd;
174
+ return rd ();
175
+ }
176
+ return seed;
177
+ }
178
+
165
179
// llama_sampler API
166
180
167
181
const char * llama_sampler_name (const struct llama_sampler * smpl) {
@@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {
387
401
388
402
struct llama_sampler_dist {
389
403
const uint32_t seed;
404
+ uint32_t seed_cur;
390
405
391
406
std::mt19937 rng;
392
407
};
@@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
416
431
417
432
static void llama_sampler_dist_reset (struct llama_sampler * smpl) {
418
433
auto * ctx = (llama_sampler_dist *) smpl->ctx ;
419
- ctx->rng = std::mt19937 (ctx->seed );
434
+ ctx->seed_cur = get_rng_seed (ctx->seed );
435
+ ctx->rng .seed (ctx->seed_cur );
420
436
}
421
437
422
438
static void llama_sampler_dist_free (struct llama_sampler * smpl) {
@@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
433
449
};
434
450
435
451
struct llama_sampler * llama_sampler_init_dist (uint32_t seed) {
452
+ auto seed_cur = get_rng_seed (seed);
436
453
return new llama_sampler {
437
454
/* .iface = */ &llama_sampler_dist_i,
438
455
/* .ctx = */ new llama_sampler_dist {
439
- /* .seed = */ seed,
440
- /* .rng = */ std::mt19937 (seed),
456
+ /* .seed = */ seed,
457
+ /* .seed_cur = */ seed_cur,
458
+ /* .rng = */ std::mt19937 (seed_cur),
441
459
},
442
460
};
443
461
}
@@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
1032
1050
const int32_t n_vocab;
1033
1051
1034
1052
const uint32_t seed;
1053
+ uint32_t seed_cur;
1035
1054
1036
1055
const float tau;
1037
1056
const float eta;
@@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
1100
1119
static void llama_sampler_mirostat_reset (struct llama_sampler * smpl) {
1101
1120
auto * ctx = (llama_sampler_mirostat *) smpl->ctx ;
1102
1121
ctx->mu = 2 .0f *ctx->tau ;
1103
- ctx->rng = std::mt19937 (ctx->seed );
1122
+ ctx->seed_cur = get_rng_seed (ctx->seed );
1123
+ ctx->rng .seed (ctx->seed_cur );
1104
1124
}
1105
1125
1106
1126
static void llama_sampler_mirostat_free (struct llama_sampler * smpl) {
@@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
1117
1137
};
1118
1138
1119
1139
struct llama_sampler * llama_sampler_init_mirostat (int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
1140
+ auto seed_cur = get_rng_seed (seed);
1120
1141
return new llama_sampler {
1121
1142
/* .iface = */ &llama_sampler_mirostat_i,
1122
1143
/* .ctx = */ new llama_sampler_mirostat {
1123
- /* .n_vocab = */ n_vocab,
1124
- /* .seed = */ seed,
1125
- /* .tau = */ tau,
1126
- /* .eta = */ eta,
1127
- /* .m = */ m,
1128
- /* .mu = */ 2 .0f *tau,
1129
- /* .rng = */ std::mt19937 (seed),
1144
+ /* .n_vocab = */ n_vocab,
1145
+ /* .seed = */ seed,
1146
+ /* .seed_cur = */ seed_cur,
1147
+ /* .tau = */ tau,
1148
+ /* .eta = */ eta,
1149
+ /* .m = */ m,
1150
+ /* .mu = */ 2 .0f *tau,
1151
+ /* .rng = */ std::mt19937 (seed_cur),
1130
1152
},
1131
1153
};
1132
1154
}
@@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see
1135
1157
1136
1158
struct llama_sampler_mirostat_v2 {
1137
1159
const uint32_t seed;
1160
+ uint32_t seed_cur;
1138
1161
1139
1162
const float tau;
1140
1163
const float eta;
@@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
1179
1202
static void llama_sampler_mirostat_v2_reset (struct llama_sampler * smpl) {
1180
1203
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx ;
1181
1204
ctx->mu = 2 .0f *ctx->tau ;
1182
- ctx->rng = std::mt19937 (ctx->seed );
1205
+ ctx->seed_cur = get_rng_seed (ctx->seed );
1206
+ ctx->rng .seed (ctx->seed_cur );
1183
1207
}
1184
1208
1185
1209
static struct llama_sampler * llama_sampler_mirostat_v2_clone (const struct llama_sampler * smpl) {
@@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
1212
1236
};
1213
1237
1214
1238
struct llama_sampler * llama_sampler_init_mirostat_v2 (uint32_t seed, float tau, float eta) {
1239
+ auto seed_cur = get_rng_seed (seed);
1215
1240
return new llama_sampler {
1216
1241
/* .iface = */ &llama_sampler_mirostat_v2_i,
1217
1242
/* .ctx = */ new llama_sampler_mirostat_v2 {
1218
- /* .seed = */ seed,
1219
- /* .tau = */ tau,
1220
- /* .eta = */ eta,
1221
- /* .mu = */ 2 .0f *tau,
1222
- /* .rng = */ std::mt19937 (seed),
1243
+ /* .seed = */ seed,
1244
+ /* .seed_cur = */ seed_cur,
1245
+ /* .tau = */ tau,
1246
+ /* .eta = */ eta,
1247
+ /* .mu = */ 2 .0f *tau,
1248
+ /* .rng = */ std::mt19937 (seed_cur),
1223
1249
},
1224
1250
};
1225
1251
}
@@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
1505
1531
ignore_eos = false ;
1506
1532
}
1507
1533
1534
+ penalty_last_n = std::max (penalty_last_n, 0 );
1535
+
1508
1536
return new llama_sampler {
1509
1537
/* .iface = */ &llama_sampler_penalties_i,
1510
1538
/* .ctx = */ new llama_sampler_penalties {
@@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
1568
1596
}
1569
1597
}
1570
1598
}
1599
+
1571
1600
static struct llama_sampler * llama_sampler_logit_bias_clone (const struct llama_sampler * smpl) {
1572
1601
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx ;
1573
1602
return llama_sampler_init_logit_bias (ctx->n_vocab , ctx->logit_bias .size (), ctx->logit_bias .data ());
@@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
1599
1628
},
1600
1629
};
1601
1630
}
1631
+
1632
+ // utils
1633
+
1634
+ uint32_t llama_sampler_get_seed (const struct llama_sampler * smpl) {
1635
+ if (smpl->iface == &llama_sampler_dist_i) {
1636
+ return ((const llama_sampler_dist *) smpl->ctx )->seed_cur ;
1637
+ }
1638
+
1639
+ if (smpl->iface == &llama_sampler_mirostat_i) {
1640
+ return ((const llama_sampler_mirostat *) smpl->ctx )->seed_cur ;
1641
+ }
1642
+
1643
+ if (smpl->iface == &llama_sampler_mirostat_v2_i) {
1644
+ return ((const llama_sampler_mirostat_v2 *) smpl->ctx )->seed_cur ;
1645
+ }
1646
+
1647
+ if (smpl->iface == &llama_sampler_chain_i) {
1648
+ const auto * ctx = (const llama_sampler_chain *) smpl->ctx ;
1649
+ for (auto it = ctx->samplers .rbegin (); it != ctx->samplers .rend (); ++it) {
1650
+ const uint32_t seed = llama_sampler_get_seed (*it);
1651
+ if (seed != LLAMA_DEFAULT_SEED) {
1652
+ return seed;
1653
+ }
1654
+ }
1655
+ }
1656
+
1657
+ return LLAMA_DEFAULT_SEED;
1658
+ }
0 commit comments