Skip to content

Commit 614c74d

Browse files
Server: enable lookup decoding
1 parent 3ea0d36 commit 614c74d

14 files changed

+332
-95
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
825825
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
826826
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
827827

828-
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/server/json-schema-to-grammar.mjs.hpp common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
828+
server: examples/server/server.cpp examples/server/utils.hpp examples/server/httplib.h common/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/server/json-schema-to-grammar.mjs.hpp common/stb_image.h ggml.o llama.o ngram-cache.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
829829
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
830830
$(CXX) $(CXXFLAGS) $(filter-out %.h %.hpp $<,$^) -Iexamples/server $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) $(LWINSOCK2)
831831

common/common.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1584,9 +1584,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
15841584
printf(" -ld LOGDIR, --logdir LOGDIR\n");
15851585
printf(" path under which to save YAML logs (no logging if unset)\n");
15861586
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
1587-
printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n");
1587+
printf(" path to static lookup cache to use for n-gram lookup decoding (not updated by generation)\n");
15881588
printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n");
1589-
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
1589+
printf(" path to dynamic lookup cache to use for n-gram lookup decoding (updated by generation)\n");
15901590
printf(" --override-kv KEY=TYPE:VALUE\n");
15911591
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
15921592
printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");

common/ngram-cache.cpp

+12-15
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,18 @@
66
#include <fstream>
77

88
void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max,
9-
std::vector<llama_token> & inp, int nnew, bool print_progress) {
9+
llama_token * inp_data, int inp_size, int nnew, bool print_progress) {
1010
const int64_t t_start_ms = ggml_time_ms();
11-
const int64_t inp_size = inp.size();
1211

1312
const int64_t n_todo = inp_size * (ngram_max - ngram_min + 1);
1413
int64_t n_done = 0;
1514

1615
for (int64_t ngram_size = ngram_min; ngram_size <= ngram_max; ++ngram_size) {
17-
const int64_t i_start = std::max(inp_size - nnew, ngram_size);
16+
const int64_t i_start = std::max((int64_t)(inp_size - nnew), ngram_size);
1817
for (int64_t i = i_start; i < inp_size; ++i) {
1918
const int64_t ngram_start = i - ngram_size;
20-
llama_ngram ngram(&inp[ngram_start], ngram_size);
21-
const llama_token token = inp[i];
19+
llama_ngram ngram(inp_data + ngram_start, ngram_size);
20+
const llama_token token = inp_data[i];
2221

2322
llama_ngram_cache::iterator part_it = ngram_cache.find(ngram);
2423
if (part_it == ngram_cache.end()) {
@@ -48,8 +47,8 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
4847
}
4948

5049
// Helper function to get a token from the combined, speculative sequence of inp and draft.
51-
static llama_token get_token(const std::vector<llama_token> & inp, const std::vector<llama_token> & draft, const size_t i) {
52-
return i < inp.size() ? inp[i] : draft[1 + i - inp.size()];
50+
static llama_token get_token(const llama_token * inp_data, const int inp_size, const std::vector<llama_token> & draft, const int i) {
51+
return i < inp_size ? inp_data[i] : draft[1 + i - inp_size];
5352
}
5453

5554
// If sample size or percentage are below these thresholds the draft is aborted early:
@@ -140,11 +139,10 @@ static llama_token try_draft(
140139
}
141140

142141
void llama_ngram_cache_draft(
143-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
142+
llama_token * inp_data, int inp_size, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
144143
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
145144
) {
146145
GGML_ASSERT(draft.size() == 1);
147-
const int inp_size = inp.size();
148146

149147
if (inp_size < LLAMA_NGRAM_STATIC) {
150148
return;
@@ -156,7 +154,7 @@ void llama_ngram_cache_draft(
156154
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size()-1;
157155
llama_ngram ngram_static;
158156
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
159-
ngram_static.tokens[j-ngram_start_static] = get_token(inp, draft, j);
157+
ngram_static.tokens[j-ngram_start_static] = get_token(inp_data, inp_size, draft, j);
160158
}
161159
llama_ngram_cache::iterator part_static_it = nc_static.find(ngram_static);
162160
llama_ngram_cache_part part_static;
@@ -170,7 +168,7 @@ void llama_ngram_cache_draft(
170168
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size()-1;
171169
llama_ngram ngram_cd;
172170
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
173-
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp, draft, j);
171+
ngram_cd.tokens[j-ngram_start_cd] = get_token(inp_data, inp_size, draft, j);
174172
}
175173
ngrams_cd.push_back(ngram_cd);
176174
}
@@ -216,12 +214,11 @@ void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filen
216214

217215
}
218216

219-
llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
217+
bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename) {
220218
std::ifstream hashmap_file(filename, std::ios::binary);
221219
if (!hashmap_file) {
222-
throw std::ifstream::failure("Unable to open file " + filename);
220+
return false;
223221
}
224-
llama_ngram_cache ngram_cache;
225222

226223
llama_ngram ngram;
227224
int32_t ntokens;
@@ -251,7 +248,7 @@ llama_ngram_cache llama_ngram_cache_load(std::string & filename) {
251248
}
252249
GGML_ASSERT(hashmap_file.eof());
253250

254-
return ngram_cache;
251+
return true;
255252
}
256253

257254
void llama_ngram_cache_merge(llama_ngram_cache & ngram_cache_target, llama_ngram_cache & ngram_cache_add) {

common/ngram-cache.h

+10-6
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ struct llama_ngram {
3939

4040
struct llama_ngram_hash_function {
4141
size_t operator()(const llama_ngram & ngram) const {
42-
size_t hash = 0;
43-
for (int i = 0; i < LLAMA_NGRAM_MAX; ++i) {
44-
hash ^= std::hash<llama_token>{}(ngram.tokens[i]);
42+
size_t hash = ngram.tokens[0];
43+
44+
for (int i = 1; i < LLAMA_NGRAM_MAX; ++i) {
45+
hash <<= 15;
46+
hash ^= ngram.tokens[i];
4547
}
48+
4649
return hash;
4750
}
4851
};
@@ -64,7 +67,7 @@ typedef std::unordered_map<llama_ngram, llama_ngram_cache_part, llama_ngram_hash
6467
// In order to get correct results inp_data can ONLY BE APPENDED TO.
6568
// Changes in the middle need a complete rebuild.
6669
void llama_ngram_cache_update(
67-
llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, std::vector<llama_token> & inp_data, int nnew, bool print_progress);
70+
llama_ngram_cache & ngram_cache, int ngram_min, int ngram_max, llama_token * inp_data, int inp_size, int nnew, bool print_progress);
6871

6972
// Try to draft tokens from ngram caches.
7073
// inp: the tokens generated so far.
@@ -75,7 +78,7 @@ void llama_ngram_cache_update(
7578
// nc_dynamic: ngram cache based on previous user generations.
7679
// nc_static: ngram cache generated from a large text corpus, used for validation.
7780
void llama_ngram_cache_draft(
78-
std::vector<llama_token> & inp, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
81+
llama_token * inp_data, int inp_size, std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
7982
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static);
8083

8184
// Save an ngram cache to a file.
@@ -84,9 +87,10 @@ void llama_ngram_cache_draft(
8487
void llama_ngram_cache_save(llama_ngram_cache & ngram_cache, std::string & filename);
8588

8689
// Load an ngram cache saved with llama_ngram_cache_save.
90+
// ngram_cache: the ngram cache to load the data into.
8791
// filename: the path from which to load the ngram cache.
8892
// returns: an ngram cache containing the information saved to filename.
89-
llama_ngram_cache llama_ngram_cache_load(std::string & filename);
93+
bool llama_ngram_cache_load(llama_ngram_cache & ngram_cache, std::string & filename);
9094

9195
// Merge two ngram caches.
9296
// ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add.

examples/lookup/README.md

+75-6
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,82 @@
11
# llama.cpp/examples/lookup
22

3-
Demonstration of Prompt Lookup Decoding
3+
Demonstration of speculative decoding using n-gram lookup.
4+
Initial version was based on https://github.com/apoorvumang/prompt-lookup-decoding .
5+
The current version uses three separate types of "n-gram caches".
6+
Each of these caches maps how frequently a given n-gram is followed by a specific token.
7+
The difference between the caches lies in what data is used to build them:
48

5-
https://github.com/apoorvumang/prompt-lookup-decoding
9+
* The "context" cache is built using the tokens in the current context of a user generation.
10+
* The "dynamic" cache is built by merging the context caches of previous user generations.
11+
* The "static" cache is built from a large text corpus with no relation to the current context.
612

7-
The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found.
13+
The tradeoff between these caches lies in relevance to the current context vs. the emount of input data.
14+
When trying to draft a new token using n-gram lookup the algorithm is as follows:
815

9-
More info:
16+
* Try to draft a suitable token from the context cache. If a static cache is available, use it to validate the draft candidates. This is done by simply multiplying the frequencies of the two caches.
17+
* Try to draft a suitable token from the dynamic cache, validate with static cache if available.
18+
* Try to draft a suitable token from the static cache.
1019

11-
https://github.com/ggerganov/llama.cpp/pull/4484
12-
https://github.com/ggerganov/llama.cpp/issues/4226
20+
Only a single token sequence with the most likely token candidates is drafted.
21+
All tokens must pass thresholds for frequency and sample size in order to be drafted.
1322

23+
Relevant command line arguments:
24+
25+
- `--draft`: maximum number of additional tokens to draft using n-gram lookup. Default: 5. Set to 0 to disable n-gram lookup. **Results are not deterministic with n-gram lookup enabled due to varying batch size.**
26+
- `-lcs FNAME, --lookup-cache-static FNAME`: optional path to static lookup cache to use for n-gram lookup. Created from a large, unspecific text corpus using `lookup-create`.
27+
- `-lcd FNAME, --lookup-cache-dynamic FNAME`: optional path to dynamic lookup cache to use for n-gram lookup. Contains data from previous generations. Automatically created and filled while the server is running but by default discarded on server exit. Setting this argument tries to initialize the dynamic cache from a file and saves it to said file on server shutdown.
28+
29+
N-gram lookup caches saved to disk are compatible between models as long as they use the same tokenizer
30+
(but for dynamic caches the resulting drafted tokens may be wrong which means there is no speedup).
31+
Furthermore, the data format for both types of caches is the same so they can be used interchangeably (but probably not with good results).
32+
33+
## Usage Examples
34+
35+
### `lookup`
36+
37+
Generation using n-gram lookup:
38+
39+
``` sh
40+
./lookup --model models/opt/llama_2-7b-q4_0.gguf -ngl 99 --n-predict 256 --ignore-eos --draft 3 --color --prompt "Write a love story about two stars that tragically ends in a type Ia supernova. Use a lot of emotional and dramatic language."
41+
```
42+
43+
The `--color` flag highlights the successfully predicted tokens.
44+
The `--lookup-cache-static` and `--lookup-cache-dynamic` arguments can be set to provide static/dynamic caches.
45+
46+
### `lookup-stats`
47+
48+
Determine n-gram lookup effectiveness for a given text corpus (similar to `perplexity`):
49+
50+
``` sh
51+
./lookup-stats --model /opt/models/llama_2-7b-q4_0.gguf --file wikitext-2-raw/wiki.test.raw --draft 3
52+
```
53+
54+
The `--lookup-cache-static` and `--lookup-cache-dynamic` arguments can be set to provide static/dynamic caches.
55+
56+
### `lookup-create`
57+
58+
Create a static lookup cache from a text corpus:
59+
60+
``` sh
61+
./lookup-create --model /opt/models/llama_2-7b-q4_0.gguf --lookup-cache-static wt103-llama_2.lcs --file wikitext-103-raw/wiki.train.raw
62+
```
63+
64+
The `--lookup-cache-static` argument must be set to provide the path to which the static lookup cache will be saved.
65+
The tokenizer for which to create the cache is taken from the provided model.
66+
67+
### `lookup-merge`
68+
69+
Merge two lookup caches into one:
70+
71+
``` sh
72+
./lookup-merge cache_1.lcs cache_2.lcs cache_merged.lcs
73+
```
74+
75+
Can be used for both static and dynamic lookup caches.
76+
77+
## More info:
78+
79+
* https://github.com/ggerganov/llama.cpp/pull/4484
80+
* https://github.com/ggerganov/llama.cpp/issues/4226
81+
* https://github.com/ggerganov/llama.cpp/pull/5479
82+
* https://github.com/ggerganov/llama.cpp/pull/6828

examples/lookup/lookup-create.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ int main(int argc, char ** argv){
3434

3535

3636
llama_ngram_cache ngram_cache;
37-
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
37+
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp.data(), inp.size(), inp.size(), true);
3838
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
3939

4040
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);

examples/lookup/lookup-merge.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ int main(int argc, char ** argv){
3333
}
3434

3535
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
36-
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
36+
llama_ngram_cache ngram_cache_merged;
37+
GGML_ASSERT(llama_ngram_cache_load(ngram_cache_merged, args[0]));
3738

3839
for (size_t i = 1; i < args.size()-1; ++i) {
3940
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
40-
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
41+
llama_ngram_cache ngram_cache;
42+
GGML_ASSERT(llama_ngram_cache_load(ngram_cache, args[i]));
4143

4244
llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
4345
}

examples/lookup/lookup-stats.cpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,15 @@ int main(int argc, char ** argv){
4646
const int64_t t_start_draft_us = ggml_time_us();
4747

4848
if (!params.lookup_cache_static.empty()) {
49-
try {
50-
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
51-
} catch (std::ifstream::failure const &) {
49+
if(!llama_ngram_cache_load(ngram_cache_static, params.lookup_cache_static)) {
5250
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
5351
exit(1);
5452
}
5553
}
5654

5755
if (!params.lookup_cache_dynamic.empty()) {
58-
try {
59-
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
60-
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
56+
// If the dynamic lookup cache doesn't exist it will be created at the end of the program:
57+
llama_ngram_cache_load(ngram_cache_dynamic, params.lookup_cache_dynamic);
6158
}
6259

6360
t_draft_flat_us += ggml_time_us() - t_start_draft_us;
@@ -85,7 +82,9 @@ int main(int argc, char ** argv){
8582

8683
{
8784
const int64_t t_start_draft_us = ggml_time_us();
88-
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
85+
llama_ngram_cache_draft(
86+
pseudo_output.data(), pseudo_output.size(), draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
87+
ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
8988
t_draft_us += ggml_time_us() - t_start_draft_us;
9089
}
9190

@@ -104,7 +103,8 @@ int main(int argc, char ** argv){
104103

105104
{
106105
const int64_t t_start_draft_us = ggml_time_us();
107-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
106+
llama_ngram_cache_update(
107+
ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output.data(), pseudo_output.size(), 1, false);
108108
t_draft_us += ggml_time_us() - t_start_draft_us;
109109
}
110110
}
@@ -114,7 +114,8 @@ int main(int argc, char ** argv){
114114
pseudo_output.push_back(inp_slice[pseudo_output.size()]);
115115
{
116116
const int64_t t_start_draft_us = ggml_time_us();
117-
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
117+
llama_ngram_cache_update(
118+
ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output.data(), pseudo_output.size(), 1, false);
118119
t_draft_us += ggml_time_us() - t_start_draft_us;
119120
}
120121
}

0 commit comments

Comments
 (0)