Skip to content

Commit d41f314

Browse files
author
ochafik
committed
grammars: move token caches to llama_context
1 parent 9f13623 commit d41f314

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

llama.cpp

+16-11
Original file line numberDiff line numberDiff line change
@@ -2325,6 +2325,11 @@ struct llama_context {
23252325
// control vectors
23262326
struct llama_control_vector cvec;
23272327

2328+
// caching token pieces & their decoded codepoints.
2329+
std::vector<std::string> token_pieces;
2330+
std::vector<std::pair<std::vector<uint32_t>,
2331+
llama_partial_utf8>> token_codepoints;
2332+
23282333
#ifdef GGML_USE_MPI
23292334
ggml_mpi_context * ctx_mpi = NULL;
23302335
#endif
@@ -13051,15 +13056,15 @@ struct llama_grammar * llama_grammar_init(
1305113056
}
1305213057
} while (true);
1305313058

13054-
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {}, {}, {} };
13059+
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
1305513060
}
1305613061

1305713062
void llama_grammar_free(struct llama_grammar * grammar) {
1305813063
delete grammar;
1305913064
}
1306013065

1306113066
struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) {
13062-
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8, grammar->token_pieces, grammar->token_codepoints };
13067+
llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 };
1306313068

1306413069
// redirect elements in stacks to point to new rules
1306513070
for (size_t is = 0; is < result->stacks.size(); is++) {
@@ -13552,14 +13557,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1355213557
}
1355313558
}
1355413559

13555-
if (grammar->token_codepoints.empty()) {
13560+
if (ctx->token_codepoints.empty()) {
1355613561
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
13557-
grammar->token_codepoints.resize(n_vocab);
13558-
grammar->token_pieces.resize(n_vocab);
13562+
ctx->token_codepoints.resize(n_vocab);
13563+
ctx->token_pieces.resize(n_vocab);
1355913564
for (llama_token id = 0; id < n_vocab; ++id) {
1356013565
const std::string piece = llama_token_to_piece(ctx, id, false);
13561-
grammar->token_pieces[id] = piece;
13562-
grammar->token_codepoints[id] = decode_utf8(piece, {0, 0});
13566+
ctx->token_pieces[id] = piece;
13567+
ctx->token_codepoints[id] = decode_utf8(piece, {0, 0});
1356313568
}
1356413569
}
1356513570

@@ -13572,15 +13577,15 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1357213577

1357313578
for (size_t i = 0; i < candidates->size; ++i) {
1357413579
const llama_token id = candidates->data[i].id;
13575-
const auto & piece = grammar->token_pieces[id];
13580+
const auto & piece = ctx->token_pieces[id];
1357613581
if (llama_token_is_eog(&ctx->model, id)) {
1357713582
if (!allow_eog) {
1357813583
candidates->data[i].logit = -INFINITY;
1357913584
}
1358013585
} else if (piece.empty() || piece[0] == 0) {
1358113586
candidates->data[i].logit = -INFINITY;
1358213587
} else if (grammar->partial_utf8.n_remain == 0){
13583-
const auto & decoded = grammar->token_codepoints.at(id);
13588+
const auto & decoded = ctx->token_codepoints.at(id);
1358413589
candidates_grammar.push_back({ i, decoded.first.data(), decoded.second });
1358513590
} else {
1358613591
candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
@@ -13778,11 +13783,11 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1377813783
GGML_ASSERT(false);
1377913784
}
1378013785

13781-
const auto & piece = grammar->token_pieces.at(token);
13786+
const auto & piece = ctx->token_pieces.at(token);
1378213787

1378313788
// Note terminating 0 in decoded string
1378413789
const auto decoded = grammar->partial_utf8.n_remain == 0
13785-
? grammar->token_codepoints[token]
13790+
? ctx->token_codepoints[token]
1378613791
: decode_utf8(piece, grammar->partial_utf8);
1378713792
const auto & code_points = decoded.first;
1378813793
std::vector<std::vector<const llama_grammar_element *>> tmp_new_stacks;

llama.h

-5
Original file line numberDiff line numberDiff line change
@@ -1099,11 +1099,6 @@ struct llama_grammar {
10991099

11001100
// buffer for partially generated UTF-8 sequence from accepted tokens
11011101
llama_partial_utf8 partial_utf8;
1102-
1103-
// caching the token pieces & their decoded codepoints.
1104-
std::vector<std::string> token_pieces;
1105-
std::vector<std::pair<std::vector<uint32_t>,
1106-
llama_partial_utf8>> token_codepoints;
11071102
};
11081103

11091104
struct llama_grammar_candidate {

0 commit comments

Comments
 (0)