@@ -26,19 +26,29 @@ int main(int argc, char ** argv) {
26
26
params.prompt = " Hello my name is" ;
27
27
}
28
28
29
- // total length of the sequence including the prompt
30
- const int n_len = 32 ;
29
+
30
+ llama_model_params model_params = llama_model_default_params ();
31
+
32
+ // how many fake tokens to add on each iteration. 0 is no-op.
33
+ int mock_tokens = 0 ;
34
+ model_params.n_gpu_layers = 0 ; // 99; CPU vs GPU
35
+ if (argc >= 4 ) {
36
+ mock_tokens = atoi (argv[3 ]);
37
+ }
38
+
39
+ if (argc >= 5 ) {
40
+ model_params.n_gpu_layers = atoi (argv[4 ]);
41
+ }
31
42
32
43
// init LLM
44
+ // total length of the sequence including the prompt
45
+ const int n_len = 256 ;
33
46
34
47
llama_backend_init ();
35
48
llama_numa_init (params.numa );
36
49
37
50
// initialize the model
38
51
39
- llama_model_params model_params = llama_model_default_params ();
40
-
41
- // model_params.n_gpu_layers = 99; // offload all layers to the GPU
42
52
43
53
llama_model * model = llama_load_model_from_file (params.model .c_str (), model_params);
44
54
@@ -115,11 +125,14 @@ int main(int argc, char ** argv) {
115
125
116
126
const auto t_main_start = ggml_time_us ();
117
127
128
+ // we'll use logits from this position to determine next token
129
+ int logit_idx = batch.n_tokens - 1 ;
130
+
118
131
while (n_cur <= n_len) {
119
132
// sample the next token
120
133
{
121
134
auto n_vocab = llama_n_vocab (model);
122
- auto * logits = llama_get_logits_ith (ctx, batch. n_tokens - 1 );
135
+ auto * logits = llama_get_logits_ith (ctx, logit_idx );
123
136
124
137
std::vector<llama_token_data> candidates;
125
138
candidates.reserve (n_vocab);
@@ -149,6 +162,12 @@ int main(int argc, char ** argv) {
149
162
// push this new token for next evaluation
150
163
llama_batch_add (batch, new_token_id, n_cur, { 0 }, true );
151
164
165
+ for (int fl = 0 ; fl < mock_tokens; fl++) {
166
+ llama_batch_add (batch, new_token_id, n_cur + 1 + fl, { 0 }, true );
167
+ }
168
+ // we still use the 'original' token to sample on next iteration
169
+ logit_idx = batch.n_tokens - mock_tokens - 1 ;
170
+
152
171
n_decode += 1 ;
153
172
}
154
173
@@ -159,6 +178,8 @@ int main(int argc, char ** argv) {
159
178
fprintf (stderr, " %s : failed to eval, return code %d\n " , __func__, 1 );
160
179
return 1 ;
161
180
}
181
+ // remove the cached entries from mock tokens
182
+ llama_kv_cache_seq_rm (ctx, 0 , n_cur, -1 );
162
183
}
163
184
164
185
LOG_TEE (" \n " );
@@ -168,7 +189,7 @@ int main(int argc, char ** argv) {
168
189
LOG_TEE (" %s: decoded %d tokens in %.2f s, speed: %.2f t/s\n " ,
169
190
__func__, n_decode, (t_main_end - t_main_start) / 1000000 .0f , n_decode / ((t_main_end - t_main_start) / 1000000 .0f ));
170
191
171
- llama_print_timings (ctx);
192
+ // llama_print_timings(ctx);
172
193
173
194
fprintf (stderr, " \n " );
174
195
0 commit comments