Skip to content

Commit 63cd5b5

Browse files
committed
experiment with extra tokens
1 parent bca40e9 commit 63cd5b5

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

examples/simple/simple.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,29 @@ int main(int argc, char ** argv) {
2626
params.prompt = "Hello my name is";
2727
}
2828

29-
// total length of the sequence including the prompt
30-
const int n_len = 32;
29+
30+
llama_model_params model_params = llama_model_default_params();
31+
32+
// how many fake tokens to add on each iteration. 0 is no-op.
33+
int mock_tokens = 0;
34+
model_params.n_gpu_layers = 0; //99; CPU vs GPU
35+
if (argc >= 4) {
36+
mock_tokens = atoi(argv[3]);
37+
}
38+
39+
if (argc >= 5) {
40+
model_params.n_gpu_layers = atoi(argv[4]);
41+
}
3142

3243
// init LLM
44+
// total length of the sequence including the prompt
45+
const int n_len = 256;
3346

3447
llama_backend_init();
3548
llama_numa_init(params.numa);
3649

3750
// initialize the model
3851

39-
llama_model_params model_params = llama_model_default_params();
40-
41-
// model_params.n_gpu_layers = 99; // offload all layers to the GPU
4252

4353
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
4454

@@ -115,11 +125,14 @@ int main(int argc, char ** argv) {
115125

116126
const auto t_main_start = ggml_time_us();
117127

128+
// we'll use logits from this position to determine next token
129+
int logit_idx = batch.n_tokens - 1;
130+
118131
while (n_cur <= n_len) {
119132
// sample the next token
120133
{
121134
auto n_vocab = llama_n_vocab(model);
122-
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
135+
auto * logits = llama_get_logits_ith(ctx, logit_idx);
123136

124137
std::vector<llama_token_data> candidates;
125138
candidates.reserve(n_vocab);
@@ -149,6 +162,12 @@ int main(int argc, char ** argv) {
149162
// push this new token for next evaluation
150163
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
151164

165+
for (int fl = 0; fl < mock_tokens; fl++) {
166+
llama_batch_add(batch, new_token_id, n_cur + 1 + fl, { 0 }, true);
167+
}
168+
// we still use the 'original' token to sample on next iteration
169+
logit_idx = batch.n_tokens - mock_tokens - 1;
170+
152171
n_decode += 1;
153172
}
154173

@@ -159,6 +178,8 @@ int main(int argc, char ** argv) {
159178
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
160179
return 1;
161180
}
181+
// remove the cached entries from mock tokens
182+
llama_kv_cache_seq_rm(ctx, 0, n_cur, -1);
162183
}
163184

164185
LOG_TEE("\n");
@@ -168,7 +189,7 @@ int main(int argc, char ** argv) {
168189
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
169190
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
170191

171-
llama_print_timings(ctx);
192+
//llama_print_timings(ctx);
172193

173194
fprintf(stderr, "\n");
174195

0 commit comments

Comments
 (0)