Skip to content

Commit 5535683

Browse files
committed
llama: reverting kv_cache in case of failed compute
1 parent 58106a6 commit 5535683

File tree

1 file changed

+49
-10
lines changed

1 file changed

+49
-10
lines changed

src/llama.cpp

+49-10
Original file line numberDiff line numberDiff line change
@@ -2815,6 +2815,42 @@ struct llama_kv_cache {
28152815
}
28162816
};
28172817

2818+
class llama_kv_cache_state {
2819+
struct llama_kv_cache_state_short {
2820+
uint32_t head = 0;
2821+
uint32_t size = 0;
2822+
uint32_t used = 0;
2823+
uint32_t n = 0;
2824+
2825+
std::vector<llama_kv_cell> cells;
2826+
} old_state;
2827+
2828+
bool saved = false;
2829+
2830+
public:
2831+
void save_state(const llama_kv_cache& cache) {
2832+
old_state.head = cache.head;
2833+
old_state.size = cache.size;
2834+
old_state.used = cache.used;
2835+
old_state.n = cache.n;
2836+
old_state.cells = cache.cells;
2837+
2838+
saved = true;
2839+
}
2840+
2841+
void restore(llama_kv_cache& cache) {
2842+
if (saved) {
2843+
cache.head = old_state.head;
2844+
cache.size = old_state.size;
2845+
cache.used = old_state.used;
2846+
cache.n = old_state.n;
2847+
cache.cells = std::move(old_state.cells);
2848+
2849+
saved = false;
2850+
}
2851+
}
2852+
};
2853+
28182854
struct llama_control_vector {
28192855
std::vector<struct ggml_tensor *> tensors; // per layer
28202856
std::vector<struct ggml_context *> ctxs;
@@ -17184,6 +17220,7 @@ static int llama_decode_internal(
1718417220
lctx.n_queued_tokens += n_tokens_all;
1718517221

1718617222
auto & kv_self = lctx.kv_self;
17223+
llama_kv_cache_state kv_cache_state_holder;
1718717224

1718817225
const int64_t n_embd = hparams.n_embd;
1718917226
const int64_t n_vocab = hparams.n_vocab;
@@ -17261,6 +17298,7 @@ static int llama_decode_internal(
1726117298
// non-causal masks do not use the KV cache
1726217299
if (hparams.causal_attn) {
1726317300
llama_kv_cache_update(&lctx);
17301+
kv_cache_state_holder.save_state(kv_self);
1726417302

1726517303
// if we have enough unused cells before the current head ->
1726617304
// better to start searching from the beginning of the cache, hoping to fill it
@@ -17318,16 +17356,17 @@ static int llama_decode_internal(
1731817356
llama_set_inputs(lctx, ubatch);
1731917357

1732017358
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
17321-
switch (compute_status) {
17322-
case GGML_STATUS_SUCCESS:
17323-
break;
17324-
case GGML_STATUS_ABORTED:
17325-
return 2;
17326-
case GGML_STATUS_ALLOC_FAILED:
17327-
return -2;
17328-
case GGML_STATUS_FAILED:
17329-
default:
17330-
return -3;
17359+
if (compute_status != GGML_STATUS_SUCCESS) {
17360+
kv_cache_state_holder.restore(kv_self);
17361+
switch (compute_status) {
17362+
case GGML_STATUS_ABORTED:
17363+
return 2;
17364+
case GGML_STATUS_ALLOC_FAILED:
17365+
return -2;
17366+
case GGML_STATUS_FAILED:
17367+
default:
17368+
return -3;
17369+
}
1733117370
}
1733217371

1733317372
// update the kv ring buffer

0 commit comments

Comments
 (0)