@@ -2927,6 +2927,90 @@ struct server_context {
2927
2927
continue ;
2928
2928
}
2929
2929
2930
+ if (slot.state == SLOT_STATE_GENERATING && slot.is_processing () && slot.can_speculate ()) {
2931
+ // determine the max draft that fits the current slot state
2932
+ int n_draft_max = slot.params .speculative .n_max ;
2933
+
2934
+ // note: n_past is not yet increased for the `id` token sampled above
2935
+ // also, need to leave space for 1 extra token to allow context shifts
2936
+ n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
2937
+
2938
+ if (slot.n_remaining > 0 ) {
2939
+ n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
2940
+ }
2941
+
2942
+ SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
2943
+
2944
+ if (n_draft_max >= slot.params .speculative .n_min ) {
2945
+ llama_token id = slot.sampled ;
2946
+
2947
+ struct common_speculative_params params_spec;
2948
+ params_spec.n_draft = n_draft_max;
2949
+ params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
2950
+ params_spec.p_min = slot.params .speculative .p_min ;
2951
+
2952
+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2953
+
2954
+ // keep track of total number of tokens generated in the draft
2955
+ slot.n_draft_total += draft.size ();
2956
+
2957
+ // ignore small drafts
2958
+ if (slot.params .speculative .n_min <= (int ) draft.size ()) {
2959
+ // construct the speculation batch
2960
+ common_batch_clear (slot.batch_spec );
2961
+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
2962
+
2963
+ for (size_t i = 0 ; i < draft.size (); ++i) {
2964
+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
2965
+ }
2966
+
2967
+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
2968
+
2969
+ llama_decode (ctx, slot.batch_spec );
2970
+
2971
+ // the accepted tokens from the speculation
2972
+ const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
2973
+
2974
+ slot.n_past += ids.size ();
2975
+ slot.n_decoded += ids.size ();
2976
+
2977
+ // update how many tokens out of draft was accepted
2978
+ slot.n_draft_accepted += ids.size () - 1 ;
2979
+
2980
+ slot.cache_tokens .push_back (id);
2981
+ slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
2982
+
2983
+ llama_kv_self_seq_rm (ctx, slot.id , slot.n_past , -1 );
2984
+
2985
+ for (size_t i = 0 ; i < ids.size (); ++i) {
2986
+ completion_token_output result;
2987
+
2988
+ result.tok = ids[i];
2989
+ result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
2990
+ result.prob = 1 .0f ; // set later
2991
+
2992
+ // TODO: set result.probs
2993
+
2994
+ if (!process_token (result, slot)) {
2995
+ // release slot because of stop condition
2996
+ slot.release ();
2997
+ slot.print_timings ();
2998
+ send_final_response (slot);
2999
+ metrics.on_prediction (slot);
3000
+ break ;
3001
+ }
3002
+ }
3003
+
3004
+ SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3005
+ continue ;
3006
+ }
3007
+
3008
+ SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3009
+ } else {
3010
+ SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3011
+ }
3012
+ }
3013
+
2930
3014
// check if we can batch this slot with the previous one
2931
3015
if (!slot_batched) {
2932
3016
slot_batched = &slot;
@@ -3300,102 +3384,6 @@ struct server_context {
3300
3384
continue ;
3301
3385
}
3302
3386
}
3303
-
3304
- // do speculative decoding
3305
- for (auto & slot : slots) {
3306
- if (!slot.is_processing () || !slot.can_speculate ()) {
3307
- continue ;
3308
- }
3309
-
3310
- if (slot.state != SLOT_STATE_GENERATING) {
3311
- continue ;
3312
- }
3313
-
3314
- // determine the max draft that fits the current slot state
3315
- int n_draft_max = slot.params .speculative .n_max ;
3316
-
3317
- // note: n_past is not yet increased for the `id` token sampled above
3318
- // also, need to leave space for 1 extra token to allow context shifts
3319
- n_draft_max = std::min (n_draft_max, slot.n_ctx - slot.n_past - 2 );
3320
-
3321
- if (slot.n_remaining > 0 ) {
3322
- n_draft_max = std::min (n_draft_max, slot.n_remaining - 1 );
3323
- }
3324
-
3325
- SLT_DBG (slot, " max possible draft: %d\n " , n_draft_max);
3326
-
3327
- if (n_draft_max < slot.params .speculative .n_min ) {
3328
- SLT_DBG (slot, " the max possible draft is too small: %d < %d - skipping speculative decoding\n " , n_draft_max, slot.params .speculative .n_min );
3329
-
3330
- continue ;
3331
- }
3332
-
3333
- llama_token id = slot.sampled ;
3334
-
3335
- struct common_speculative_params params_spec;
3336
- params_spec.n_draft = n_draft_max;
3337
- params_spec.n_reuse = llama_n_ctx (slot.ctx_dft ) - slot.params .speculative .n_max ;
3338
- params_spec.p_min = slot.params .speculative .p_min ;
3339
-
3340
- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
3341
-
3342
- // keep track of total number of tokens generated in the draft
3343
- slot.n_draft_total += draft.size ();
3344
-
3345
- // ignore small drafts
3346
- if (slot.params .speculative .n_min > (int ) draft.size ()) {
3347
- SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
3348
-
3349
- continue ;
3350
- }
3351
-
3352
- // construct the speculation batch
3353
- common_batch_clear (slot.batch_spec );
3354
- common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
3355
-
3356
- for (size_t i = 0 ; i < draft.size (); ++i) {
3357
- common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
3358
- }
3359
-
3360
- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .n_tokens );
3361
-
3362
- llama_decode (ctx, slot.batch_spec );
3363
-
3364
- // the accepted tokens from the speculation
3365
- const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
3366
-
3367
- slot.n_past += ids.size ();
3368
- slot.n_decoded += ids.size ();
3369
-
3370
- // update how many tokens out of draft was accepted
3371
- slot.n_draft_accepted += ids.size () - 1 ;
3372
-
3373
- slot.cache_tokens .push_back (id);
3374
- slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
3375
-
3376
- llama_kv_self_seq_rm (ctx, slot.id , slot.n_past , -1 );
3377
-
3378
- for (size_t i = 0 ; i < ids.size (); ++i) {
3379
- completion_token_output result;
3380
-
3381
- result.tok = ids[i];
3382
- result.text_to_send = common_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
3383
- result.prob = 1 .0f ; // set later
3384
-
3385
- // TODO: set result.probs
3386
-
3387
- if (!process_token (result, slot)) {
3388
- // release slot because of stop condition
3389
- slot.release ();
3390
- slot.print_timings ();
3391
- send_final_response (slot);
3392
- metrics.on_prediction (slot);
3393
- break ;
3394
- }
3395
- }
3396
-
3397
- SLT_DBG (slot, " accepted %d/%d draft tokens, new n_past = %d\n " , (int ) ids.size () - 1 , (int ) draft.size (), slot.n_past );
3398
- }
3399
3387
}
3400
3388
3401
3389
SRV_DBG (" %s" , " run slots completed\n " );
0 commit comments