@@ -2811,22 +2811,6 @@ struct llama_kv_cache {
2811
2811
}
2812
2812
};
2813
2813
2814
- // saves the kv_cache state for future recovery
2815
- // used to preserve the kv_cache state before searching for a slot
2816
- struct llama_kv_slot_restorer {
2817
- struct llama_kv_cache_state {
2818
- uint32_t head = 0;
2819
- uint32_t size = 0;
2820
- uint32_t used = 0;
2821
- uint32_t n = 0;
2822
- } old_state;
2823
-
2824
- std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2825
- std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2826
-
2827
- bool restore = false;
2828
- };
2829
-
2830
2814
struct llama_control_vector {
2831
2815
std::vector<struct ggml_tensor *> tensors; // per layer
2832
2816
std::vector<ggml_context_ptr> ctxs;
@@ -3522,21 +3506,24 @@ static bool llama_kv_cache_init(
3522
3506
// updates the cache head
3523
3507
// Note: On success, it's important that cache.head points
3524
3508
// to the first cell of the slot.
3525
- static bool llama_kv_cache_find_slot(
3509
+ struct llama_kv_cache_slot_info {
3510
+ std::pair<uint32_t, uint32_t> boundaries;
3511
+ bool found = false;
3512
+
3513
+ explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3514
+ llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3515
+
3516
+ operator bool() const { return found; }
3517
+ };
3518
+ static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3519
+
3520
+ static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
3526
3521
struct llama_kv_cache & cache,
3527
- const struct llama_ubatch & batch,
3528
- struct llama_kv_slot_restorer * slot_restorer = nullptr) {
3522
+ const struct llama_ubatch & batch) {
3529
3523
const uint32_t n_tokens = batch.n_tokens;
3530
3524
const uint32_t n_seqs = batch.n_seqs;
3531
3525
const uint32_t n_seq_tokens = batch.n_seq_tokens;
3532
3526
3533
- if (slot_restorer != nullptr) {
3534
- slot_restorer->old_state.head = cache.head;
3535
- slot_restorer->old_state.size = cache.size;
3536
- slot_restorer->old_state.used = cache.used;
3537
- slot_restorer->old_state.n = cache.n;
3538
- }
3539
-
3540
3527
if (cache.recurrent) {
3541
3528
// For recurrent state architectures (like Mamba or RWKV),
3542
3529
// each cache cell can store the state for a whole sequence.
@@ -3545,11 +3532,6 @@ static bool llama_kv_cache_find_slot(
3545
3532
// can only process batches with an equal number of new tokens in each sequence
3546
3533
GGML_ASSERT(batch.equal_seqs);
3547
3534
3548
- if (slot_restorer != nullptr) {
3549
- slot_restorer->recurrent_cells = cache.cells;
3550
- slot_restorer->restore = true;
3551
- }
3552
-
3553
3535
int32_t min = cache.size - 1;
3554
3536
int32_t max = 0;
3555
3537
@@ -3563,7 +3545,7 @@ static bool llama_kv_cache_find_slot(
3563
3545
// too big seq_id
3564
3546
// TODO: would it be possible to resize the cache instead?
3565
3547
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3566
- return false ;
3548
+ return llama_kv_cache_slot_info_failed ;
3567
3549
}
3568
3550
if (j > 0) {
3569
3551
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3698,15 +3680,17 @@ static bool llama_kv_cache_find_slot(
3698
3680
// allow getting the range of used cells, from head to head + n
3699
3681
cache.head = min;
3700
3682
cache.n = max - min + 1;
3683
+ cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3684
+ [](const llama_kv_cell& cell){ return !cell.is_empty(); });
3701
3685
3702
3686
// sanity check
3703
- return cache.n >= n_seqs;
3687
+ return llama_kv_cache_slot_info( cache.n >= n_seqs) ;
3704
3688
}
3705
3689
// otherwise, one cell per token.
3706
3690
3707
3691
if (n_tokens > cache.size) {
3708
3692
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3709
- return false ;
3693
+ return llama_kv_cache_slot_info_failed ;
3710
3694
}
3711
3695
3712
3696
uint32_t n_tested = 0;
@@ -3734,15 +3718,10 @@ static bool llama_kv_cache_find_slot(
3734
3718
3735
3719
if (n_tested >= cache.size) {
3736
3720
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3737
- return false ;
3721
+ return llama_kv_cache_slot_info_failed ;
3738
3722
}
3739
3723
}
3740
3724
3741
- if (slot_restorer != nullptr) {
3742
- slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3743
- slot_restorer->restore = true;
3744
- }
3745
-
3746
3725
for (uint32_t s = 0; s < n_seqs; s++) {
3747
3726
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
3748
3727
uint32_t k = s*n_seq_tokens + i;
@@ -3756,7 +3735,7 @@ static bool llama_kv_cache_find_slot(
3756
3735
3757
3736
cache.used += n_tokens;
3758
3737
3759
- return true ;
3738
+ return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens) ;
3760
3739
}
3761
3740
3762
3741
// find how many cells are currently in use
@@ -4032,22 +4011,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
4032
4011
return cparams.flash_attn ? 256u : 32u;
4033
4012
}
4034
4013
4035
- static void llama_kv_cache_slot_restore(
4036
- const struct llama_kv_slot_restorer & restorer,
4037
- struct llama_kv_cache & cache) {
4038
- if (restorer.restore) {
4039
- cache.head = restorer.old_state.head;
4040
- cache.size = restorer.old_state.size;
4041
- cache.used = restorer.old_state.used;
4042
- cache.n = restorer.old_state.n;
4043
-
4044
- if (cache.recurrent) {
4045
- cache.cells = restorer.recurrent_cells;
4046
- } else {
4047
- llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4014
+ // saves the kv_cache state for future recovery.
4015
+ // used to rollback llama_kv_cache_find_slot changes.
4016
+ struct llama_kv_slot_restorer {
4017
+ struct llama_kv_cache_state {
4018
+ uint32_t head = 0;
4019
+ uint32_t n = 0;
4020
+ } old_state;
4021
+
4022
+ std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries; // for non-recurrent models only
4023
+
4024
+ bool do_restore = false;
4025
+
4026
+ explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4027
+ old_state.head = cache.head;
4028
+ old_state.n = cache.n;
4029
+ }
4030
+
4031
+ void save(const struct llama_kv_cache_slot_info& slot) {
4032
+ if (slot) {
4033
+ do_restore = true;
4034
+ if (slot.boundaries.first != slot.boundaries.second) {
4035
+ slot_boundaries.push_back(slot.boundaries);
4036
+ }
4048
4037
}
4049
4038
}
4050
- }
4039
+
4040
+ void restore(struct llama_kv_cache & cache) {
4041
+ if (do_restore) {
4042
+ cache.head = old_state.head;
4043
+ cache.n = old_state.n;
4044
+
4045
+ if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4046
+ llama_kv_cache_seq_rm(cache, -1, -1, -1);
4047
+ } else {
4048
+ for (auto & slot : slot_boundaries) {
4049
+ llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4050
+ }
4051
+ }
4052
+ }
4053
+ }
4054
+ };
4051
4055
4052
4056
//
4053
4057
// model loading and saving
@@ -17307,7 +17311,7 @@ static int llama_decode_internal(
17307
17311
lctx.n_queued_tokens += n_tokens_all;
17308
17312
17309
17313
auto & kv_self = lctx.kv_self;
17310
- llama_kv_slot_restorer kv_slot_restorer;
17314
+ llama_kv_slot_restorer kv_slot_restorer(kv_self) ;
17311
17315
17312
17316
const int64_t n_embd = hparams.n_embd;
17313
17317
const int64_t n_vocab = hparams.n_vocab;
@@ -17392,9 +17396,11 @@ static int llama_decode_internal(
17392
17396
kv_self.head = 0;
17393
17397
}
17394
17398
17395
- if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
17399
+ const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17400
+ if (!slot) {
17396
17401
return 1;
17397
17402
}
17403
+ kv_slot_restorer.save(slot);
17398
17404
17399
17405
if (!kv_self.recurrent) {
17400
17406
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17443,7 +17449,7 @@ static int llama_decode_internal(
17443
17449
17444
17450
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17445
17451
if (compute_status != GGML_STATUS_SUCCESS) {
17446
- llama_kv_cache_slot_restore( kv_slot_restorer, kv_self);
17452
+ kv_slot_restorer.restore( kv_self);
17447
17453
switch (compute_status) {
17448
17454
case GGML_STATUS_ABORTED:
17449
17455
return 2;
0 commit comments