Skip to content

Commit 49e794f

Browse files
move code to common, CLI arg, lookup-stats
1 parent 0cf40f3 commit 49e794f

File tree

6 files changed

+450
-302
lines changed

6 files changed

+450
-302
lines changed

Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,8 @@ lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
746746
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
747747
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
748748
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
749+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
750+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
749751

750752
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
751753
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/common.cpp

+233
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
692692
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
693693
params.logdir += DIRECTORY_SEPARATOR;
694694
}
695+
} else if (arg == "-lcs" || arg == "--lookup-cache-static") {
696+
if (++i >= argc) {
697+
invalid_param = true;
698+
break;
699+
}
700+
params.lookup_cache_static = argv[i];
695701
} else if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
696702
if (++i >= argc) {
697703
invalid_param = true;
@@ -1064,6 +1070,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
10641070
printf(" draft model for speculative decoding\n");
10651071
printf(" -ld LOGDIR, --logdir LOGDIR\n");
10661072
printf(" path under which to save YAML logs (no logging if unset)\n");
1073+
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
1074+
printf(" path to static lookup cache to use for lookup decoding\n");
10671075
printf(" --override-kv KEY=TYPE:VALUE\n");
10681076
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
10691077
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
@@ -1805,3 +1813,228 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
18051813

18061814
printf("\n=== Done dumping\n");
18071815
}
1816+
1817+
void llama_ngram_cache_update(std::vector<llama_ngram_cache> & ncs, int ngram_min,
1818+
std::vector<llama_token> & inp, int nnew, bool print_progress) {
1819+
const int64_t t_start_ms = ggml_time_ms();
1820+
const int ngram_max = ngram_min + ncs.size()-1;
1821+
const int inp_size = inp.size();
1822+
1823+
for (int ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
1824+
llama_ngram_cache & nc = ncs[ngram_size - ngram_min];
1825+
1826+
const int i_start = std::max(inp_size - nnew, ngram_size);
1827+
for (int i = i_start; i < inp_size; ++i) {
1828+
const int ngram_start = i - ngram_size;
1829+
uint64_t ngram = inp[ngram_start];
1830+
for (int j = ngram_start+1; j < ngram_start + ngram_size; ++j) { // FIXME
1831+
const uint64_t ngram_part = inp[j];
1832+
ngram <<= 16;
1833+
ngram |= ngram_part;
1834+
}
1835+
const llama_token token = inp[i];
1836+
1837+
llama_ngram_cache::iterator part_it = nc.find(ngram);
1838+
if (part_it == nc.end()) {
1839+
llama_ngram_cache_part part;
1840+
part.emplace(token, 1);
1841+
nc.emplace(ngram, part);
1842+
} else {
1843+
llama_ngram_cache_part::iterator token_count_it = part_it->second.find(token);
1844+
if (token_count_it == part_it->second.end()) {
1845+
part_it->second.emplace(token, 1);
1846+
} else {
1847+
token_count_it->second++;
1848+
}
1849+
}
1850+
if (print_progress && i % 10000000 == 0) {
1851+
const int64_t t_now_ms = ggml_time_ms();
1852+
const int64_t eta_ms = (inp_size - i) * (t_now_ms - t_start_ms) / i;
1853+
const int64_t eta_min = eta_ms / (60*1000);
1854+
const int64_t eta_s = (eta_ms - eta_min) / 1000;
1855+
1856+
fprintf(stderr, "%s: %d/%d done, ETA: %02ld:%02ld\n", __func__, i, inp_size, eta_min, eta_s);
1857+
}
1858+
}
1859+
}
1860+
}
1861+
1862+
// Helper function to get a token from the combined, speculative sequence of inp and draft.
1863+
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
1864+
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
1865+
};
1866+
1867+
// If sample size or percentage in context are below these thresholds the draft is aborted early:
1868+
constexpr int draft_min_sample_size[LLAMA_NGRAM_MAX] = { 2, 2, 1, 1};
1869+
constexpr int draft_min_percent[LLAMA_NGRAM_MAX] = {50, 50, 50, 50};
1870+
1871+
void llama_ngram_cache_draft(
1872+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
1873+
std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2
1874+
) {
1875+
const int inp_size = inp.size();
1876+
const int ngram_max = ngram_min + ncs_t1.size()-1;
1877+
1878+
while ((int) draft.size()-1 < n_draft) {
1879+
bool draft_success = false;
1880+
1881+
const int ngram_start_t2 = inp_size-2 + draft.size()-1;
1882+
uint64_t ngram_t2 = get_token(inp, draft, ngram_start_t2);
1883+
for (int j = ngram_start_t2+1; j < ngram_start_t2 + 2; ++j) {
1884+
const uint64_t token = get_token(inp, draft, j);
1885+
ngram_t2 <<= 16;
1886+
ngram_t2 |= token;
1887+
}
1888+
llama_ngram_cache::iterator part_t2_it = nc_t2.find(ngram_t2);
1889+
llama_ngram_cache_part part_t2;
1890+
if (part_t2_it != nc_t2.end()) {
1891+
part_t2 = part_t2_it->second;
1892+
}
1893+
1894+
for (int ngram_size = ngram_max; ngram_size >= ngram_min; --ngram_size) {
1895+
if (ngram_size > inp_size) {
1896+
continue;
1897+
}
1898+
1899+
llama_ngram_cache & nc_t1 = ncs_t1[ngram_size - ngram_min];
1900+
1901+
const int ngram_start_t1 = inp_size-ngram_size + draft.size()-1;
1902+
uint64_t ngram_t1 = get_token(inp, draft, ngram_start_t1);
1903+
for (int j = ngram_start_t1+1; j < ngram_start_t1 + ngram_size; ++j) {
1904+
const uint64_t token = get_token(inp, draft, j);
1905+
ngram_t1 <<= 16;
1906+
ngram_t1 |= token;
1907+
}
1908+
1909+
llama_ngram_cache::iterator part_t1_it = nc_t1.find(ngram_t1);
1910+
if (part_t1_it == nc_t1.end()) {
1911+
continue;
1912+
}
1913+
const llama_ngram_cache_part part_t1 = part_t1_it->second;
1914+
1915+
int max_count_t1 = 0;
1916+
int max_count_t2 = 0;
1917+
int sum_count_t1 = 0;
1918+
llama_token max_token = -1;
1919+
1920+
for (std::pair<llama_token, int> token_count_t1 : part_t1) {
1921+
const llama_token token = token_count_t1.first;
1922+
1923+
llama_ngram_cache_part::iterator token_count_t2_it = part_t2.find(token);
1924+
const int32_t count_t1 = token_count_t1.second;
1925+
const int32_t count_t2 = token_count_t2_it != part_t2.end() ? 100*token_count_t2_it->second : 1;
1926+
1927+
if (count_t1*count_t2 > max_count_t1*max_count_t2) {
1928+
max_token = token;
1929+
max_count_t1 = count_t1;
1930+
max_count_t2 = count_t2;
1931+
}
1932+
sum_count_t1 += count_t1;
1933+
}
1934+
// Skip this candidate if the sample size is too low:
1935+
if (sum_count_t1 < draft_min_sample_size[ngram_size-1]) {
1936+
continue;
1937+
}
1938+
// skip this candidate if the empirically most likely token following this token is not likely enough:
1939+
if (100*max_count_t1 < draft_min_percent[ngram_size-1]*sum_count_t1) {
1940+
continue;
1941+
}
1942+
1943+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count_t1);
1944+
draft.push_back(max_token);
1945+
draft_success = true;
1946+
break;
1947+
}
1948+
1949+
if (!draft_success) {
1950+
int max_count_t2 = 0;
1951+
int sum_count_t2 = 0;
1952+
llama_token max_token = -1;
1953+
1954+
for (std::pair<llama_token, int> token_count_t2 : part_t2) {
1955+
const llama_token token = token_count_t2.first;
1956+
const int32_t count_t2 = token_count_t2.second;
1957+
1958+
if (count_t2 > max_count_t2) {
1959+
max_token = token;
1960+
max_count_t2 = count_t2;
1961+
}
1962+
sum_count_t2 += count_t2;
1963+
}
1964+
1965+
// Skip this candidate if the sample size is too low:
1966+
if (sum_count_t2 < draft_min_sample_size[2-1]) {
1967+
break;
1968+
}
1969+
// skip this candidate if the empirically most likely token following this token is not likely enough:
1970+
if (100*max_count_t2 < draft_min_percent[2-1]*sum_count_t2) {
1971+
break;
1972+
}
1973+
1974+
LOG(" - draft candidate: token=%d count=%d\n", max_token, max_count_t2);
1975+
draft.push_back(max_token);
1976+
draft_success = true;
1977+
break;
1978+
}
1979+
1980+
if (!draft_success) {
1981+
break;
1982+
}
1983+
}
1984+
};
1985+
1986+
void llama_ngram_cache_save(std::vector<llama_ngram_cache> & ngram_cache, std::string & filename) {
1987+
GGML_ASSERT(ngram_cache.size() == 1);
1988+
std::ofstream file_out(filename, std::ios::binary);
1989+
for (std::pair<uint64_t, llama_ngram_cache_part> item : ngram_cache[0]) {
1990+
const uint64_t ngram = item.first;
1991+
llama_ngram_cache_part token_counts = item.second;
1992+
GGML_ASSERT(!token_counts.empty());
1993+
const int32_t ntokens = token_counts.size();
1994+
1995+
1996+
file_out.write(reinterpret_cast<const char *>(&ngram), sizeof(uint64_t));
1997+
file_out.write(reinterpret_cast<const char *>(&ntokens), sizeof(int32_t));
1998+
for (std::pair<llama_token, int32_t> item2 : token_counts) {
1999+
const llama_token token = item2.first;
2000+
const int32_t count = item2.second;
2001+
file_out.write(reinterpret_cast<const char *>(&token), sizeof(llama_token));
2002+
file_out.write(reinterpret_cast<const char *>(&count), sizeof(int32_t));
2003+
}
2004+
}
2005+
2006+
}
2007+
2008+
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
2009+
std::ifstream hashmap_file(filename, std::ios::binary);
2010+
if (!hashmap_file) {
2011+
fprintf(stderr, "error: failed to open file '%s'\n", filename.c_str());
2012+
exit(1);
2013+
}
2014+
llama_ngram_cache ngram_cache;
2015+
2016+
uint64_t ngram;
2017+
int32_t ntokens;
2018+
llama_token token;
2019+
int32_t count;
2020+
2021+
char * ngramc = reinterpret_cast<char*>(&ngram);
2022+
char * ntokensc = reinterpret_cast<char*>(&ntokens);
2023+
char * tokenc = reinterpret_cast<char*>(&token);
2024+
char * countc = reinterpret_cast<char*>(&count);
2025+
while(hashmap_file.read(ngramc, sizeof(uint64_t))) {
2026+
GGML_ASSERT(hashmap_file.read(ntokensc, sizeof(int32_t)));
2027+
llama_ngram_cache_part token_counts;
2028+
2029+
for (int i = 0; i < ntokens; ++i) {
2030+
GGML_ASSERT(hashmap_file.read(tokenc, sizeof(llama_token)));
2031+
GGML_ASSERT(hashmap_file.read(countc, sizeof(int32_t)));
2032+
token_counts.emplace(token, count);
2033+
}
2034+
2035+
ngram_cache.emplace(ngram, token_counts);
2036+
}
2037+
GGML_ASSERT(hashmap_file.eof());
2038+
2039+
return ngram_cache;
2040+
}

common/common.h

+39-11
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,18 @@ struct gpt_params {
8080
// // sampling parameters
8181
struct llama_sampling_params sparams;
8282

83-
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
84-
std::string model_draft = ""; // draft model for speculative decoding
85-
std::string model_alias = "unknown"; // model alias
86-
std::string prompt = "";
87-
std::string prompt_file = ""; // store the external prompt file name
88-
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
89-
std::string input_prefix = ""; // string to prefix user inputs with
90-
std::string input_suffix = ""; // string to suffix user inputs with
91-
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
92-
std::string logdir = ""; // directory in which to save YAML log files
93-
std::string logits_file = ""; // file for saving *all* logits
83+
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
84+
std::string model_draft = ""; // draft model for speculative decoding
85+
std::string model_alias = "unknown"; // model alias
86+
std::string prompt = "";
87+
std::string prompt_file = ""; // store the external prompt file name
88+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
89+
std::string input_prefix = ""; // string to prefix user inputs with
90+
std::string input_suffix = ""; // string to suffix user inputs with
91+
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
92+
std::string logdir = ""; // directory in which to save YAML log files
93+
std::string lookup_cache_static = ""; // path of ngram cache file for lookup decoding
94+
std::string logits_file = ""; // file for saving *all* logits
9495

9596
std::vector<llama_model_kv_override> kv_overrides;
9697

@@ -258,3 +259,30 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
258259

259260
// Dump the KV cache view showing individual sequences in each cell (long output).
260261
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
262+
263+
#define LLAMA_NGRAM_MAX 4
264+
265+
// Data structures to map n-grams to empirical token probabilities:
266+
typedef std::unordered_map<llama_token, int32_t> llama_ngram_cache_part; // token -> number of times token has been seen
267+
typedef std::unordered_map<uint64_t, llama_ngram_cache_part> llama_ngram_cache; // n-gram -> empirical distribution of following tokens
268+
// n-grams are encoded as 64 bit integers with each of the 4 16 bit sections representing a token id.
269+
// This way no custom hashing function for the n-grams is needed.
270+
271+
// Update an ngram cache with tokens.
272+
// ncs = ngram caches: the hashmaps to modify.
273+
// ngram_min/ngram_max: the min/max size of the ngrams in ncs.
274+
// inp_data: the token sequence on which the hashmaps are based.
275+
// nnew: how many new tokens have been appended to inp_data since the last call to this function.
276+
// print_progress: whether to print progress to stderr
277+
//
278+
// In order to get correct results inp_data can ONLY BE APPENDED TO.
279+
// Changes in the middle need a complete rebuild.
280+
void llama_ngram_cache_update(std::vector<llama_ngram_cache> & ncs, int ngram_min,
281+
std::vector<llama_token> & inp_data, int nnew, bool print_progress);
282+
283+
void llama_ngram_cache_draft(
284+
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft,
285+
std::vector<llama_ngram_cache> & ncs_t1, int ngram_min, llama_ngram_cache & nc_t2);
286+
287+
void llama_ngram_cache_save(std::vector<llama_ngram_cache> & ngram_cache, std::string & filename);
288+
llama_ngram_cache llama_ngram_cache_load(std::string & filename);

0 commit comments

Comments
 (0)