6
6
import torch
7
7
8
8
from vllm .config import (CacheConfig , KVTransferConfig , ModelConfig ,
9
- SchedulerConfig , VllmConfig )
9
+ SchedulerConfig , SpeculativeConfig , VllmConfig )
10
10
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
11
11
from vllm .sampling_params import SamplingParams
12
12
from vllm .v1 .core .sched .output import SchedulerOutput
@@ -31,6 +31,7 @@ def create_scheduler(
31
31
num_blocks : int = 10000 ,
32
32
block_size : int = 16 ,
33
33
max_model_len : Optional [int ] = None ,
34
+ num_speculative_tokens : Optional [int ] = None ,
34
35
) -> Scheduler :
35
36
'''Create scheduler under test.
36
37
@@ -81,11 +82,17 @@ def create_scheduler(
81
82
kv_connector_extra_config = {"shared_storage_path" : "local_storage" },
82
83
) if use_kv_connector else None
83
84
85
+ speculative_config : Optional [SpeculativeConfig ] = None
86
+ if num_speculative_tokens is not None :
87
+ speculative_config = SpeculativeConfig (
88
+ model = "ngram" , num_speculative_tokens = num_speculative_tokens )
89
+
84
90
vllm_config = VllmConfig (
85
91
scheduler_config = scheduler_config ,
86
92
model_config = model_config ,
87
93
cache_config = cache_config ,
88
94
kv_transfer_config = kv_transfer_config ,
95
+ speculative_config = speculative_config ,
89
96
)
90
97
kv_cache_config = KVCacheConfig (
91
98
num_blocks = num_blocks , # A large number of blocks to hold all requests
@@ -429,7 +436,7 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
429
436
430
437
def test_stop_via_update_from_output ():
431
438
"""Test stopping behavior through update_from_output"""
432
- scheduler = create_scheduler ()
439
+ scheduler = create_scheduler (num_speculative_tokens = 1 )
433
440
434
441
# Test case 1: Stop on EOS token
435
442
requests = create_requests (num_requests = 2 , max_tokens = 10 )
@@ -480,7 +487,7 @@ def test_stop_via_update_from_output():
480
487
assert list (requests [1 ].output_token_ids ) == [10 , 11 ]
481
488
482
489
# Test case 2: Stop on custom stop token
483
- scheduler = create_scheduler ()
490
+ scheduler = create_scheduler (num_speculative_tokens = 2 )
484
491
requests = create_requests (num_requests = 2 ,
485
492
max_tokens = 10 ,
486
493
stop_token_ids = [42 , 43 ])
@@ -531,7 +538,7 @@ def test_stop_via_update_from_output():
531
538
assert list (requests [1 ].output_token_ids ) == [13 , 14 ]
532
539
533
540
# Test case 3: Stop on max tokens
534
- scheduler = create_scheduler ()
541
+ scheduler = create_scheduler (num_speculative_tokens = 2 )
535
542
requests = create_requests (num_requests = 2 , max_tokens = 2 )
536
543
for req in requests :
537
544
req .num_computed_tokens = req .num_tokens
@@ -580,7 +587,7 @@ def test_stop_via_update_from_output():
580
587
assert list (requests [1 ].output_token_ids ) == [13 ]
581
588
582
589
# Test case 4: Ignore EOS flag
583
- scheduler = create_scheduler ()
590
+ scheduler = create_scheduler (num_speculative_tokens = 2 )
584
591
requests = create_requests (num_requests = 1 , max_tokens = 10 )
585
592
requests [0 ].sampling_params .ignore_eos = True
586
593
requests [0 ].num_computed_tokens = requests [0 ].num_tokens
@@ -682,13 +689,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
682
689
@pytest .mark .parametrize (
683
690
"spec_tokens,output_tokens,expected" ,
684
691
[
685
- ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (3 , 3 )), # perfect match
686
- ([[1 , 2 , 3 ]], [[1 , 5 ]], (3 , 1 )), # early mismatch
687
- ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (3 , 3 )), # multiple sequences
688
- ([[1 ]], [[1 , 2 ]], (1 , 1 )), # single token sequence
689
- ([[]], [[5 ]], (0 , 0 )), # empty sequence
692
+ ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
693
+ ([[1 , 2 , 3 ]], [[1 , 5 ]], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
694
+ ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]],
695
+ (2 , 3 , 3 , [2 , 1 ])), # multiple sequences
696
+ ([[1 ]], [[1 , 2 ]], (1 , 1 , 1 , [1 ])), # single token sequence
697
+ ([[]], [[5 ]], (0 , 0 , 0 , [0 ])), # empty sequence
690
698
([[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 7 ], [4 , 8 ]],
691
- (6 , 3 )), # multiple mismatches
699
+ (2 , 6 , 3 , [ 2 , 1 , 0 ] )), # multiple mismatches
692
700
])
693
701
def test_schedule_spec_decoding_stats (spec_tokens , output_tokens , expected ):
694
702
"""Test scheduling behavior with speculative decoding.
@@ -697,7 +705,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
697
705
1. Speculated tokens get scheduled correctly
698
706
2. Spec decoding stats properly count number of draft and accepted tokens
699
707
"""
700
- scheduler = create_scheduler ()
708
+ num_spec_tokens = max (1 , max (len (t ) for t in spec_tokens ))
709
+ scheduler = create_scheduler (num_speculative_tokens = num_spec_tokens )
701
710
requests = create_requests (num_requests = len (spec_tokens ), num_tokens = 1 )
702
711
req_ids = []
703
712
req_to_index = {}
@@ -770,8 +779,10 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
770
779
else :
771
780
assert scheduler_stats .spec_decoding_stats is not None
772
781
stats = scheduler_stats .spec_decoding_stats
773
- assert stats .num_draft_tokens == expected [0 ]
774
- assert stats .num_accepted_tokens == expected [1 ]
782
+ assert stats .num_drafts == expected [0 ]
783
+ assert stats .num_draft_tokens == expected [1 ]
784
+ assert stats .num_accepted_tokens == expected [2 ]
785
+ assert stats .num_accepted_tokens_per_pos == expected [3 ]
775
786
776
787
777
788
def _assert_right_scheduler_output (
0 commit comments