@@ -2815,6 +2815,42 @@ struct llama_kv_cache {
2815
2815
}
2816
2816
};
2817
2817
2818
+ class llama_kv_cache_state {
2819
+ struct llama_kv_cache_state_short {
2820
+ uint32_t head = 0;
2821
+ uint32_t size = 0;
2822
+ uint32_t used = 0;
2823
+ uint32_t n = 0;
2824
+
2825
+ std::vector<llama_kv_cell> cells;
2826
+ } old_state;
2827
+
2828
+ bool saved = false;
2829
+
2830
+ public:
2831
+ void save_state(const llama_kv_cache& cache) {
2832
+ old_state.head = cache.head;
2833
+ old_state.size = cache.size;
2834
+ old_state.used = cache.used;
2835
+ old_state.n = cache.n;
2836
+ old_state.cells = cache.cells;
2837
+
2838
+ saved = true;
2839
+ }
2840
+
2841
+ void restore(llama_kv_cache& cache) {
2842
+ if (saved) {
2843
+ cache.head = old_state.head;
2844
+ cache.size = old_state.size;
2845
+ cache.used = old_state.used;
2846
+ cache.n = old_state.n;
2847
+ cache.cells = std::move(old_state.cells);
2848
+
2849
+ saved = false;
2850
+ }
2851
+ }
2852
+ };
2853
+
2818
2854
struct llama_control_vector {
2819
2855
std::vector<struct ggml_tensor *> tensors; // per layer
2820
2856
std::vector<struct ggml_context *> ctxs;
@@ -17184,6 +17220,7 @@ static int llama_decode_internal(
17184
17220
lctx.n_queued_tokens += n_tokens_all;
17185
17221
17186
17222
auto & kv_self = lctx.kv_self;
17223
+ llama_kv_cache_state kv_cache_state_holder;
17187
17224
17188
17225
const int64_t n_embd = hparams.n_embd;
17189
17226
const int64_t n_vocab = hparams.n_vocab;
@@ -17261,6 +17298,7 @@ static int llama_decode_internal(
17261
17298
// non-causal masks do not use the KV cache
17262
17299
if (hparams.causal_attn) {
17263
17300
llama_kv_cache_update(&lctx);
17301
+ kv_cache_state_holder.save_state(kv_self);
17264
17302
17265
17303
// if we have enough unused cells before the current head ->
17266
17304
// better to start searching from the beginning of the cache, hoping to fill it
@@ -17318,16 +17356,17 @@ static int llama_decode_internal(
17318
17356
llama_set_inputs(lctx, ubatch);
17319
17357
17320
17358
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17321
- switch (compute_status) {
17322
- case GGML_STATUS_SUCCESS:
17323
- break;
17324
- case GGML_STATUS_ABORTED:
17325
- return 2;
17326
- case GGML_STATUS_ALLOC_FAILED:
17327
- return -2;
17328
- case GGML_STATUS_FAILED:
17329
- default:
17330
- return -3;
17359
+ if (compute_status != GGML_STATUS_SUCCESS) {
17360
+ kv_cache_state_holder.restore(kv_self);
17361
+ switch (compute_status) {
17362
+ case GGML_STATUS_ABORTED:
17363
+ return 2;
17364
+ case GGML_STATUS_ALLOC_FAILED:
17365
+ return -2;
17366
+ case GGML_STATUS_FAILED:
17367
+ default:
17368
+ return -3;
17369
+ }
17331
17370
}
17332
17371
17333
17372
// update the kv ring buffer
0 commit comments