Skip to content

Commit 86ec439

Browse files
markmcMu Huai
authored and
Mu Huai
committed
[V1][Spec Decoding] Add num_drafts and num_accepted_tokens_per_position metrics (vllm-project#16665)
Signed-off-by: Mark McLoughlin <markmc@redhat.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent c41536f commit 86ec439

File tree

4 files changed

+158
-60
lines changed

4 files changed

+158
-60
lines changed

tests/v1/core/test_scheduler.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
9-
SchedulerConfig, VllmConfig)
9+
SchedulerConfig, SpeculativeConfig, VllmConfig)
1010
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
1111
from vllm.sampling_params import SamplingParams
1212
from vllm.v1.core.sched.output import SchedulerOutput
@@ -31,6 +31,7 @@ def create_scheduler(
3131
num_blocks: int = 10000,
3232
block_size: int = 16,
3333
max_model_len: Optional[int] = None,
34+
num_speculative_tokens: Optional[int] = None,
3435
) -> Scheduler:
3536
'''Create scheduler under test.
3637
@@ -81,11 +82,17 @@ def create_scheduler(
8182
kv_connector_extra_config={"shared_storage_path": "local_storage"},
8283
) if use_kv_connector else None
8384

85+
speculative_config: Optional[SpeculativeConfig] = None
86+
if num_speculative_tokens is not None:
87+
speculative_config = SpeculativeConfig(
88+
model="ngram", num_speculative_tokens=num_speculative_tokens)
89+
8490
vllm_config = VllmConfig(
8591
scheduler_config=scheduler_config,
8692
model_config=model_config,
8793
cache_config=cache_config,
8894
kv_transfer_config=kv_transfer_config,
95+
speculative_config=speculative_config,
8996
)
9097
kv_cache_config = KVCacheConfig(
9198
num_blocks=num_blocks, # A large number of blocks to hold all requests
@@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
429436

430437
def test_stop_via_update_from_output():
431438
"""Test stopping behavior through update_from_output"""
432-
scheduler = create_scheduler()
439+
scheduler = create_scheduler(num_speculative_tokens=1)
433440

434441
# Test case 1: Stop on EOS token
435442
requests = create_requests(num_requests=2, max_tokens=10)
@@ -480,7 +487,7 @@ def test_stop_via_update_from_output():
480487
assert list(requests[1].output_token_ids) == [10, 11]
481488

482489
# Test case 2: Stop on custom stop token
483-
scheduler = create_scheduler()
490+
scheduler = create_scheduler(num_speculative_tokens=2)
484491
requests = create_requests(num_requests=2,
485492
max_tokens=10,
486493
stop_token_ids=[42, 43])
@@ -531,7 +538,7 @@ def test_stop_via_update_from_output():
531538
assert list(requests[1].output_token_ids) == [13, 14]
532539

533540
# Test case 3: Stop on max tokens
534-
scheduler = create_scheduler()
541+
scheduler = create_scheduler(num_speculative_tokens=2)
535542
requests = create_requests(num_requests=2, max_tokens=2)
536543
for req in requests:
537544
req.num_computed_tokens = req.num_tokens
@@ -580,7 +587,7 @@ def test_stop_via_update_from_output():
580587
assert list(requests[1].output_token_ids) == [13]
581588

582589
# Test case 4: Ignore EOS flag
583-
scheduler = create_scheduler()
590+
scheduler = create_scheduler(num_speculative_tokens=2)
584591
requests = create_requests(num_requests=1, max_tokens=10)
585592
requests[0].sampling_params.ignore_eos = True
586593
requests[0].num_computed_tokens = requests[0].num_tokens
@@ -682,13 +689,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
682689
@pytest.mark.parametrize(
683690
"spec_tokens,output_tokens,expected",
684691
[
685-
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
686-
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
687-
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
688-
([[1]], [[1, 2]], (1, 1)), # single token sequence
689-
([[]], [[5]], (0, 0)), # empty sequence
692+
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
693+
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
694+
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
695+
(2, 3, 3, [2, 1])), # multiple sequences
696+
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
697+
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
690698
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
691-
(6, 3)), # multiple mismatches
699+
(2, 6, 3, [2, 1, 0])), # multiple mismatches
692700
])
693701
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
694702
"""Test scheduling behavior with speculative decoding.
@@ -697,7 +705,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
697705
1. Speculated tokens get scheduled correctly
698706
2. Spec decoding stats properly count number of draft and accepted tokens
699707
"""
700-
scheduler = create_scheduler()
708+
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
709+
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
701710
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
702711
req_ids = []
703712
req_to_index = {}
@@ -770,8 +779,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
770779
else:
771780
assert scheduler_stats.spec_decoding_stats is not None
772781
stats = scheduler_stats.spec_decoding_stats
773-
assert stats.num_draft_tokens == expected[0]
774-
assert stats.num_accepted_tokens == expected[1]
782+
assert stats.num_drafts == expected[0]
783+
assert stats.num_draft_tokens == expected[1]
784+
assert stats.num_accepted_tokens == expected[2]
785+
assert stats.num_accepted_tokens_per_pos == expected[3]
775786

776787

777788
def _assert_right_scheduler_output(

vllm/v1/core/sched/scheduler.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ def __init__(
122122
self.encoder_cache_manager = EncoderCacheManager(
123123
cache_size=encoder_cache_size)
124124

125-
self.num_lookahead_tokens = 0
126125
speculative_config = vllm_config.speculative_config
127-
if speculative_config and speculative_config.method == "eagle":
128-
self.num_lookahead_tokens = \
129-
speculative_config.num_speculative_tokens
126+
self.num_spec_tokens = self.num_lookahead_tokens = 0
127+
if speculative_config:
128+
self.num_spec_tokens = speculative_config.num_speculative_tokens
129+
if speculative_config.method == "eagle":
130+
self.num_lookahead_tokens = self.num_spec_tokens
130131

131132
def schedule(self) -> SchedulerOutput:
132133
# NOTE(woosuk) on the scheduling algorithm:
@@ -824,7 +825,8 @@ def make_spec_decoding_stats(
824825
if not self.log_stats:
825826
return None
826827
if spec_decoding_stats is None:
827-
spec_decoding_stats = SpecDecodingStats()
828-
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
829-
num_accepted_tokens=num_accepted_tokens)
828+
spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens)
829+
spec_decoding_stats.observe_draft(
830+
num_draft_tokens=num_draft_tokens,
831+
num_accepted_tokens=num_accepted_tokens)
830832
return spec_decoding_stats

vllm/v1/metrics/loggers.py

+9-26
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.v1.core.kv_cache_utils import PrefixCachingMetrics
1313
from vllm.v1.engine import FinishReason
1414
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
15-
from vllm.v1.spec_decode.metrics import SpecDecodingMetrics
15+
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
1616

1717
logger = init_logger(__name__)
1818

@@ -39,7 +39,7 @@ def __init__(self, engine_index: int = 0):
3939
# Prefix cache metrics. This cannot be reset.
4040
# TODO: Make the interval configurable.
4141
self.prefix_caching_metrics = PrefixCachingMetrics()
42-
self.spec_decoding_metrics = SpecDecodingMetrics()
42+
self.spec_decoding_logging = SpecDecodingLogging()
4343
self.last_prompt_throughput: float = 0.0
4444
self.last_generation_throughput: float = 0.0
4545

@@ -70,7 +70,7 @@ def record(self, scheduler_stats: SchedulerStats,
7070
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
7171

7272
if scheduler_stats.spec_decoding_stats is not None:
73-
self.spec_decoding_metrics.observe(
73+
self.spec_decoding_logging.observe(
7474
scheduler_stats.spec_decoding_stats)
7575

7676
self.last_scheduler_stats = scheduler_stats
@@ -112,7 +112,7 @@ def log(self):
112112
)
113113

114114
if scheduler_stats.spec_decoding_stats is not None:
115-
self.spec_decoding_metrics.log(log_fn=log_fn)
115+
self.spec_decoding_logging.log(log_fn=log_fn)
116116

117117

118118
class PrometheusStatLogger(StatLoggerBase):
@@ -133,6 +133,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
133133

134134
max_model_len = vllm_config.model_config.max_model_len
135135

136+
self.spec_decoding_prom = SpecDecodingProm(
137+
vllm_config.speculative_config, labelnames, labelvalues)
138+
136139
#
137140
# Scheduler state
138141
#
@@ -323,24 +326,6 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
323326
self.labelname_running_lora_adapters,
324327
])
325328

326-
#
327-
# Speculative Decoding metrics
328-
# The acceptance rate can be calculated using a PromQL query:
329-
#
330-
# rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) /
331-
# rate(vllm:spec_decode_num_draft_tokens_total[$interval])
332-
#
333-
self.counter_spec_decode_num_draft_tokens = \
334-
prometheus_client.Counter(
335-
name="vllm:spec_decode_num_draft_tokens_total",
336-
documentation="Number of draft tokens.",
337-
labelnames=labelnames).labels(*labelvalues)
338-
self.counter_spec_decode_num_accepted_tokens = \
339-
prometheus_client.Counter(
340-
name="vllm:spec_decode_num_accepted_tokens_total",
341-
documentation="Number of accepted tokens.",
342-
labelnames=labelnames).labels(*labelvalues)
343-
344329
#
345330
# Cache config info metric
346331
#
@@ -378,10 +363,8 @@ def record(self, scheduler_stats: SchedulerStats,
378363
scheduler_stats.prefix_cache_stats.hits)
379364

380365
if scheduler_stats.spec_decoding_stats is not None:
381-
self.counter_spec_decode_num_draft_tokens.inc(
382-
scheduler_stats.spec_decoding_stats.num_draft_tokens)
383-
self.counter_spec_decode_num_accepted_tokens.inc(
384-
scheduler_stats.spec_decoding_stats.num_accepted_tokens)
366+
self.spec_decoding_prom.observe(
367+
scheduler_stats.spec_decoding_stats)
385368

386369
if iteration_stats is None:
387370
return

0 commit comments

Comments
 (0)