Skip to content

Commit 2682063

Browse files
committed
[V1][Spec Decoding] Add scheduler test for num_drafts and num_accepted_tokens_per_pos
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent bf35874 commit 2682063

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

tests/v1/core/test_scheduler.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77

88
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
9-
SchedulerConfig, VllmConfig)
9+
SchedulerConfig, SpeculativeConfig, VllmConfig)
1010
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
1111
from vllm.sampling_params import SamplingParams
1212
from vllm.v1.core.sched.output import SchedulerOutput
@@ -31,6 +31,7 @@ def create_scheduler(
3131
num_blocks: int = 10000,
3232
block_size: int = 16,
3333
max_model_len: Optional[int] = None,
34+
num_speculative_tokens: Optional[int] = None,
3435
) -> Scheduler:
3536
'''Create scheduler under test.
3637
@@ -81,11 +82,17 @@ def create_scheduler(
8182
kv_connector_extra_config={"shared_storage_path": "local_storage"},
8283
) if use_kv_connector else None
8384

85+
speculative_config: Optional[SpeculativeConfig] = None
86+
if num_speculative_tokens is not None:
87+
speculative_config = SpeculativeConfig(
88+
model="ngram", num_speculative_tokens=num_speculative_tokens)
89+
8490
vllm_config = VllmConfig(
8591
scheduler_config=scheduler_config,
8692
model_config=model_config,
8793
cache_config=cache_config,
8894
kv_transfer_config=kv_transfer_config,
95+
speculative_config=speculative_config,
8996
)
9097
kv_cache_config = KVCacheConfig(
9198
num_blocks=num_blocks, # A large number of blocks to hold all requests
@@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
429436

430437
def test_stop_via_update_from_output():
431438
"""Test stopping behavior through update_from_output"""
432-
scheduler = create_scheduler()
439+
scheduler = create_scheduler(num_speculative_tokens=1)
433440

434441
# Test case 1: Stop on EOS token
435442
requests = create_requests(num_requests=2, max_tokens=10)
@@ -481,7 +488,7 @@ def test_stop_via_update_from_output():
481488
assert list(requests[1].output_token_ids) == [10, 11]
482489

483490
# Test case 2: Stop on custom stop token
484-
scheduler = create_scheduler()
491+
scheduler = create_scheduler(num_speculative_tokens=2)
485492
requests = create_requests(num_requests=2,
486493
max_tokens=10,
487494
stop_token_ids=[42, 43])
@@ -533,7 +540,7 @@ def test_stop_via_update_from_output():
533540
assert list(requests[1].output_token_ids) == [13, 14]
534541

535542
# Test case 3: Stop on max tokens
536-
scheduler = create_scheduler()
543+
scheduler = create_scheduler(num_speculative_tokens=2)
537544
requests = create_requests(num_requests=2, max_tokens=2)
538545
for req in requests:
539546
req.num_computed_tokens = req.num_tokens
@@ -583,7 +590,7 @@ def test_stop_via_update_from_output():
583590
assert list(requests[1].output_token_ids) == [13]
584591

585592
# Test case 4: Ignore EOS flag
586-
scheduler = create_scheduler()
593+
scheduler = create_scheduler(num_speculative_tokens=2)
587594
requests = create_requests(num_requests=1, max_tokens=10)
588595
requests[0].sampling_params.ignore_eos = True
589596
requests[0].num_computed_tokens = requests[0].num_tokens
@@ -686,13 +693,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
686693
@pytest.mark.parametrize(
687694
"spec_tokens,output_tokens,expected",
688695
[
689-
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
690-
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
691-
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
692-
([[1]], [[1, 2]], (1, 1)), # single token sequence
693-
([[]], [[5]], (0, 0)), # empty sequence
696+
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
697+
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
698+
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
699+
(2, 3, 3, [2, 1])), # multiple sequences
700+
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
701+
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
694702
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
695-
(6, 3)), # multiple mismatches
703+
(2, 6, 3, [2, 1, 0])), # multiple mismatches
696704
])
697705
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
698706
"""Test scheduling behavior with speculative decoding.
@@ -701,7 +709,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
701709
1. Speculated tokens get scheduled correctly
702710
2. Spec decoding stats properly count number of draft and accepted tokens
703711
"""
704-
scheduler = create_scheduler()
712+
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
713+
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
705714
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
706715
req_ids = []
707716
req_to_index = {}
@@ -774,8 +783,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
774783
else:
775784
assert scheduler_stats.spec_decoding_stats is not None
776785
stats = scheduler_stats.spec_decoding_stats
777-
assert stats.num_draft_tokens == expected[0]
778-
assert stats.num_accepted_tokens == expected[1]
786+
assert stats.num_drafts == expected[0]
787+
assert stats.num_draft_tokens == expected[1]
788+
assert stats.num_accepted_tokens == expected[2]
789+
assert stats.num_accepted_tokens_per_pos == expected[3]
779790

780791

781792
def _assert_right_scheduler_output(

0 commit comments

Comments
 (0)