Skip to content

Commit 0b32f64

Browse files
committed
Optimize speculative decoding performance of llama-server
1 parent 0770ec9 commit 0b32f64

File tree

1 file changed

+84
-96
lines changed

1 file changed

+84
-96
lines changed

examples/server/server.cpp

+84-96
Original file line numberDiff line numberDiff line change
@@ -2927,6 +2927,90 @@ struct server_context {
29272927
continue;
29282928
}
29292929

2930+
if (slot.state == SLOT_STATE_GENERATING && slot.is_processing() && slot.can_speculate()) {
2931+
// determine the max draft that fits the current slot state
2932+
int n_draft_max = slot.params.speculative.n_max;
2933+
2934+
// note: n_past is not yet increased for the `id` token sampled above
2935+
// also, need to leave space for 1 extra token to allow context shifts
2936+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
2937+
2938+
if (slot.n_remaining > 0) {
2939+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
2940+
}
2941+
2942+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
2943+
2944+
if (n_draft_max >= slot.params.speculative.n_min) {
2945+
llama_token id = slot.sampled;
2946+
2947+
struct common_speculative_params params_spec;
2948+
params_spec.n_draft = n_draft_max;
2949+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
2950+
params_spec.p_min = slot.params.speculative.p_min;
2951+
2952+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
2953+
2954+
// keep track of total number of tokens generated in the draft
2955+
slot.n_draft_total += draft.size();
2956+
2957+
// ignore small drafts
2958+
if (slot.params.speculative.n_min <= (int) draft.size()) {
2959+
// construct the speculation batch
2960+
common_batch_clear(slot.batch_spec);
2961+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
2962+
2963+
for (size_t i = 0; i < draft.size(); ++i) {
2964+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
2965+
}
2966+
2967+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
2968+
2969+
llama_decode(ctx, slot.batch_spec);
2970+
2971+
// the accepted tokens from the speculation
2972+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
2973+
2974+
slot.n_past += ids.size();
2975+
slot.n_decoded += ids.size();
2976+
2977+
// update how many tokens out of draft was accepted
2978+
slot.n_draft_accepted += ids.size() - 1;
2979+
2980+
slot.cache_tokens.push_back(id);
2981+
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
2982+
2983+
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
2984+
2985+
for (size_t i = 0; i < ids.size(); ++i) {
2986+
completion_token_output result;
2987+
2988+
result.tok = ids[i];
2989+
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
2990+
result.prob = 1.0f; // set later
2991+
2992+
// TODO: set result.probs
2993+
2994+
if (!process_token(result, slot)) {
2995+
// release slot because of stop condition
2996+
slot.release();
2997+
slot.print_timings();
2998+
send_final_response(slot);
2999+
metrics.on_prediction(slot);
3000+
break;
3001+
}
3002+
}
3003+
3004+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3005+
continue;
3006+
}
3007+
3008+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3009+
} else {
3010+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3011+
}
3012+
}
3013+
29303014
// check if we can batch this slot with the previous one
29313015
if (!slot_batched) {
29323016
slot_batched = &slot;
@@ -3300,102 +3384,6 @@ struct server_context {
33003384
continue;
33013385
}
33023386
}
3303-
3304-
// do speculative decoding
3305-
for (auto & slot : slots) {
3306-
if (!slot.is_processing() || !slot.can_speculate()) {
3307-
continue;
3308-
}
3309-
3310-
if (slot.state != SLOT_STATE_GENERATING) {
3311-
continue;
3312-
}
3313-
3314-
// determine the max draft that fits the current slot state
3315-
int n_draft_max = slot.params.speculative.n_max;
3316-
3317-
// note: n_past is not yet increased for the `id` token sampled above
3318-
// also, need to leave space for 1 extra token to allow context shifts
3319-
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
3320-
3321-
if (slot.n_remaining > 0) {
3322-
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
3323-
}
3324-
3325-
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
3326-
3327-
if (n_draft_max < slot.params.speculative.n_min) {
3328-
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
3329-
3330-
continue;
3331-
}
3332-
3333-
llama_token id = slot.sampled;
3334-
3335-
struct common_speculative_params params_spec;
3336-
params_spec.n_draft = n_draft_max;
3337-
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
3338-
params_spec.p_min = slot.params.speculative.p_min;
3339-
3340-
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
3341-
3342-
// keep track of total number of tokens generated in the draft
3343-
slot.n_draft_total += draft.size();
3344-
3345-
// ignore small drafts
3346-
if (slot.params.speculative.n_min > (int) draft.size()) {
3347-
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
3348-
3349-
continue;
3350-
}
3351-
3352-
// construct the speculation batch
3353-
common_batch_clear(slot.batch_spec);
3354-
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
3355-
3356-
for (size_t i = 0; i < draft.size(); ++i) {
3357-
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
3358-
}
3359-
3360-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3361-
3362-
llama_decode(ctx, slot.batch_spec);
3363-
3364-
// the accepted tokens from the speculation
3365-
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3366-
3367-
slot.n_past += ids.size();
3368-
slot.n_decoded += ids.size();
3369-
3370-
// update how many tokens out of draft was accepted
3371-
slot.n_draft_accepted += ids.size() - 1;
3372-
3373-
slot.cache_tokens.push_back(id);
3374-
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
3375-
3376-
llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1);
3377-
3378-
for (size_t i = 0; i < ids.size(); ++i) {
3379-
completion_token_output result;
3380-
3381-
result.tok = ids[i];
3382-
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
3383-
result.prob = 1.0f; // set later
3384-
3385-
// TODO: set result.probs
3386-
3387-
if (!process_token(result, slot)) {
3388-
// release slot because of stop condition
3389-
slot.release();
3390-
slot.print_timings();
3391-
send_final_response(slot);
3392-
metrics.on_prediction(slot);
3393-
break;
3394-
}
3395-
}
3396-
3397-
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
3398-
}
33993387
}
34003388

34013389
SRV_DBG("%s", "run slots completed\n");

0 commit comments

Comments
 (0)