Skip to content

Commit 1439bad

Browse files
committed
llama : add struct llama_vocab to the API (#11156)
ggml-ci
1 parent bfe781a commit 1439bad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+592
-495
lines changed

common/common.cpp

+46-30
Original file line numberDiff line numberDiff line change
@@ -857,21 +857,23 @@ struct common_init_result common_init_from_params(common_params & params) {
857857
return iparams;
858858
}
859859

860+
const llama_vocab * vocab = llama_get_vocab(model);
861+
860862
if (params.reranking) {
861863
bool ok = true;
862864

863-
if (llama_token_bos(model) == LLAMA_TOKEN_NULL) {
864-
LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__);
865+
if (llama_token_bos(vocab) == LLAMA_TOKEN_NULL) {
866+
LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
865867
ok = false;
866868
}
867869

868-
if (llama_token_eos(model) == LLAMA_TOKEN_NULL) {
869-
LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__);
870+
if (llama_token_eos(vocab) == LLAMA_TOKEN_NULL) {
871+
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
870872
ok = false;
871873
}
872874

873-
if (llama_token_sep(model) == LLAMA_TOKEN_NULL) {
874-
LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__);
875+
if (llama_token_sep(vocab) == LLAMA_TOKEN_NULL) {
876+
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
875877
ok = false;
876878
}
877879

@@ -941,14 +943,14 @@ struct common_init_result common_init_from_params(common_params & params) {
941943
common_lora_adapters_apply(lctx, params.lora_adapters);
942944
}
943945

944-
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) {
945-
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__);
946+
if (params.sampling.ignore_eos && llama_token_eos(vocab) == LLAMA_TOKEN_NULL) {
947+
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
946948
params.sampling.ignore_eos = false;
947949
}
948950

949951
if (params.sampling.ignore_eos) {
950-
for (llama_token i = 0; i < llama_n_vocab(model); i++) {
951-
if (llama_token_is_eog(model, i)) {
952+
for (llama_token i = 0; i < llama_n_vocab(vocab); i++) {
953+
if (llama_token_is_eog(vocab, i)) {
952954
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
953955
params.sampling.logit_bias.push_back({i, -INFINITY});
954956
}
@@ -969,8 +971,9 @@ struct common_init_result common_init_from_params(common_params & params) {
969971
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
970972

971973
std::vector<llama_token> tmp;
972-
llama_token bos = llama_token_bos(model);
973-
llama_token eos = llama_token_eos(model);
974+
llama_token bos = llama_token_bos(vocab);
975+
llama_token eos = llama_token_eos(vocab);
976+
974977
// some models (e.g. T5) don't have a BOS token
975978
if (bos != LLAMA_TOKEN_NULL) {
976979
tmp.push_back(bos);
@@ -1559,21 +1562,23 @@ std::vector<llama_token> common_tokenize(
15591562
const std::string & text,
15601563
bool add_special,
15611564
bool parse_special) {
1562-
return common_tokenize(llama_get_model(ctx), text, add_special, parse_special);
1565+
const llama_model * model = llama_get_model(ctx);
1566+
const llama_vocab * vocab = llama_get_vocab(model);
1567+
return common_tokenize(vocab, text, add_special, parse_special);
15631568
}
15641569

15651570
std::vector<llama_token> common_tokenize(
1566-
const struct llama_model * model,
1571+
const struct llama_vocab * vocab,
15671572
const std::string & text,
15681573
bool add_special,
15691574
bool parse_special) {
15701575
// upper limit for the number of tokens
15711576
int n_tokens = text.length() + 2 * add_special;
15721577
std::vector<llama_token> result(n_tokens);
1573-
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1578+
n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
15741579
if (n_tokens < 0) {
15751580
result.resize(-n_tokens);
1576-
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
1581+
int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
15771582
GGML_ASSERT(check == -n_tokens);
15781583
} else {
15791584
result.resize(n_tokens);
@@ -1582,12 +1587,18 @@ std::vector<llama_token> common_tokenize(
15821587
}
15831588

15841589
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1590+
const llama_model * model = llama_get_model(ctx);
1591+
const llama_vocab * vocab = llama_get_vocab(model);
1592+
return common_token_to_piece(vocab, token, special);
1593+
}
1594+
1595+
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
15851596
std::string piece;
15861597
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
1587-
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
1598+
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
15881599
if (n_chars < 0) {
15891600
piece.resize(-n_chars);
1590-
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special);
1601+
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
15911602
GGML_ASSERT(check == -n_chars);
15921603
}
15931604
else {
@@ -1597,13 +1608,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
15971608
return piece;
15981609
}
15991610

1600-
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1611+
std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1612+
const llama_model * model = llama_get_model(ctx);
1613+
const llama_vocab * vocab = llama_get_vocab(model);
1614+
return common_detokenize(vocab, tokens, special);
1615+
}
1616+
1617+
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
16011618
std::string text;
16021619
text.resize(std::max(text.capacity(), tokens.size()));
1603-
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1620+
int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
16041621
if (n_chars < 0) {
16051622
text.resize(-n_chars);
1606-
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
1623+
n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
16071624
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
16081625
}
16091626

@@ -1631,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
16311648

16321649
bool common_chat_verify_template(const std::string & tmpl) {
16331650
llama_chat_message chat[] = {{"user", "test"}};
1634-
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
1651+
const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
16351652
return res >= 0;
16361653
}
16371654

@@ -1642,35 +1659,34 @@ std::string common_chat_apply_template(const struct llama_model * model,
16421659
int alloc_size = 0;
16431660
bool fallback = false; // indicate if we must fallback to default chatml
16441661
std::vector<llama_chat_message> chat;
1645-
for (auto & msg : msgs) {
1662+
for (const auto & msg : msgs) {
16461663
chat.push_back({msg.role.c_str(), msg.content.c_str()});
16471664
alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
16481665
}
16491666

1650-
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
1667+
const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
16511668
std::vector<char> buf(alloc_size);
16521669

16531670
// run the first time to get the total output length
1654-
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1671+
int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
16551672

16561673
// error: chat template is not supported
16571674
if (res < 0) {
16581675
if (ptr_tmpl != nullptr) {
16591676
// if the custom "tmpl" is not supported, we throw an error
16601677
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
16611678
throw std::runtime_error("this custom template is not supported");
1662-
} else {
1663-
// If the built-in template is not supported, we default to chatml
1664-
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1665-
fallback = true;
16661679
}
1680+
1681+
// If the built-in template is not supported, we default to chatml
1682+
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
1683+
fallback = true;
16671684
}
16681685

16691686
// if it turns out that our buffer is too small, we resize it
16701687
if ((size_t) res > buf.size()) {
16711688
buf.resize(res);
16721689
res = llama_chat_apply_template(
1673-
fallback ? nullptr : model,
16741690
fallback ? "chatml" : ptr_tmpl,
16751691
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
16761692
}

common/common.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ std::vector<llama_token> common_tokenize(
541541
bool parse_special = false);
542542

543543
std::vector<llama_token> common_tokenize(
544-
const struct llama_model * model,
544+
const struct llama_vocab * vocab,
545545
const std::string & text,
546546
bool add_special,
547547
bool parse_special = false);
@@ -553,11 +553,21 @@ std::string common_token_to_piece(
553553
llama_token token,
554554
bool special = true);
555555

556+
std::string common_token_to_piece(
557+
const struct llama_vocab * vocab,
558+
llama_token token,
559+
bool special = true);
560+
556561
// detokenizes a vector of tokens into a string
557562
// should work similar to Python's `tokenizer.decode`
558563
// optionally renders special/control tokens
559564
std::string common_detokenize(
560-
llama_context * ctx,
565+
const struct llama_context * ctx,
566+
const std::vector<llama_token> & tokens,
567+
bool special = true);
568+
569+
std::string common_detokenize(
570+
const struct llama_vocab * vocab,
561571
const std::vector<llama_token> & tokens,
562572
bool special = true);
563573

common/sampling.cpp

+11-6
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ struct common_sampler {
113113
void set_logits(struct llama_context * ctx, int idx) {
114114
const auto * logits = llama_get_logits_ith(ctx, idx);
115115

116-
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
116+
const llama_model * model = llama_get_model(ctx);
117+
const llama_vocab * vocab = llama_get_vocab(model);
118+
119+
const int n_vocab = llama_n_vocab(vocab);
117120

118121
cur.resize(n_vocab);
119122

@@ -142,13 +145,15 @@ std::string common_params_sampling::print() const {
142145
}
143146

144147
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
148+
const llama_vocab * vocab = llama_get_vocab(model);
149+
145150
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
146151

147152
lparams.no_perf = params.no_perf;
148153

149154
auto * result = new common_sampler {
150155
/* .params = */ params,
151-
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"),
156+
/* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
152157
/* .chain = */ llama_sampler_chain_init(lparams),
153158
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
154159
/* .cur = */ {},
@@ -157,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
157162

158163
llama_sampler_chain_add(result->chain,
159164
llama_sampler_init_logit_bias(
160-
llama_n_vocab(model),
165+
llama_n_vocab(vocab),
161166
params.logit_bias.size(),
162167
params.logit_bias.data()));
163168

@@ -172,7 +177,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
172177
c_breakers.push_back(str.c_str());
173178
}
174179

175-
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
180+
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
176181
}
177182
break;
178183
case COMMON_SAMPLER_TYPE_TOP_K:
@@ -194,7 +199,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
194199
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
195200
break;
196201
case COMMON_SAMPLER_TYPE_INFILL:
197-
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
202+
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
198203
break;
199204
case COMMON_SAMPLER_TYPE_PENALTIES:
200205
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
@@ -206,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
206211
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
207212
} else if (params.mirostat == 1) {
208213
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
209-
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
214+
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
210215
} else if (params.mirostat == 2) {
211216
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
212217
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

common/speculative.cpp

+18-15
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,13 @@ bool common_speculative_are_compatible(
7979
const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
8080
const struct llama_model * model_dft = llama_get_model(ctx_dft);
8181

82-
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
82+
const struct llama_vocab * vocab_tgt = llama_get_vocab(model_tgt);
83+
const struct llama_vocab * vocab_dft = llama_get_vocab(model_dft);
84+
85+
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
8386
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
8487

85-
const bool vocab_type_dft = llama_vocab_type(model_dft);
88+
const bool vocab_type_dft = llama_vocab_type(vocab_dft);
8689
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
8790

8891
if (vocab_type_tgt != vocab_type_dft) {
@@ -91,34 +94,34 @@ bool common_speculative_are_compatible(
9194
return false;
9295
}
9396

94-
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
95-
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
96-
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
97-
llama_token_eos(model_tgt) != llama_token_eos(model_dft)) {
98-
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
99-
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt));
100-
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft));
97+
if (llama_add_bos_token(vocab_tgt) != llama_add_bos_token(vocab_dft) ||
98+
llama_add_eos_token(vocab_tgt) != llama_add_eos_token(vocab_dft) ||
99+
llama_token_bos(vocab_tgt) != llama_token_bos(vocab_dft) ||
100+
llama_token_eos(vocab_tgt) != llama_token_eos(vocab_dft)) {
101+
LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
102+
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(vocab_tgt), llama_add_bos_token(vocab_tgt), llama_token_eos(vocab_tgt), llama_add_eos_token(vocab_tgt));
103+
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(vocab_dft), llama_add_bos_token(vocab_dft), llama_token_eos(vocab_dft), llama_add_eos_token(vocab_dft));
101104
return false;
102105
}
103106

104107
{
105-
const int n_vocab_tgt = llama_n_vocab(model_tgt);
106-
const int n_vocab_dft = llama_n_vocab(model_dft);
108+
const int n_vocab_tgt = llama_n_vocab(vocab_tgt);
109+
const int n_vocab_dft = llama_n_vocab(vocab_dft);
107110

108111
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
109112

110113
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
111114
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
112115
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
113-
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
116+
__func__, n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
114117
return false;
115118
}
116119

117120
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
118-
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
119-
const char * token_text_dft = llama_token_get_text(model_dft, i);
121+
const char * token_text_tgt = llama_token_get_text(vocab_tgt, i);
122+
const char * token_text_dft = llama_token_get_text(vocab_dft, i);
120123
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
121-
LOG_ERR("%s: draft model vocab must match target model to use speculation but "
124+
LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
122125
"token %d content differs - target '%s', draft '%s'\n", __func__, i,
123126
common_token_to_piece(ctx_tgt, i).c_str(),
124127
common_token_to_piece(ctx_dft, i).c_str());

examples/batched.swift/Sources/main.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ defer {
2323
}
2424

2525
let model_params = llama_model_default_params()
26-
guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else {
26+
guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else {
2727
print("Failed to load model")
2828
exit(1)
2929
}
3030
defer {
31-
llama_free_model(model)
31+
llama_model_free(model)
3232
}
3333

3434
var tokens = tokenize(text: prompt, add_bos: true)

0 commit comments

Comments
 (0)