Skip to content

Commit 1f95cf7

Browse files
committed
hparams : move vocab params to llama_vocab (#11159)
ggml-ci
1 parent 0f02297 commit 1f95cf7

File tree

6 files changed

+25
-25
lines changed

6 files changed

+25
-25
lines changed

src/llama-context.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -469,11 +469,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
469469
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
470470
const auto & cparams = lctx.cparams;
471471
const auto & hparams = lctx.model.hparams;
472+
const auto & vocab = lctx.model.vocab;
472473

473474
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
474475

475476
const auto n_batch = cparams.n_batch;
476-
const auto n_vocab = hparams.n_vocab;
477+
const auto n_vocab = vocab.n_vocab();
477478
const auto n_embd = hparams.n_embd;
478479

479480
// TODO: use a per-batch flag for logits presence instead
@@ -540,7 +541,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
540541
void llama_output_reorder(struct llama_context & ctx) {
541542
std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
542543
if (!out_ids.empty()) {
543-
const uint32_t n_vocab = ctx.model.hparams.n_vocab;
544+
const uint32_t n_vocab = ctx.model.vocab.n_vocab();
544545
const uint32_t n_embd = ctx.model.hparams.n_embd;
545546

546547
const int32_t n_outputs = ctx.n_outputs;
@@ -724,7 +725,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
724725
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
725726
}
726727

727-
return ctx->logits + j*ctx->model.hparams.n_vocab;
728+
return ctx->logits + j*ctx->model.vocab.n_vocab();
728729
} catch (const std::exception & err) {
729730
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
730731
#ifndef NDEBUG
@@ -884,7 +885,7 @@ struct llama_data_write {
884885
}
885886

886887
void write_logits(const struct llama_context * ctx) {
887-
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab);
888+
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_vocab());
888889

889890
write(&logits_size, sizeof(logits_size));
890891

src/llama-hparams.h

-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ struct llama_hparams {
3030
bool use_par_res;
3131
bool swin_norm;
3232

33-
uint32_t n_vocab = 0;
3433
uint32_t n_ctx_train; // context size the model was trained on
3534
uint32_t n_embd;
3635
uint32_t n_embd_features = 0;
@@ -41,7 +40,6 @@ struct llama_hparams {
4140
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
4241
uint32_t n_expert = 0;
4342
uint32_t n_expert_used = 0;
44-
uint32_t n_vocab_type = 0; // for BERT-style token types
4543
uint32_t n_rel_attn_bkts = 0;
4644

4745
// for WavTokenizer

src/llama-model.cpp

+9-12
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
402402
// get general kv
403403
ml.get_key(LLM_KV_GENERAL_NAME, name, false);
404404

405-
// get hparams kv
406-
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false);
407-
408405
// everything past this point is not vocab-related
409406
if (hparams.vocab_only) {
410407
return;
@@ -500,6 +497,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
500497
hparams.n_embd_head_v = 0;
501498
}
502499

500+
// for differentiating model types
501+
uint32_t n_vocab = 0;
502+
ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false);
503+
503504
// arch-specific KVs
504505
switch (arch) {
505506
case LLM_ARCH_LLAMA:
@@ -519,7 +520,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
519520
case 26: type = LLM_TYPE_3B; break;
520521
case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B
521522
// granite uses a vocab with len 49152
522-
case 32: type = hparams.n_vocab == 49152 ? LLM_TYPE_3B : (hparams.n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break;
523+
case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break;
523524
case 36: type = LLM_TYPE_8B; break; // granite
524525
case 40: type = LLM_TYPE_13B; break;
525526
case 48: type = LLM_TYPE_34B; break;
@@ -621,7 +622,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
621622
{
622623
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
623624
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
624-
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
625625
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
626626

627627
switch (hparams.n_layer) {
@@ -644,7 +644,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
644644
{
645645
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
646646
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
647-
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
648647
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false);
649648
hparams.f_max_alibi_bias = 8.0f;
650649

@@ -658,7 +657,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
658657
{
659658
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
660659
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
661-
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, hparams.n_vocab_type);
662660
ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
663661

664662
if (hparams.n_layer == 12 && hparams.n_embd == 768) {
@@ -1365,8 +1363,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
13651363
const int64_t n_embd_head_v = hparams.n_embd_head_v;
13661364
const int64_t n_ff = hparams.n_ff();
13671365
const int64_t n_embd_gqa = n_embd_v_gqa;
1368-
const int64_t n_vocab = hparams.n_vocab;
1369-
const int64_t n_vocab_type = hparams.n_vocab_type;
1366+
const int64_t n_vocab = vocab.n_vocab();
1367+
const int64_t n_token_types = vocab.n_token_types();
13701368
const int64_t n_rot = hparams.n_rot;
13711369
const int64_t n_expert = hparams.n_expert;
13721370
const int64_t n_expert_used = hparams.n_expert_used;
@@ -1811,7 +1809,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
18111809
case LLM_ARCH_NOMIC_BERT:
18121810
{
18131811
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
1814-
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0);
1812+
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
18151813

18161814
if (arch == LLM_ARCH_BERT) {
18171815
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
@@ -1865,7 +1863,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
18651863
case LLM_ARCH_JINA_BERT_V2:
18661864
{
18671865
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
1868-
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_vocab_type}, 0); // token_type_embeddings
1866+
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings
18691867

18701868
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm
18711869
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias
@@ -3494,7 +3492,6 @@ void llama_model::print_info() const {
34943492

34953493
// hparams
34963494
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str());
3497-
LLAMA_LOG_INFO("%s: n_vocab (hp) = %u\n", __func__, hparams.n_vocab);
34983495
LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only);
34993496

35003497
if (!hparams.vocab_only) {

src/llama-vocab.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,7 @@ struct fragment_buffer_variant {
12051205

12061206
struct llama_vocab::impl {
12071207
uint32_t n_vocab = 0;
1208+
uint32_t n_token_types = 0; // for BERT-style token types
12081209

12091210
std::unordered_map<std::string, llama_token> token_to_id;
12101211
std::vector<token_data> id_to_token;
@@ -1286,6 +1287,7 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
12861287
struct gguf_context * ctx = ml.meta.get();
12871288

12881289
auto & n_vocab = pimpl->n_vocab;
1290+
auto & n_token_types = pimpl->n_token_types;
12891291
auto & id_to_token = pimpl->id_to_token;
12901292
auto & token_to_id = pimpl->token_to_id;
12911293
auto & special_eog_ids = pimpl->special_eog_ids;
@@ -1300,6 +1302,8 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
13001302
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
13011303
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
13021304

1305+
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false);
1306+
13031307
if (tokenizer_model == "no_vocab" || tokenizer_model == "none") {
13041308
type = LLAMA_VOCAB_TYPE_NONE;
13051309

@@ -2013,6 +2017,10 @@ uint32_t llama_vocab::n_vocab() const {
20132017
return (uint32_t) pimpl->id_to_token.size();
20142018
}
20152019

2020+
uint32_t llama_vocab::n_token_types() const {
2021+
return (uint32_t) pimpl->n_token_types;
2022+
}
2023+
20162024
std::string llama_vocab::type_name() const{
20172025
switch (type) {
20182026
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";

src/llama-vocab.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ struct llama_vocab {
2424
enum llama_vocab_type get_type() const;
2525
enum llama_vocab_pre_type get_pre_type() const;
2626

27-
// TODO: how to deduplicate with llama_hparams.n_vocab ?
2827
uint32_t n_vocab() const;
28+
uint32_t n_token_types() const;
2929

3030
std::string type_name() const;
3131

src/llama.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam
6565
model.load_stats(ml);
6666
model.print_info();
6767

68-
if (model.vocab.get_type() != LLAMA_VOCAB_TYPE_NONE &&
69-
model.hparams.n_vocab != model.vocab.n_vocab()) {
70-
throw std::runtime_error("vocab size mismatch");
71-
}
72-
7368
if (params.vocab_only) {
7469
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
7570
return 0;
@@ -8342,6 +8337,7 @@ static int llama_decode_impl(
83428337
const uint32_t n_tokens_all = batch.n_tokens;
83438338

83448339
const auto & model = lctx.model;
8340+
const auto & vocab = model.vocab;
83458341
const auto & hparams = model.hparams;
83468342
const auto & cparams = lctx.cparams;
83478343

@@ -8369,7 +8365,7 @@ static int llama_decode_impl(
83698365
llama_kv_slot_restorer kv_slot_restorer(kv_self);
83708366

83718367
const int64_t n_embd = hparams.n_embd;
8372-
const int64_t n_vocab = hparams.n_vocab;
8368+
const int64_t n_vocab = vocab.n_vocab();
83738369

83748370
uint32_t n_outputs = 0;
83758371
uint32_t n_outputs_prev = 0;

0 commit comments

Comments
 (0)