Skip to content

Commit b56cd97

Browse files
committed
[V1][Spec Decoding] Add scheduler test for num_drafts and num_accepted_tokens_per_pos
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 4a62abd commit b56cd97

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

tests/v1/core/test_scheduler.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import pytest
55
import torch
66

7-
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
7+
from vllm.config import (CacheConfig, ModelConfig, SchedulerConfig,
8+
SpeculativeConfig, VllmConfig)
89
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
910
from vllm.sampling_params import SamplingParams
1011
from vllm.v1.core.sched.output import SchedulerOutput
@@ -25,6 +26,7 @@ def create_scheduler(
2526
enable_prefix_caching: Optional[bool] = None,
2627
long_prefill_token_threshold: int = 0,
2728
disable_chunked_mm_input: bool = False,
29+
num_speculative_tokens: Optional[int] = None,
2830
) -> Scheduler:
2931
'''Create scheduler under test.
3032
@@ -80,12 +82,17 @@ def create_scheduler(
8082
],
8183
)
8284
cache_config.num_gpu_blocks = 10000
85+
speculative_config: Optional[SpeculativeConfig] = None
86+
if num_speculative_tokens is not None:
87+
speculative_config = SpeculativeConfig(
88+
model="ngram", num_speculative_tokens=num_speculative_tokens)
8389
return Scheduler(
8490
scheduler_config,
8591
model_config,
8692
cache_config,
8793
lora_config=None,
8894
kv_cache_config=kv_cache_config,
95+
speculative_config=speculative_config,
8996
log_stats=True,
9097
structured_output_manager=StructuredOutputManager(vllm_config),
9198
)
@@ -671,13 +678,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
671678
@pytest.mark.parametrize(
672679
"spec_tokens,output_tokens,expected",
673680
[
674-
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
675-
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
676-
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
677-
([[1]], [[1, 2]], (1, 1)), # single token sequence
678-
([[]], [[5]], (0, 0)), # empty sequence
681+
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
682+
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
683+
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
684+
(2, 3, 3, [2, 1])), # multiple sequences
685+
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
686+
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
679687
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
680-
(6, 3)), # multiple mismatches
688+
(2, 6, 3, [2, 1, 0])), # multiple mismatches
681689
])
682690
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
683691
"""Test scheduling behavior with speculative decoding.
@@ -686,7 +694,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
686694
1. Speculated tokens get scheduled correctly
687695
2. Spec decoding stats properly count number of draft and accepted tokens
688696
"""
689-
scheduler = create_scheduler()
697+
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
698+
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
690699
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
691700
req_ids = []
692701
req_to_index = {}
@@ -759,5 +768,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
759768
else:
760769
assert scheduler_stats.spec_decoding_stats is not None
761770
stats = scheduler_stats.spec_decoding_stats
762-
assert stats.num_draft_tokens == expected[0]
763-
assert stats.num_accepted_tokens == expected[1]
771+
assert stats.num_drafts == expected[0]
772+
assert stats.num_draft_tokens == expected[1]
773+
assert stats.num_accepted_tokens == expected[2]
774+
assert stats.num_accepted_tokens_per_pos == expected[3]

0 commit comments

Comments
 (0)