Skip to content

Commit feb3bb4

Browse files
markmcmawong-amd
authored andcommitted
[V1][Spec Decoding] Include bonus tokens in mean acceptance length (vllm-project#17908)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 01d5d45 commit feb3bb4

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

examples/offline_inference/eagle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def main():
118118
acceptance_counts[step] += count
119119

120120
print("-" * 50)
121-
print(f"mean acceptance length: \
122-
{sum(acceptance_counts) / acceptance_counts[0]:.2f}")
121+
print(f"mean acceptance length (including bonus tokens): \
122+
{1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}")
123123
print("-" * 50)
124124

125125
# print acceptance at each token position

vllm/v1/spec_decode/metrics.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def log(self, log_fn=logger.info):
7373

7474
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
7575
100 if num_draft_tokens > 0 else float("nan"))
76-
mean_acceptance_length = (num_accepted_tokens / num_drafts)
76+
77+
# Conventionally, mean acceptance length includes the bonus token
78+
mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts)
7779

7880
pos_matrix = np.array(self.accepted_tokens_per_pos_lists)
7981
acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts
@@ -103,10 +105,12 @@ class SpecDecodingProm:
103105
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
104106
rate(vllm:spec_decode_num_draft_tokens_total[$interval])
105107
106-
The mean acceptance length can be calculated using:
108+
The mean acceptance length (conventionally including bonus tokens)
109+
can be calculated using:
107110
111+
1 + (
108112
rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
109-
rate(vllm:spec_decode_num_drafts[$interval])
113+
rate(vllm:spec_decode_num_drafts[$interval]))
110114
111115
A per-position acceptance rate vector can be computed using
112116

0 commit comments

Comments
 (0)