@@ -692,6 +692,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
692
692
if (params.logdir .back () != DIRECTORY_SEPARATOR) {
693
693
params.logdir += DIRECTORY_SEPARATOR;
694
694
}
695
+ } else if (arg == " -lcs" || arg == " --lookup-cache-static" ) {
696
+ if (++i >= argc) {
697
+ invalid_param = true ;
698
+ break ;
699
+ }
700
+ params.lookup_cache_static = argv[i];
695
701
} else if (arg == " --save-all-logits" || arg == " --kl-divergence-base" ) {
696
702
if (++i >= argc) {
697
703
invalid_param = true ;
@@ -1064,6 +1070,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
1064
1070
printf (" draft model for speculative decoding\n " );
1065
1071
printf (" -ld LOGDIR, --logdir LOGDIR\n " );
1066
1072
printf (" path under which to save YAML logs (no logging if unset)\n " );
1073
+ printf (" -lcs FNAME, --lookup-cache-static FNAME\n " );
1074
+ printf (" path to static lookup cache to use for lookup decoding\n " );
1067
1075
printf (" --override-kv KEY=TYPE:VALUE\n " );
1068
1076
printf (" advanced option to override model metadata by key. may be specified multiple times.\n " );
1069
1077
printf (" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n " );
@@ -1805,3 +1813,228 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
1805
1813
1806
1814
printf (" \n === Done dumping\n " );
1807
1815
}
1816
+
1817
+ void llama_ngram_cache_update (std::vector<llama_ngram_cache> & ncs, int ngram_min,
1818
+ std::vector<llama_token> & inp, int nnew, bool print_progress) {
1819
+ const int64_t t_start_ms = ggml_time_ms ();
1820
+ const int ngram_max = ngram_min + ncs.size ()-1 ;
1821
+ const int inp_size = inp.size ();
1822
+
1823
+ for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
1824
+ llama_ngram_cache & nc = ncs[ngram_size - ngram_min];
1825
+
1826
+ const int i_start = std::max (inp_size - nnew, ngram_size);
1827
+ for (int i = i_start; i < inp_size; ++i) {
1828
+ const int ngram_start = i - ngram_size;
1829
+ uint64_t ngram = inp[ngram_start];
1830
+ for (int j = ngram_start+1 ; j < ngram_start + ngram_size; ++j) { // FIXME
1831
+ const uint64_t ngram_part = inp[j];
1832
+ ngram <<= 16 ;
1833
+ ngram |= ngram_part;
1834
+ }
1835
+ const llama_token token = inp[i];
1836
+
1837
+ llama_ngram_cache::iterator part_it = nc.find (ngram);
1838
+ if (part_it == nc.end ()) {
1839
+ llama_ngram_cache_part part;
1840
+ part.emplace (token, 1 );
1841
+ nc.emplace (ngram, part);
1842
+ } else {
1843
+ llama_ngram_cache_part::iterator token_count_it = part_it->second .find (token);
1844
+ if (token_count_it == part_it->second .end ()) {
1845
+ part_it->second .emplace (token, 1 );
1846
+ } else {
1847
+ token_count_it->second ++;
1848
+ }
1849
+ }
1850
+ if (print_progress && i % 10000000 == 0 ) {
1851
+ const int64_t t_now_ms = ggml_time_ms ();
1852
+ const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
1853
+ const int64_t eta_min = eta_ms / (60 *1000 );
1854
+ const int64_t eta_s = (eta_ms - eta_min) / 1000 ;
1855
+
1856
+ fprintf (stderr, " %s: %d/%d done, ETA: %02ld:%02ld\n " , __func__, i, inp_size, eta_min, eta_s);
1857
+ }
1858
+ }
1859
+ }
1860
+ }
1861
+
1862
+ // Helper function to get a token from the combined, speculative sequence of inp and draft.
1863
+ static llama_token get_token (const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
1864
+ return i < inp.size () ? inp[i] : draft[1 + i - inp.size ()];
1865
+ };
1866
+
1867
+ // If sample size or percentage in context are below these thresholds the draft is aborted early:
1868
+ constexpr int draft_min_sample_size[LLAMA_NGRAM_MAX] = { 2 , 2 , 1 , 1 };
1869
+ constexpr int draft_min_percent[LLAMA_NGRAM_MAX] = {50 , 50 , 50 , 50 };
1870
+
1871
+ void llama_ngram_cache_draft (
1872
+ std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
1873
+ std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2
1874
+ ) {
1875
+ const int inp_size = inp.size ();
1876
+ const int ngram_max = ngram_min + ncs_t1.size ()-1 ;
1877
+
1878
+ while ((int ) draft.size ()-1 < n_draft) {
1879
+ bool draft_success = false ;
1880
+
1881
+ const int ngram_start_t2 = inp_size-2 + draft.size ()-1 ;
1882
+ uint64_t ngram_t2 = get_token (inp, draft, ngram_start_t2);
1883
+ for (int j = ngram_start_t2+1 ; j < ngram_start_t2 + 2 ; ++j) {
1884
+ const uint64_t token = get_token (inp, draft, j);
1885
+ ngram_t2 <<= 16 ;
1886
+ ngram_t2 |= token;
1887
+ }
1888
+ llama_ngram_cache::iterator part_t2_it = nc_t2.find (ngram_t2);
1889
+ llama_ngram_cache_part part_t2;
1890
+ if (part_t2_it != nc_t2.end ()) {
1891
+ part_t2 = part_t2_it->second ;
1892
+ }
1893
+
1894
+ for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
1895
+ if (ngram_size > inp_size) {
1896
+ continue ;
1897
+ }
1898
+
1899
+ llama_ngram_cache & nc_t1 = ncs_t1[ngram_size - ngram_min];
1900
+
1901
+ const int ngram_start_t1 = inp_size-ngram_size + draft.size ()-1 ;
1902
+ uint64_t ngram_t1 = get_token (inp, draft, ngram_start_t1);
1903
+ for (int j = ngram_start_t1+1 ; j < ngram_start_t1 + ngram_size; ++j) {
1904
+ const uint64_t token = get_token (inp, draft, j);
1905
+ ngram_t1 <<= 16 ;
1906
+ ngram_t1 |= token;
1907
+ }
1908
+
1909
+ llama_ngram_cache::iterator part_t1_it = nc_t1.find (ngram_t1);
1910
+ if (part_t1_it == nc_t1.end ()) {
1911
+ continue ;
1912
+ }
1913
+ const llama_ngram_cache_part part_t1 = part_t1_it->second ;
1914
+
1915
+ int max_count_t1 = 0 ;
1916
+ int max_count_t2 = 0 ;
1917
+ int sum_count_t1 = 0 ;
1918
+ llama_token max_token = -1 ;
1919
+
1920
+ for (std::pair<llama_token, int > token_count_t1 : part_t1) {
1921
+ const llama_token token = token_count_t1.first ;
1922
+
1923
+ llama_ngram_cache_part::iterator token_count_t2_it = part_t2.find (token);
1924
+ const int32_t count_t1 = token_count_t1.second ;
1925
+ const int32_t count_t2 = token_count_t2_it != part_t2.end () ? 100 *token_count_t2_it->second : 1 ;
1926
+
1927
+ if (count_t1*count_t2 > max_count_t1*max_count_t2) {
1928
+ max_token = token;
1929
+ max_count_t1 = count_t1;
1930
+ max_count_t2 = count_t2;
1931
+ }
1932
+ sum_count_t1 += count_t1;
1933
+ }
1934
+ // Skip this candidate if the sample size is too low:
1935
+ if (sum_count_t1 < draft_min_sample_size[ngram_size-1 ]) {
1936
+ continue ;
1937
+ }
1938
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
1939
+ if (100 *max_count_t1 < draft_min_percent[ngram_size-1 ]*sum_count_t1) {
1940
+ continue ;
1941
+ }
1942
+
1943
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count_t1);
1944
+ draft.push_back (max_token);
1945
+ draft_success = true ;
1946
+ break ;
1947
+ }
1948
+
1949
+ if (!draft_success) {
1950
+ int max_count_t2 = 0 ;
1951
+ int sum_count_t2 = 0 ;
1952
+ llama_token max_token = -1 ;
1953
+
1954
+ for (std::pair<llama_token, int > token_count_t2 : part_t2) {
1955
+ const llama_token token = token_count_t2.first ;
1956
+ const int32_t count_t2 = token_count_t2.second ;
1957
+
1958
+ if (count_t2 > max_count_t2) {
1959
+ max_token = token;
1960
+ max_count_t2 = count_t2;
1961
+ }
1962
+ sum_count_t2 += count_t2;
1963
+ }
1964
+
1965
+ // Skip this candidate if the sample size is too low:
1966
+ if (sum_count_t2 < draft_min_sample_size[2 -1 ]) {
1967
+ break ;
1968
+ }
1969
+ // skip this candidate if the empirically most likely token following this token is not likely enough:
1970
+ if (100 *max_count_t2 < draft_min_percent[2 -1 ]*sum_count_t2) {
1971
+ break ;
1972
+ }
1973
+
1974
+ LOG (" - draft candidate: token=%d count=%d\n " , max_token, max_count_t2);
1975
+ draft.push_back (max_token);
1976
+ draft_success = true ;
1977
+ break ;
1978
+ }
1979
+
1980
+ if (!draft_success) {
1981
+ break ;
1982
+ }
1983
+ }
1984
+ };
1985
+
1986
+ void llama_ngram_cache_save (std::vector<llama_ngram_cache> & ngram_cache, std::string & filename) {
1987
+ GGML_ASSERT (ngram_cache.size () == 1 );
1988
+ std::ofstream file_out (filename, std::ios::binary);
1989
+ for (std::pair<uint64_t , llama_ngram_cache_part> item : ngram_cache[0 ]) {
1990
+ const uint64_t ngram = item.first ;
1991
+ llama_ngram_cache_part token_counts = item.second ;
1992
+ GGML_ASSERT (!token_counts.empty ());
1993
+ const int32_t ntokens = token_counts.size ();
1994
+
1995
+
1996
+ file_out.write (reinterpret_cast <const char *>(&ngram), sizeof (uint64_t ));
1997
+ file_out.write (reinterpret_cast <const char *>(&ntokens), sizeof (int32_t ));
1998
+ for (std::pair<llama_token, int32_t > item2 : token_counts) {
1999
+ const llama_token token = item2.first ;
2000
+ const int32_t count = item2.second ;
2001
+ file_out.write (reinterpret_cast <const char *>(&token), sizeof (llama_token));
2002
+ file_out.write (reinterpret_cast <const char *>(&count), sizeof (int32_t ));
2003
+ }
2004
+ }
2005
+
2006
+ }
2007
+
2008
+ llama_ngram_cache llama_ngram_cache_load (std::string & filename) {
2009
+ std::ifstream hashmap_file (filename, std::ios::binary);
2010
+ if (!hashmap_file) {
2011
+ fprintf (stderr, " error: failed to open file '%s'\n " , filename.c_str ());
2012
+ exit (1 );
2013
+ }
2014
+ llama_ngram_cache ngram_cache;
2015
+
2016
+ uint64_t ngram;
2017
+ int32_t ntokens;
2018
+ llama_token token;
2019
+ int32_t count;
2020
+
2021
+ char * ngramc = reinterpret_cast <char *>(&ngram);
2022
+ char * ntokensc = reinterpret_cast <char *>(&ntokens);
2023
+ char * tokenc = reinterpret_cast <char *>(&token);
2024
+ char * countc = reinterpret_cast <char *>(&count);
2025
+ while (hashmap_file.read (ngramc, sizeof (uint64_t ))) {
2026
+ GGML_ASSERT (hashmap_file.read (ntokensc, sizeof (int32_t )));
2027
+ llama_ngram_cache_part token_counts;
2028
+
2029
+ for (int i = 0 ; i < ntokens; ++i) {
2030
+ GGML_ASSERT (hashmap_file.read (tokenc, sizeof (llama_token)));
2031
+ GGML_ASSERT (hashmap_file.read (countc, sizeof (int32_t )));
2032
+ token_counts.emplace (token, count);
2033
+ }
2034
+
2035
+ ngram_cache.emplace (ngram, token_counts);
2036
+ }
2037
+ GGML_ASSERT (hashmap_file.eof ());
2038
+
2039
+ return ngram_cache;
2040
+ }
0 commit comments