Skip to content

Commit 8dcd771

Browse files
lookup: evaluation tools, use corpus/previous gens
1 parent f9c7ba3 commit 8dcd771

10 files changed

+396
-61
lines changed

Makefile

+11-2
Original file line numberDiff line numberDiff line change
@@ -669,14 +669,17 @@ grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h
669669
train.o: common/train.cpp common/train.h
670670
$(CXX) $(CXXFLAGS) -c $< -o $@
671671

672+
ngram-cache.o: common/ngram-cache.cpp common/ngram-cache.h
673+
$(CXX) $(CXXFLAGS) -c $< -o $@
674+
672675
libllama.so: llama.o ggml.o $(OBJS)
673676
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
674677

675678
libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
676679
ar rcs libllama.a llama.o ggml.o $(OBJS) $(COMMON_DEPS)
677680

678681
clean:
679-
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
682+
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
680683
find examples pocs -type f -name "*.o" -delete
681684

682685
#
@@ -806,9 +809,15 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
806809
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
807810
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
808811

809-
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
812+
lookup: examples/lookup/lookup.cpp ggml.o llama.o ngram-cache.o $(COMMON_DEPS) $(OBJS)
810813
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
811814
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
815+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
816+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
817+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-merge.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp)
818+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp) -o lookup-merge $(LDFLAGS)
819+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
820+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
812821

813822
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
814823
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ add_library(${TARGET} STATIC
6262
grammar-parser.cpp
6363
train.h
6464
train.cpp
65+
ngram-cache.h
66+
ngram-cache.cpp
6567
)
6668

6769
if (BUILD_SHARED_LIBS)

common/common.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, gpt_params & params, int
948948
}
949949
return true;
950950
}
951+
if (arg == "-lcs" || arg == "--lookup-cache-static") {
952+
if (++i >= argc) {
953+
invalid_param = true;
954+
return true;
955+
}
956+
params.lookup_cache_static = argv[i];
957+
return true;
958+
}
959+
if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
960+
if (++i >= argc) {
961+
invalid_param = true;
962+
return true;
963+
}
964+
params.lookup_cache_dynamic = argv[i];
965+
return true;
966+
}
951967
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
952968
if (++i >= argc) {
953969
invalid_param = true;
@@ -1410,6 +1426,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14101426
printf(" draft model for speculative decoding\n");
14111427
printf(" -ld LOGDIR, --logdir LOGDIR\n");
14121428
printf(" path under which to save YAML logs (no logging if unset)\n");
1429+
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
1430+
printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n");
1431+
printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n");
1432+
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
14131433
printf(" --override-kv KEY=TYPE:VALUE\n");
14141434
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
14151435
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");

common/common.h

+13-11
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,20 @@ struct gpt_params {
8888
// // sampling parameters
8989
struct llama_sampling_params sparams;
9090

91-
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92-
std::string model_url = ""; // model url to download
93-
std::string model_draft = ""; // draft model for speculative decoding
94-
std::string model_alias = "unknown"; // model alias
95-
std::string prompt = "";
96-
std::string prompt_file = ""; // store the external prompt file name
97-
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
98-
std::string input_prefix = ""; // string to prefix user inputs with
99-
std::string input_suffix = ""; // string to suffix user inputs with
91+
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92+
std::string model_url = ""; // model url to download
93+
std::string model_draft = ""; // draft model for speculative decoding
94+
std::string model_alias = "unknown"; // model alias
95+
std::string prompt = "";
96+
std::string prompt_file = ""; // store the external prompt file name
97+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
98+
std::string input_prefix = ""; // string to prefix user inputs with
99+
std::string input_suffix = ""; // string to suffix user inputs with
100100
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
101-
std::string logdir = ""; // directory in which to save YAML log files
102-
std::string logits_file = ""; // file for saving *all* logits
101+
std::string logdir = ""; // directory in which to save YAML log files
102+
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding
103+
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
104+
std::string logits_file = ""; // file for saving *all* logits
103105

104106
std::vector<llama_model_kv_override> kv_overrides;
105107

examples/lookup/CMakeLists.txt

+18
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,21 @@ add_executable(${TARGET} lookup.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
44
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6+
7+
set(TARGET lookup-create)
8+
add_executable(${TARGET} lookup-create.cpp)
9+
install(TARGETS ${TARGET} RUNTIME)
10+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
11+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
12+
13+
set(TARGET lookup-merge)
14+
add_executable(${TARGET} lookup-merge.cpp)
15+
install(TARGETS ${TARGET} RUNTIME)
16+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
17+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
18+
19+
set(TARGET lookup-stats)
20+
add_executable(${TARGET} lookup-stats.cpp)
21+
install(TARGETS ${TARGET} RUNTIME)
22+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
23+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/lookup/lookup-create.cpp

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "ggml.h"
2+
#include "llama.h"
3+
#include "common.h"
4+
#include "ngram-cache.h"
5+
6+
#include <cstdint>
7+
#include <fstream>
8+
#include <iostream>
9+
#include <string>
10+
#include <unordered_map>
11+
#include <vector>
12+
13+
int main(int argc, char ** argv){
14+
gpt_params params;
15+
16+
if (!gpt_params_parse(argc, argv, params)) {
17+
return 1;
18+
}
19+
// init llama.cpp
20+
llama_backend_init();
21+
llama_numa_init(params.numa);
22+
23+
llama_model * model = NULL;
24+
llama_context * ctx = NULL;
25+
26+
// load the model
27+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
28+
GGML_ASSERT(model != nullptr);
29+
30+
// tokenize the prompt
31+
const bool add_bos = llama_should_add_bos_token(model);
32+
33+
std::vector<llama_token> inp;
34+
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
35+
fprintf(stderr, "%s: tokenization done\n", __func__);
36+
37+
38+
llama_ngram_cache ngram_cache;
39+
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
40+
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
41+
42+
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);
43+
}

examples/lookup/lookup-merge.cpp

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "ggml.h"
2+
#include "llama.h"
3+
#include "common.h"
4+
#include "ngram-cache.h"
5+
6+
#include <cstdint>
7+
#include <cstdio>
8+
#include <fstream>
9+
#include <iostream>
10+
#include <string>
11+
#include <unordered_map>
12+
#include <vector>
13+
14+
static void print_usage() {
15+
fprintf(stderr, "Merges multiple lookup cache files into a single one.\n");
16+
fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n");
17+
}
18+
19+
int main(int argc, char ** argv){
20+
if (argc < 3) {
21+
print_usage();
22+
exit(1);
23+
}
24+
25+
std::vector<std::string> args;
26+
args.resize(argc-1);
27+
for (int i = 0; i < argc-1; ++i) {
28+
args[i] = argv[i+1];
29+
if (args[i] == "-h" || args[i] == "--help") {
30+
print_usage();
31+
exit(0);
32+
}
33+
}
34+
35+
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
36+
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
37+
38+
for (size_t i = 1; i < args.size()-1; ++i) {
39+
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
40+
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
41+
42+
llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
43+
}
44+
45+
fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
46+
llama_ngram_cache_save(ngram_cache_merged, args.back());
47+
}

examples/lookup/lookup-stats.cpp

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#include "ggml.h"
2+
#include "common.h"
3+
#include "llama.h"
4+
#include "log.h"
5+
#include "ngram-cache.h"
6+
7+
#include <cmath>
8+
#include <cstdint>
9+
#include <cstdio>
10+
#include <fstream>
11+
#include <string>
12+
#include <vector>
13+
#include <unordered_map>
14+
15+
int main(int argc, char ** argv){
16+
gpt_params params;
17+
18+
if (!gpt_params_parse(argc, argv, params)) {
19+
return 1;
20+
}
21+
22+
const int n_draft = params.n_draft;
23+
24+
// init llama.cpp
25+
llama_backend_init();
26+
llama_numa_init(params.numa);
27+
28+
llama_model * model = NULL;
29+
llama_context * ctx = NULL;
30+
31+
// load the model
32+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
33+
llama_set_rng_seed(ctx, params.seed);
34+
GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
35+
36+
// tokenize the prompt
37+
const bool add_bos = llama_should_add_bos_token(model);
38+
LOG("add_bos tgt: %d\n", add_bos);
39+
40+
std::vector<llama_token> inp;
41+
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
42+
43+
llama_ngram_cache ngram_cache_context;
44+
llama_ngram_cache ngram_cache_dynamic;
45+
llama_ngram_cache ngram_cache_static;
46+
int64_t t_draft_flat_us = 0;
47+
int64_t t_draft_us = 0;
48+
49+
{
50+
const int64_t t_start_draft_us = ggml_time_us();
51+
52+
if (!params.lookup_cache_static.empty()) {
53+
try {
54+
ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
55+
} catch (std::system_error const &) {
56+
fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
57+
exit(1);
58+
}
59+
}
60+
61+
if (!params.lookup_cache_dynamic.empty()) {
62+
try {
63+
ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
64+
} catch (std::system_error const &) {} // if the file does not exist it will simply be created at the end of the program
65+
}
66+
67+
t_draft_flat_us += ggml_time_us() - t_start_draft_us;
68+
}
69+
70+
const int n_input = inp.size();
71+
const int n_ctx = params.n_ctx;
72+
73+
int n_drafted = 0;
74+
int n_accept = 0;
75+
76+
const int64_t t_start_ms = ggml_time_ms();
77+
78+
// Iterate over input tokens in chunks of size n_ctx.
79+
// Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility.
80+
for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) {
81+
const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx);
82+
std::vector<llama_token> pseudo_output;
83+
pseudo_output.push_back(inp_slice[0]);
84+
85+
while ((int) pseudo_output.size() < n_ctx) {
86+
// Simulate drafting and decoding from draft:
87+
std::vector<llama_token> draft;
88+
draft.push_back(pseudo_output.back());
89+
90+
{
91+
const int64_t t_start_draft_us = ggml_time_us();
92+
llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
93+
t_draft_us += ggml_time_us() - t_start_draft_us;
94+
}
95+
96+
n_drafted += draft.size() - 1;
97+
98+
for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
99+
const llama_token ground_truth = inp_slice[pseudo_output.size()];
100+
const llama_token drafted = draft[j];
101+
102+
if (ground_truth != drafted) {
103+
break;
104+
}
105+
106+
++n_accept;
107+
pseudo_output.push_back(ground_truth);
108+
109+
{
110+
const int64_t t_start_draft_us = ggml_time_us();
111+
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
112+
t_draft_us += ggml_time_us() - t_start_draft_us;
113+
}
114+
}
115+
116+
// After each simulated batch decoding simulate the sampling of a single token:
117+
if ((int) pseudo_output.size() < n_ctx) {
118+
pseudo_output.push_back(inp_slice[pseudo_output.size()]);
119+
{
120+
const int64_t t_start_draft_us = ggml_time_us();
121+
llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
122+
t_draft_us += ggml_time_us() - t_start_draft_us;
123+
}
124+
}
125+
126+
draft.erase(draft.begin());
127+
128+
}
129+
if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
130+
const int64_t t_now_ms = ggml_time_ms();
131+
const int64_t eta_ms = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start;
132+
const int64_t eta_min = eta_ms / (60*1000);
133+
const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
134+
135+
LOG_TEE("%d/%d done, ETA: %02ld:%02ld\n", i_start, n_input, eta_min, eta_s);
136+
}
137+
138+
// After each chunk, update the dynamic ngram cache with the context ngram cache:
139+
llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
140+
ngram_cache_context.clear();
141+
}
142+
143+
LOG_TEE("\n");
144+
145+
LOG_TEE("\n");
146+
LOG_TEE("n_draft = %d\n", n_draft);
147+
LOG_TEE("n_predict = %d\n", n_input - n_input % n_ctx);
148+
LOG_TEE("n_drafted = %d\n", n_drafted);
149+
LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
150+
LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
151+
t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
152+
LOG_TEE("n_accept = %d\n", n_accept);
153+
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
154+
155+
llama_free(ctx);
156+
llama_free_model(model);
157+
158+
llama_backend_free();
159+
160+
fprintf(stderr, "\n\n");
161+
162+
return 0;
163+
}

0 commit comments

Comments
 (0)