Skip to content

Commit ee599f9

Browse files
committed
llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models
1 parent 0026c81 commit ee599f9

File tree

1 file changed

+64
-58
lines changed

1 file changed

+64
-58
lines changed

src/llama.cpp

+64-58
Original file line numberDiff line numberDiff line change
@@ -2811,22 +2811,6 @@ struct llama_kv_cache {
28112811
}
28122812
};
28132813

2814-
// saves the kv_cache state for future recovery
2815-
// used to preserve the kv_cache state before searching for a slot
2816-
struct llama_kv_slot_restorer {
2817-
struct llama_kv_cache_state {
2818-
uint32_t head = 0;
2819-
uint32_t size = 0;
2820-
uint32_t used = 0;
2821-
uint32_t n = 0;
2822-
} old_state;
2823-
2824-
std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
2825-
std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
2826-
2827-
bool restore = false;
2828-
};
2829-
28302814
struct llama_control_vector {
28312815
std::vector<struct ggml_tensor *> tensors; // per layer
28322816
std::vector<ggml_context_ptr> ctxs;
@@ -3522,21 +3506,24 @@ static bool llama_kv_cache_init(
35223506
// updates the cache head
35233507
// Note: On success, it's important that cache.head points
35243508
// to the first cell of the slot.
3525-
static bool llama_kv_cache_find_slot(
3509+
struct llama_kv_cache_slot_info {
3510+
std::pair<uint32_t, uint32_t> boundaries;
3511+
bool found = false;
3512+
3513+
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
3514+
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
3515+
3516+
operator bool() const { return found; }
3517+
};
3518+
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
3519+
3520+
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
35263521
struct llama_kv_cache & cache,
3527-
const struct llama_ubatch & batch,
3528-
struct llama_kv_slot_restorer * slot_restorer = nullptr) {
3522+
const struct llama_ubatch & batch) {
35293523
const uint32_t n_tokens = batch.n_tokens;
35303524
const uint32_t n_seqs = batch.n_seqs;
35313525
const uint32_t n_seq_tokens = batch.n_seq_tokens;
35323526

3533-
if (slot_restorer != nullptr) {
3534-
slot_restorer->old_state.head = cache.head;
3535-
slot_restorer->old_state.size = cache.size;
3536-
slot_restorer->old_state.used = cache.used;
3537-
slot_restorer->old_state.n = cache.n;
3538-
}
3539-
35403527
if (cache.recurrent) {
35413528
// For recurrent state architectures (like Mamba or RWKV),
35423529
// each cache cell can store the state for a whole sequence.
@@ -3545,11 +3532,6 @@ static bool llama_kv_cache_find_slot(
35453532
// can only process batches with an equal number of new tokens in each sequence
35463533
GGML_ASSERT(batch.equal_seqs);
35473534

3548-
if (slot_restorer != nullptr) {
3549-
slot_restorer->recurrent_cells = cache.cells;
3550-
slot_restorer->restore = true;
3551-
}
3552-
35533535
int32_t min = cache.size - 1;
35543536
int32_t max = 0;
35553537

@@ -3563,7 +3545,7 @@ static bool llama_kv_cache_find_slot(
35633545
// too big seq_id
35643546
// TODO: would it be possible to resize the cache instead?
35653547
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
3566-
return false;
3548+
return llama_kv_cache_slot_info_failed;
35673549
}
35683550
if (j > 0) {
35693551
llama_kv_cell & seq = cache.cells[seq_id];
@@ -3698,15 +3680,17 @@ static bool llama_kv_cache_find_slot(
36983680
// allow getting the range of used cells, from head to head + n
36993681
cache.head = min;
37003682
cache.n = max - min + 1;
3683+
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
3684+
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
37013685

37023686
// sanity check
3703-
return cache.n >= n_seqs;
3687+
return llama_kv_cache_slot_info(cache.n >= n_seqs);
37043688
}
37053689
// otherwise, one cell per token.
37063690

37073691
if (n_tokens > cache.size) {
37083692
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
3709-
return false;
3693+
return llama_kv_cache_slot_info_failed;
37103694
}
37113695

37123696
uint32_t n_tested = 0;
@@ -3734,15 +3718,10 @@ static bool llama_kv_cache_find_slot(
37343718

37353719
if (n_tested >= cache.size) {
37363720
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
3737-
return false;
3721+
return llama_kv_cache_slot_info_failed;
37383722
}
37393723
}
37403724

3741-
if (slot_restorer != nullptr) {
3742-
slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
3743-
slot_restorer->restore = true;
3744-
}
3745-
37463725
for (uint32_t s = 0; s < n_seqs; s++) {
37473726
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
37483727
uint32_t k = s*n_seq_tokens + i;
@@ -3756,7 +3735,7 @@ static bool llama_kv_cache_find_slot(
37563735

37573736
cache.used += n_tokens;
37583737

3759-
return true;
3738+
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
37603739
}
37613740

37623741
// find how many cells are currently in use
@@ -4032,22 +4011,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
40324011
return cparams.flash_attn ? 256u : 32u;
40334012
}
40344013

4035-
static void llama_kv_cache_slot_restore(
4036-
const struct llama_kv_slot_restorer & restorer,
4037-
struct llama_kv_cache & cache) {
4038-
if (restorer.restore) {
4039-
cache.head = restorer.old_state.head;
4040-
cache.size = restorer.old_state.size;
4041-
cache.used = restorer.old_state.used;
4042-
cache.n = restorer.old_state.n;
4043-
4044-
if (cache.recurrent) {
4045-
cache.cells = restorer.recurrent_cells;
4046-
} else {
4047-
llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
4014+
// saves the kv_cache state for future recovery.
4015+
// used to rollback llama_kv_cache_find_slot changes.
4016+
struct llama_kv_slot_restorer {
4017+
struct llama_kv_cache_state {
4018+
uint32_t head = 0;
4019+
uint32_t n = 0;
4020+
} old_state;
4021+
4022+
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries; // for non-recurrent models only
4023+
4024+
bool do_restore = false;
4025+
4026+
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
4027+
old_state.head = cache.head;
4028+
old_state.n = cache.n;
4029+
}
4030+
4031+
void save(const struct llama_kv_cache_slot_info& slot) {
4032+
if (slot) {
4033+
do_restore = true;
4034+
if (slot.boundaries.first != slot.boundaries.second) {
4035+
slot_boundaries.push_back(slot.boundaries);
4036+
}
40484037
}
40494038
}
4050-
}
4039+
4040+
void restore(struct llama_kv_cache & cache) {
4041+
if (do_restore) {
4042+
cache.head = old_state.head;
4043+
cache.n = old_state.n;
4044+
4045+
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
4046+
llama_kv_cache_seq_rm(cache, -1, -1, -1);
4047+
} else {
4048+
for (auto & slot : slot_boundaries) {
4049+
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
4050+
}
4051+
}
4052+
}
4053+
}
4054+
};
40514055

40524056
//
40534057
// model loading and saving
@@ -17307,7 +17311,7 @@ static int llama_decode_internal(
1730717311
lctx.n_queued_tokens += n_tokens_all;
1730817312

1730917313
auto & kv_self = lctx.kv_self;
17310-
llama_kv_slot_restorer kv_slot_restorer;
17314+
llama_kv_slot_restorer kv_slot_restorer(kv_self);
1731117315

1731217316
const int64_t n_embd = hparams.n_embd;
1731317317
const int64_t n_vocab = hparams.n_vocab;
@@ -17392,9 +17396,11 @@ static int llama_decode_internal(
1739217396
kv_self.head = 0;
1739317397
}
1739417398

17395-
if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
17399+
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
17400+
if (!slot) {
1739617401
return 1;
1739717402
}
17403+
kv_slot_restorer.save(slot);
1739817404

1739917405
if (!kv_self.recurrent) {
1740017406
// a heuristic, to avoid attending the full cache if it is not yet utilized
@@ -17443,7 +17449,7 @@ static int llama_decode_internal(
1744317449

1744417450
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
1744517451
if (compute_status != GGML_STATUS_SUCCESS) {
17446-
llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
17452+
kv_slot_restorer.restore(kv_self);
1744717453
switch (compute_status) {
1744817454
case GGML_STATUS_ABORTED:
1744917455
return 2;

0 commit comments

Comments
 (0)