@@ -857,21 +857,23 @@ struct common_init_result common_init_from_params(common_params & params) {
857
857
return iparams;
858
858
}
859
859
860
+ const llama_vocab * vocab = llama_get_vocab (model);
861
+
860
862
if (params.reranking ) {
861
863
bool ok = true ;
862
864
863
- if (llama_token_bos (model ) == LLAMA_TOKEN_NULL) {
864
- LOG_WRN (" %s: warning: model does not have a BOS token, reranking will not work\n " , __func__);
865
+ if (llama_token_bos (vocab ) == LLAMA_TOKEN_NULL) {
866
+ LOG_WRN (" %s: warning: vocab does not have a BOS token, reranking will not work\n " , __func__);
865
867
ok = false ;
866
868
}
867
869
868
- if (llama_token_eos (model ) == LLAMA_TOKEN_NULL) {
869
- LOG_WRN (" %s: warning: model does not have an EOS token, reranking will not work\n " , __func__);
870
+ if (llama_token_eos (vocab ) == LLAMA_TOKEN_NULL) {
871
+ LOG_WRN (" %s: warning: vocab does not have an EOS token, reranking will not work\n " , __func__);
870
872
ok = false ;
871
873
}
872
874
873
- if (llama_token_sep (model ) == LLAMA_TOKEN_NULL) {
874
- LOG_WRN (" %s: warning: model does not have a SEP token, reranking will not work\n " , __func__);
875
+ if (llama_token_sep (vocab ) == LLAMA_TOKEN_NULL) {
876
+ LOG_WRN (" %s: warning: vocab does not have a SEP token, reranking will not work\n " , __func__);
875
877
ok = false ;
876
878
}
877
879
@@ -941,14 +943,14 @@ struct common_init_result common_init_from_params(common_params & params) {
941
943
common_lora_adapters_apply (lctx, params.lora_adapters );
942
944
}
943
945
944
- if (params.sampling .ignore_eos && llama_token_eos (model ) == LLAMA_TOKEN_NULL) {
945
- LOG_WRN (" %s: warning: model does not have an EOS token, ignoring --ignore-eos\n " , __func__);
946
+ if (params.sampling .ignore_eos && llama_token_eos (vocab ) == LLAMA_TOKEN_NULL) {
947
+ LOG_WRN (" %s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n " , __func__);
946
948
params.sampling .ignore_eos = false ;
947
949
}
948
950
949
951
if (params.sampling .ignore_eos ) {
950
- for (llama_token i = 0 ; i < llama_n_vocab (model ); i++) {
951
- if (llama_token_is_eog (model , i)) {
952
+ for (llama_token i = 0 ; i < llama_n_vocab (vocab ); i++) {
953
+ if (llama_token_is_eog (vocab , i)) {
952
954
LOG_INF (" %s: added %s logit bias = %f\n " , __func__, common_token_to_piece (lctx, i).c_str (), -INFINITY);
953
955
params.sampling .logit_bias .push_back ({i, -INFINITY});
954
956
}
@@ -969,8 +971,9 @@ struct common_init_result common_init_from_params(common_params & params) {
969
971
LOG_WRN (" %s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n " , __func__);
970
972
971
973
std::vector<llama_token> tmp;
972
- llama_token bos = llama_token_bos (model);
973
- llama_token eos = llama_token_eos (model);
974
+ llama_token bos = llama_token_bos (vocab);
975
+ llama_token eos = llama_token_eos (vocab);
976
+
974
977
// some models (e.g. T5) don't have a BOS token
975
978
if (bos != LLAMA_TOKEN_NULL) {
976
979
tmp.push_back (bos);
@@ -1559,21 +1562,23 @@ std::vector<llama_token> common_tokenize(
1559
1562
const std::string & text,
1560
1563
bool add_special,
1561
1564
bool parse_special) {
1562
- return common_tokenize (llama_get_model (ctx), text, add_special, parse_special);
1565
+ const llama_model * model = llama_get_model (ctx);
1566
+ const llama_vocab * vocab = llama_get_vocab (model);
1567
+ return common_tokenize (vocab, text, add_special, parse_special);
1563
1568
}
1564
1569
1565
1570
std::vector<llama_token> common_tokenize (
1566
- const struct llama_model * model ,
1571
+ const struct llama_vocab * vocab ,
1567
1572
const std::string & text,
1568
1573
bool add_special,
1569
1574
bool parse_special) {
1570
1575
// upper limit for the number of tokens
1571
1576
int n_tokens = text.length () + 2 * add_special;
1572
1577
std::vector<llama_token> result (n_tokens);
1573
- n_tokens = llama_tokenize (model , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
1578
+ n_tokens = llama_tokenize (vocab , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
1574
1579
if (n_tokens < 0 ) {
1575
1580
result.resize (-n_tokens);
1576
- int check = llama_tokenize (model , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
1581
+ int check = llama_tokenize (vocab , text.data (), text.length (), result.data (), result.size (), add_special, parse_special);
1577
1582
GGML_ASSERT (check == -n_tokens);
1578
1583
} else {
1579
1584
result.resize (n_tokens);
@@ -1582,12 +1587,18 @@ std::vector<llama_token> common_tokenize(
1582
1587
}
1583
1588
1584
1589
std::string common_token_to_piece (const struct llama_context * ctx, llama_token token, bool special) {
1590
+ const llama_model * model = llama_get_model (ctx);
1591
+ const llama_vocab * vocab = llama_get_vocab (model);
1592
+ return common_token_to_piece (vocab, token, special);
1593
+ }
1594
+
1595
+ std::string common_token_to_piece (const struct llama_vocab * vocab, llama_token token, bool special) {
1585
1596
std::string piece;
1586
1597
piece.resize (piece.capacity ()); // using string internal cache, 15 bytes + '\n'
1587
- const int n_chars = llama_token_to_piece (llama_get_model (ctx) , token, &piece[0 ], piece.size (), 0 , special);
1598
+ const int n_chars = llama_token_to_piece (vocab , token, &piece[0 ], piece.size (), 0 , special);
1588
1599
if (n_chars < 0 ) {
1589
1600
piece.resize (-n_chars);
1590
- int check = llama_token_to_piece (llama_get_model (ctx) , token, &piece[0 ], piece.size (), 0 , special);
1601
+ int check = llama_token_to_piece (vocab , token, &piece[0 ], piece.size (), 0 , special);
1591
1602
GGML_ASSERT (check == -n_chars);
1592
1603
}
1593
1604
else {
@@ -1597,13 +1608,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
1597
1608
return piece;
1598
1609
}
1599
1610
1600
- std::string common_detokenize (llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1611
+ std::string common_detokenize (const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
1612
+ const llama_model * model = llama_get_model (ctx);
1613
+ const llama_vocab * vocab = llama_get_vocab (model);
1614
+ return common_detokenize (vocab, tokens, special);
1615
+ }
1616
+
1617
+ std::string common_detokenize (const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
1601
1618
std::string text;
1602
1619
text.resize (std::max (text.capacity (), tokens.size ()));
1603
- int32_t n_chars = llama_detokenize (llama_get_model (ctx) , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
1620
+ int32_t n_chars = llama_detokenize (vocab , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
1604
1621
if (n_chars < 0 ) {
1605
1622
text.resize (-n_chars);
1606
- n_chars = llama_detokenize (llama_get_model (ctx) , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
1623
+ n_chars = llama_detokenize (vocab , tokens.data (), (int32_t )tokens.size (), &text[0 ], (int32_t )text.size (), false , special);
1607
1624
GGML_ASSERT (n_chars <= (int32_t )text.size ()); // whitespace trimming is performed after per-token detokenization
1608
1625
}
1609
1626
@@ -1631,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
1631
1648
1632
1649
bool common_chat_verify_template (const std::string & tmpl) {
1633
1650
llama_chat_message chat[] = {{" user" , " test" }};
1634
- int res = llama_chat_apply_template (nullptr , tmpl.c_str (), chat, 1 , true , nullptr , 0 );
1651
+ const int res = llama_chat_apply_template (tmpl.c_str (), chat, 1 , true , nullptr , 0 );
1635
1652
return res >= 0 ;
1636
1653
}
1637
1654
@@ -1642,35 +1659,34 @@ std::string common_chat_apply_template(const struct llama_model * model,
1642
1659
int alloc_size = 0 ;
1643
1660
bool fallback = false ; // indicate if we must fallback to default chatml
1644
1661
std::vector<llama_chat_message> chat;
1645
- for (auto & msg : msgs) {
1662
+ for (const auto & msg : msgs) {
1646
1663
chat.push_back ({msg.role .c_str (), msg.content .c_str ()});
1647
1664
alloc_size += (msg.role .size () + msg.content .size ()) * 1.25 ;
1648
1665
}
1649
1666
1650
- const char * ptr_tmpl = tmpl.empty () ? nullptr : tmpl.c_str ();
1667
+ const char * ptr_tmpl = tmpl.empty () ? llama_model_chat_template (model) : tmpl.c_str ();
1651
1668
std::vector<char > buf (alloc_size);
1652
1669
1653
1670
// run the first time to get the total output length
1654
- int32_t res = llama_chat_apply_template (model, ptr_tmpl, chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1671
+ int32_t res = llama_chat_apply_template (ptr_tmpl, chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1655
1672
1656
1673
// error: chat template is not supported
1657
1674
if (res < 0 ) {
1658
1675
if (ptr_tmpl != nullptr ) {
1659
1676
// if the custom "tmpl" is not supported, we throw an error
1660
1677
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
1661
1678
throw std::runtime_error (" this custom template is not supported" );
1662
- } else {
1663
- // If the built-in template is not supported, we default to chatml
1664
- res = llama_chat_apply_template (nullptr , " chatml" , chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1665
- fallback = true ;
1666
1679
}
1680
+
1681
+ // If the built-in template is not supported, we default to chatml
1682
+ res = llama_chat_apply_template (" chatml" , chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1683
+ fallback = true ;
1667
1684
}
1668
1685
1669
1686
// if it turns out that our buffer is too small, we resize it
1670
1687
if ((size_t ) res > buf.size ()) {
1671
1688
buf.resize (res);
1672
1689
res = llama_chat_apply_template (
1673
- fallback ? nullptr : model,
1674
1690
fallback ? " chatml" : ptr_tmpl,
1675
1691
chat.data (), chat.size (), add_ass, buf.data (), buf.size ());
1676
1692
}
0 commit comments