Skip to content

Commit 83988df

Browse files
committed
Use smart pointers in simple-chat
Avoid manual memory cleanups. Less memory leaks in the code now. Avoid printing multiple dots. Split code into smaller functions. Use C-style IO, rather than a mix of C++ streams and C style. No exception handling. Signed-off-by: Eric Curtin <ecurtin@redhat.com>
1 parent 1842922 commit 83988df

File tree

6 files changed

+672
-138
lines changed

6 files changed

+672
-138
lines changed

Makefile

+6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ BUILD_TARGETS = \
3434
llama-server \
3535
llama-simple \
3636
llama-simple-chat \
37+
llama-ramalama-core \
3738
llama-speculative \
3839
llama-tokenize \
3940
llama-vdot \
@@ -1382,6 +1383,11 @@ llama-infill: examples/infill/infill.cpp \
13821383
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
13831384
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
13841385

1386+
llama-ramalama-core: examples/ramalama-core/ramalama-core.cpp \
1387+
$(OBJ_ALL)
1388+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
1389+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
1390+
13851391
llama-simple: examples/simple/simple.cpp \
13861392
$(OBJ_ALL)
13871393
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)

examples/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ else()
4747
add_subdirectory(sycl)
4848
endif()
4949
add_subdirectory(save-load-state)
50+
add_subdirectory(ramalama-core)
5051
add_subdirectory(simple)
5152
add_subdirectory(simple-chat)
5253
add_subdirectory(speculative)

examples/ramalama-core/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET llama-ramalama-core)
2+
add_executable(${TARGET} ramalama-core.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/ramalama-core/README.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# llama.cpp/example/ramalama-core
2+
3+
The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file.
4+
5+
```bash
6+
./llama-ramalama-core -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048
7+
...
+356
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
#include <climits>
2+
#include <cstdio>
3+
#include <cstring>
4+
#include <memory>
5+
#include <string>
6+
#include <vector>
7+
8+
#include "llama.h"
9+
10+
// Add a message to `messages` and store its content in `owned_content`
11+
static void add_message(const std::string & role, const std::string & text, std::vector<llama_chat_message> & messages,
12+
std::vector<std::unique_ptr<char[]>> & owned_content) {
13+
auto content = std::unique_ptr<char[]>(new char[text.size() + 1]);
14+
std::strcpy(content.get(), text.c_str());
15+
messages.push_back({role.c_str(), content.get()});
16+
owned_content.push_back(std::move(content));
17+
}
18+
19+
// Function to apply the chat template and resize `formatted` if needed
20+
static int apply_chat_template(const llama_model * model, const std::vector<llama_chat_message> & messages,
21+
std::vector<char> & formatted, const bool append) {
22+
int result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(),
23+
formatted.size());
24+
if (result > static_cast<int>(formatted.size())) {
25+
formatted.resize(result);
26+
result = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), append, formatted.data(),
27+
formatted.size());
28+
}
29+
30+
return result;
31+
}
32+
33+
// Function to tokenize the prompt
34+
static int tokenize_prompt(const llama_model * model, const std::string & prompt,
35+
std::vector<llama_token> & prompt_tokens) {
36+
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true);
37+
prompt_tokens.resize(n_prompt_tokens);
38+
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) <
39+
0) {
40+
GGML_ABORT("failed to tokenize the prompt\n");
41+
}
42+
43+
return n_prompt_tokens;
44+
}
45+
46+
// Check if we have enough space in the context to evaluate this batch
47+
static int check_context_size(const llama_context * ctx, const llama_batch & batch) {
48+
const int n_ctx = llama_n_ctx(ctx);
49+
const int n_ctx_used = llama_get_kv_cache_used_cells(ctx);
50+
if (n_ctx_used + batch.n_tokens > n_ctx) {
51+
printf("\033[0m\n");
52+
fprintf(stderr, "context size exceeded\n");
53+
return 1;
54+
}
55+
56+
return 0;
57+
}
58+
59+
// convert the token to a string
60+
static int convert_token_to_string(const llama_model * model, const llama_token token_id, std::string & piece) {
61+
char buf[256];
62+
int n = llama_token_to_piece(model, token_id, buf, sizeof(buf), 0, true);
63+
if (n < 0) {
64+
GGML_ABORT("failed to convert token to piece\n");
65+
}
66+
67+
piece = std::string(buf, n);
68+
return 0;
69+
}
70+
71+
static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) {
72+
printf("%s", piece.c_str());
73+
fflush(stdout);
74+
response += piece;
75+
}
76+
77+
// helper function to evaluate a prompt and generate a response
78+
static int generate(const llama_model * model, llama_sampler * smpl, llama_context * ctx, const std::string & prompt,
79+
std::string & response) {
80+
std::vector<llama_token> prompt_tokens;
81+
const int n_prompt_tokens = tokenize_prompt(model, prompt, prompt_tokens);
82+
if (n_prompt_tokens < 0) {
83+
return 1;
84+
}
85+
86+
// prepare a batch for the prompt
87+
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
88+
llama_token new_token_id;
89+
while (true) {
90+
check_context_size(ctx, batch);
91+
if (llama_decode(ctx, batch)) {
92+
GGML_ABORT("failed to decode\n");
93+
}
94+
95+
// sample the next token, check is it an end of generation?
96+
new_token_id = llama_sampler_sample(smpl, ctx, -1);
97+
if (llama_token_is_eog(model, new_token_id)) {
98+
break;
99+
}
100+
101+
std::string piece;
102+
if (convert_token_to_string(model, new_token_id, piece)) {
103+
return 1;
104+
}
105+
106+
print_word_and_concatenate_to_response(piece, response);
107+
108+
// prepare the next batch with the sampled token
109+
batch = llama_batch_get_one(&new_token_id, 1);
110+
}
111+
112+
return 0;
113+
}
114+
115+
static void print_usage(int, const char ** argv) {
116+
printf("\nexample usage:\n");
117+
printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]);
118+
printf("\n");
119+
}
120+
121+
static int parse_int_arg(const char * arg, int & value) {
122+
char * end;
123+
long val = std::strtol(arg, &end, 10);
124+
if (*end == '\0' && val >= INT_MIN && val <= INT_MAX) {
125+
value = static_cast<int>(val);
126+
return 0;
127+
}
128+
129+
return 1;
130+
}
131+
132+
static int handle_model_path(const int argc, const char ** argv, int & i, std::string & model_path) {
133+
if (i + 1 < argc) {
134+
model_path = argv[++i];
135+
return 0;
136+
}
137+
138+
print_usage(argc, argv);
139+
return 1;
140+
}
141+
142+
static int handle_n_ctx(const int argc, const char ** argv, int & i, int & n_ctx) {
143+
if (i + 1 < argc) {
144+
if (parse_int_arg(argv[++i], n_ctx)) {
145+
return 0;
146+
} else {
147+
fprintf(stderr, "error: invalid value for -c: %s\n", argv[i]);
148+
print_usage(argc, argv);
149+
}
150+
} else {
151+
print_usage(argc, argv);
152+
}
153+
154+
return 1;
155+
}
156+
157+
static int handle_ngl(const int argc, const char ** argv, int & i, int & ngl) {
158+
if (i + 1 < argc) {
159+
if (parse_int_arg(argv[++i], ngl)) {
160+
return 0;
161+
} else {
162+
fprintf(stderr, "error: invalid value for -ngl: %s\n", argv[i]);
163+
print_usage(argc, argv);
164+
}
165+
} else {
166+
print_usage(argc, argv);
167+
}
168+
169+
return 1;
170+
}
171+
172+
static int parse_arguments(const int argc, const char ** argv, std::string & model_path, int & n_ctx, int & ngl) {
173+
for (int i = 1; i < argc; ++i) {
174+
if (strcmp(argv[i], "-m") == 0) {
175+
if (handle_model_path(argc, argv, i, model_path)) {
176+
return 1;
177+
}
178+
} else if (strcmp(argv[i], "-c") == 0) {
179+
if (handle_n_ctx(argc, argv, i, n_ctx)) {
180+
return 1;
181+
}
182+
} else if (strcmp(argv[i], "-ngl") == 0) {
183+
if (handle_ngl(argc, argv, i, ngl)) {
184+
return 1;
185+
}
186+
} else {
187+
print_usage(argc, argv);
188+
return 1;
189+
}
190+
}
191+
192+
if (model_path.empty()) {
193+
print_usage(argc, argv);
194+
return 1;
195+
}
196+
197+
return 0;
198+
}
199+
200+
static int read_user_input(std::string & user_input) {
201+
// Use unique_ptr with free as the deleter
202+
std::unique_ptr<char, decltype(&free)> buffer(nullptr, &free);
203+
204+
size_t buffer_size = 0;
205+
char * raw_buffer = nullptr;
206+
207+
// Use getline to dynamically allocate the buffer and get input
208+
const ssize_t line_size = getline(&raw_buffer, &buffer_size, stdin);
209+
210+
// Transfer ownership to unique_ptr
211+
buffer.reset(raw_buffer);
212+
213+
if (line_size > 0) {
214+
// Remove the trailing newline character if present
215+
if (buffer.get()[line_size - 1] == '\n') {
216+
buffer.get()[line_size - 1] = '\0';
217+
}
218+
219+
user_input = std::string(buffer.get());
220+
221+
return 0; // Success
222+
}
223+
224+
user_input.clear();
225+
226+
return 1; // Indicate an error or empty input
227+
}
228+
229+
// Function to generate a response based on the prompt
230+
static int generate_response(llama_model * model, llama_sampler * sampler, llama_context * context,
231+
const std::string & prompt, std::string & response) {
232+
// Set response color
233+
printf("\033[33m");
234+
if (generate(model, sampler, context, prompt, response)) {
235+
fprintf(stderr, "failed to generate response\n");
236+
return 1;
237+
}
238+
239+
// End response with color reset and newline
240+
printf("\n\033[0m");
241+
return 0;
242+
}
243+
244+
// The main chat loop where user inputs are processed and responses generated.
245+
static int chat_loop(llama_model * model, llama_sampler * sampler, llama_context * context,
246+
std::vector<llama_chat_message> & messages) {
247+
std::vector<std::unique_ptr<char[]>> owned_content;
248+
std::vector<char> formatted(llama_n_ctx(context));
249+
int prev_len = 0;
250+
251+
while (true) {
252+
// Print prompt for user input
253+
printf("\033[32m> \033[0m");
254+
std::string user;
255+
if (read_user_input(user)) {
256+
break;
257+
}
258+
259+
add_message("user", user, messages, owned_content);
260+
int new_len = apply_chat_template(model, messages, formatted, true);
261+
if (new_len < 0) {
262+
fprintf(stderr, "failed to apply the chat template\n");
263+
return 1;
264+
}
265+
266+
std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len);
267+
std::string response;
268+
if (generate_response(model, sampler, context, prompt, response)) {
269+
return 1;
270+
}
271+
272+
add_message("assistant", response, messages, owned_content);
273+
prev_len = apply_chat_template(model, messages, formatted, false);
274+
if (prev_len < 0) {
275+
fprintf(stderr, "failed to apply the chat template\n");
276+
return 1;
277+
}
278+
}
279+
280+
return 0;
281+
}
282+
283+
static void log_callback(const enum ggml_log_level level, const char * text, void *) {
284+
if (level == GGML_LOG_LEVEL_ERROR) {
285+
fprintf(stderr, "%s", text);
286+
}
287+
}
288+
289+
// Initializes the model and returns a unique pointer to it.
290+
static std::unique_ptr<llama_model, decltype(&llama_free_model)> initialize_model(const std::string & model_path,
291+
int ngl) {
292+
llama_model_params model_params = llama_model_default_params();
293+
model_params.n_gpu_layers = ngl;
294+
295+
auto model = std::unique_ptr<llama_model, decltype(&llama_free_model)>(
296+
llama_load_model_from_file(model_path.c_str(), model_params), llama_free_model);
297+
if (!model) {
298+
fprintf(stderr, "%s: error: unable to load model\n", __func__);
299+
}
300+
301+
return model;
302+
}
303+
304+
// Initializes the context with the specified parameters.
305+
static std::unique_ptr<llama_context, decltype(&llama_free)> initialize_context(llama_model * model, int n_ctx) {
306+
llama_context_params ctx_params = llama_context_default_params();
307+
ctx_params.n_ctx = n_ctx;
308+
ctx_params.n_batch = n_ctx;
309+
310+
auto context = std::unique_ptr<llama_context, decltype(&llama_free)>(
311+
llama_new_context_with_model(model, ctx_params), llama_free);
312+
if (!context) {
313+
fprintf(stderr, "%s: error: failed to create the llama_context\n", __func__);
314+
}
315+
316+
return context;
317+
}
318+
319+
// Initializes and configures the sampler.
320+
static std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)> initialize_sampler() {
321+
auto sampler = std::unique_ptr<llama_sampler, decltype(&llama_sampler_free)>(
322+
llama_sampler_chain_init(llama_sampler_chain_default_params()), llama_sampler_free);
323+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1));
324+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(0.8f));
325+
llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
326+
327+
return sampler;
328+
}
329+
330+
int main(int argc, const char ** argv) {
331+
std::string model_path;
332+
int ngl = 99;
333+
int n_ctx = 2048;
334+
if (parse_arguments(argc, argv, model_path, n_ctx, ngl)) {
335+
return 1;
336+
}
337+
338+
llama_log_set(log_callback, nullptr);
339+
auto model = initialize_model(model_path, ngl);
340+
if (!model) {
341+
return 1;
342+
}
343+
344+
auto context = initialize_context(model.get(), n_ctx);
345+
if (!context) {
346+
return 1;
347+
}
348+
349+
auto sampler = initialize_sampler();
350+
std::vector<llama_chat_message> messages;
351+
if (chat_loop(model.get(), sampler.get(), context.get(), messages)) {
352+
return 1;
353+
}
354+
355+
return 0;
356+
}

0 commit comments

Comments
 (0)