Skip to content

Commit 50ccaf5

Browse files
lookup: complement data from context with general text statistics (#5479)
* lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens * fixup! lookup: evaluation tools, use corpus/previous gens
1 parent 56a00f0 commit 50ccaf5

13 files changed

+774
-63
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ models-mnt
5858
/llava-cli
5959
/lookahead
6060
/lookup
61+
/lookup-create
62+
/lookup-merge
63+
/lookup-stats
6164
/main
6265
/metal
6366
/passkey

Makefile

+11-2
Original file line numberDiff line numberDiff line change
@@ -676,14 +676,17 @@ json-schema-to-grammar.o: common/json-schema-to-grammar.cpp common/json-schema-t
676676
train.o: common/train.cpp common/train.h
677677
$(CXX) $(CXXFLAGS) -c $< -o $@
678678

679+
ngram-cache.o: common/ngram-cache.cpp common/ngram-cache.h
680+
$(CXX) $(CXXFLAGS) -c $< -o $@
681+
679682
libllama.so: llama.o ggml.o $(OBJS)
680683
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
681684

682685
libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
683686
ar rcs libllama.a llama.o ggml.o $(OBJS) $(COMMON_DEPS)
684687

685688
clean:
686-
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
689+
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
687690
find examples pocs -type f -name "*.o" -delete
688691

689692
#
@@ -813,9 +816,15 @@ lookahead: examples/lookahead/lookahead.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS
813816
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
814817
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
815818

816-
lookup: examples/lookup/lookup.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
819+
lookup: examples/lookup/lookup.cpp ggml.o llama.o ngram-cache.o $(COMMON_DEPS) $(OBJS)
817820
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
818821
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
822+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-create.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp)
823+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-create.cpp) -o lookup-create $(LDFLAGS)
824+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-merge.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp)
825+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-merge.cpp) -o lookup-merge $(LDFLAGS)
826+
$(CXX) $(CXXFLAGS) -c examples/lookup/lookup-stats.cpp -o $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp)
827+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, examples/lookup/lookup-stats.cpp) -o lookup-stats $(LDFLAGS)
819828

820829
passkey: examples/passkey/passkey.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
821830
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

common/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ add_library(${TARGET} STATIC
6565
json.hpp
6666
train.h
6767
train.cpp
68+
ngram-cache.h
69+
ngram-cache.cpp
6870
)
6971

7072
if (BUILD_SHARED_LIBS)

common/common.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,22 @@ static bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg,
963963
}
964964
return true;
965965
}
966+
if (arg == "-lcs" || arg == "--lookup-cache-static") {
967+
if (++i >= argc) {
968+
invalid_param = true;
969+
return true;
970+
}
971+
params.lookup_cache_static = argv[i];
972+
return true;
973+
}
974+
if (arg == "-lcd" || arg == "--lookup-cache-dynamic") {
975+
if (++i >= argc) {
976+
invalid_param = true;
977+
return true;
978+
}
979+
params.lookup_cache_dynamic = argv[i];
980+
return true;
981+
}
966982
if (arg == "--save-all-logits" || arg == "--kl-divergence-base") {
967983
if (++i >= argc) {
968984
invalid_param = true;
@@ -1436,6 +1452,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
14361452
printf(" Hugging Face model file (default: unused)\n");
14371453
printf(" -ld LOGDIR, --logdir LOGDIR\n");
14381454
printf(" path under which to save YAML logs (no logging if unset)\n");
1455+
printf(" -lcs FNAME, --lookup-cache-static FNAME\n");
1456+
printf(" path to static lookup cache to use for lookup decoding (not updated by generation)\n");
1457+
printf(" -lcd FNAME, --lookup-cache-dynamic FNAME\n");
1458+
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
14391459
printf(" --override-kv KEY=TYPE:VALUE\n");
14401460
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
14411461
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");

common/common.h

+15-13
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,22 @@ struct gpt_params {
8888
// // sampling parameters
8989
struct llama_sampling_params sparams;
9090

91-
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92-
std::string model_draft = ""; // draft model for speculative decoding
93-
std::string model_alias = "unknown"; // model alias
94-
std::string model_url = ""; // model url to download
95-
std::string hf_repo = ""; // HF repo
96-
std::string hf_file = ""; // HF file
97-
std::string prompt = "";
98-
std::string prompt_file = ""; // store the external prompt file name
99-
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
100-
std::string input_prefix = ""; // string to prefix user inputs with
101-
std::string input_suffix = ""; // string to suffix user inputs with
91+
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
92+
std::string model_draft = ""; // draft model for speculative decoding
93+
std::string model_alias = "unknown"; // model alias
94+
std::string model_url = ""; // model url to download
95+
std::string hf_repo = ""; // HF repo
96+
std::string hf_file = ""; // HF file
97+
std::string prompt = "";
98+
std::string prompt_file = ""; // store the external prompt file name
99+
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
100+
std::string input_prefix = ""; // string to prefix user inputs with
101+
std::string input_suffix = ""; // string to suffix user inputs with
102102
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
103-
std::string logdir = ""; // directory in which to save YAML log files
104-
std::string logits_file = ""; // file for saving *all* logits
103+
std::string logdir = ""; // directory in which to save YAML log files
104+
std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding
105+
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding
106+
std::string logits_file = ""; // file for saving *all* logits
105107

106108
std::vector<llama_model_kv_override> kv_overrides;
107109

0 commit comments

Comments
 (0)