6
6
import torch
7
7
8
8
from vllm .config import (CacheConfig , KVTransferConfig , ModelConfig ,
9
- SchedulerConfig , VllmConfig )
9
+ SchedulerConfig , SpeculativeConfig , VllmConfig )
10
10
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
11
11
from vllm .sampling_params import SamplingParams
12
12
from vllm .v1 .core .sched .output import SchedulerOutput
@@ -31,6 +31,7 @@ def create_scheduler(
31
31
num_blocks : int = 10000 ,
32
32
block_size : int = 16 ,
33
33
max_model_len : Optional [int ] = None ,
34
+ num_speculative_tokens : Optional [int ] = None ,
34
35
) -> Scheduler :
35
36
'''Create scheduler under test.
36
37
@@ -81,11 +82,17 @@ def create_scheduler(
81
82
kv_connector_extra_config = {"shared_storage_path" : "local_storage" },
82
83
) if use_kv_connector else None
83
84
85
+ speculative_config : Optional [SpeculativeConfig ] = None
86
+ if num_speculative_tokens is not None :
87
+ speculative_config = SpeculativeConfig (
88
+ model = "ngram" , num_speculative_tokens = num_speculative_tokens )
89
+
84
90
vllm_config = VllmConfig (
85
91
scheduler_config = scheduler_config ,
86
92
model_config = model_config ,
87
93
cache_config = cache_config ,
88
94
kv_transfer_config = kv_transfer_config ,
95
+ speculative_config = speculative_config ,
89
96
)
90
97
kv_cache_config = KVCacheConfig (
91
98
num_blocks = num_blocks , # A large number of blocks to hold all requests
@@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
429
436
430
437
def test_stop_via_update_from_output ():
431
438
"""Test stopping behavior through update_from_output"""
432
- scheduler = create_scheduler ()
439
+ scheduler = create_scheduler (num_speculative_tokens = 1 )
433
440
434
441
# Test case 1: Stop on EOS token
435
442
requests = create_requests (num_requests = 2 , max_tokens = 10 )
@@ -481,7 +488,7 @@ def test_stop_via_update_from_output():
481
488
assert list (requests [1 ].output_token_ids ) == [10 , 11 ]
482
489
483
490
# Test case 2: Stop on custom stop token
484
- scheduler = create_scheduler ()
491
+ scheduler = create_scheduler (num_speculative_tokens = 2 )
485
492
requests = create_requests (num_requests = 2 ,
486
493
max_tokens = 10 ,
487
494
stop_token_ids = [42 , 43 ])
@@ -533,7 +540,7 @@ def test_stop_via_update_from_output():
533
540
assert list (requests [1 ].output_token_ids ) == [13 , 14 ]
534
541
535
542
# Test case 3: Stop on max tokens
536
- scheduler = create_scheduler ()
543
+ scheduler = create_scheduler (num_speculative_tokens = 2 )
537
544
requests = create_requests (num_requests = 2 , max_tokens = 2 )
538
545
for req in requests :
539
546
req .num_computed_tokens = req .num_tokens
@@ -583,7 +590,7 @@ def test_stop_via_update_from_output():
583
590
assert list (requests [1 ].output_token_ids ) == [13 ]
584
591
585
592
# Test case 4: Ignore EOS flag
586
- scheduler = create_scheduler ()
593
+ scheduler = create_scheduler (num_speculative_tokens = 2 )
587
594
requests = create_requests (num_requests = 1 , max_tokens = 10 )
588
595
requests [0 ].sampling_params .ignore_eos = True
589
596
requests [0 ].num_computed_tokens = requests [0 ].num_tokens
@@ -686,13 +693,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
686
693
@pytest .mark .parametrize (
687
694
"spec_tokens,output_tokens,expected" ,
688
695
[
689
- ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (3 , 3 )), # perfect match
690
- ([[1 , 2 , 3 ]], [[1 , 5 ]], (3 , 1 )), # early mismatch
691
- ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (3 , 3 )), # multiple sequences
692
- ([[1 ]], [[1 , 2 ]], (1 , 1 )), # single token sequence
693
- ([[]], [[5 ]], (0 , 0 )), # empty sequence
696
+ ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
697
+ ([[1 , 2 , 3 ]], [[1 , 5 ]], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
698
+ ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]],
699
+ (2 , 3 , 3 , [2 , 1 ])), # multiple sequences
700
+ ([[1 ]], [[1 , 2 ]], (1 , 1 , 1 , [1 ])), # single token sequence
701
+ ([[]], [[5 ]], (0 , 0 , 0 , [0 ])), # empty sequence
694
702
([[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 7 ], [4 , 8 ]],
695
- (6 , 3 )), # multiple mismatches
703
+ (2 , 6 , 3 , [ 2 , 1 , 0 ] )), # multiple mismatches
696
704
])
697
705
def test_schedule_spec_decoding_stats (spec_tokens , output_tokens , expected ):
698
706
"""Test scheduling behavior with speculative decoding.
@@ -701,7 +709,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
701
709
1. Speculated tokens get scheduled correctly
702
710
2. Spec decoding stats properly count number of draft and accepted tokens
703
711
"""
704
- scheduler = create_scheduler ()
712
+ num_spec_tokens = max (1 , max (len (t ) for t in spec_tokens ))
713
+ scheduler = create_scheduler (num_speculative_tokens = num_spec_tokens )
705
714
requests = create_requests (num_requests = len (spec_tokens ), num_tokens = 1 )
706
715
req_ids = []
707
716
req_to_index = {}
@@ -774,8 +783,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
774
783
else :
775
784
assert scheduler_stats .spec_decoding_stats is not None
776
785
stats = scheduler_stats .spec_decoding_stats
777
- assert stats .num_draft_tokens == expected [0 ]
778
- assert stats .num_accepted_tokens == expected [1 ]
786
+ assert stats .num_drafts == expected [0 ]
787
+ assert stats .num_draft_tokens == expected [1 ]
788
+ assert stats .num_accepted_tokens == expected [2 ]
789
+ assert stats .num_accepted_tokens_per_pos == expected [3 ]
779
790
780
791
781
792
def _assert_right_scheduler_output (
0 commit comments