4
4
import pytest
5
5
import torch
6
6
7
- from vllm .config import CacheConfig , ModelConfig , SchedulerConfig , VllmConfig
7
+ from vllm .config import (CacheConfig , ModelConfig , SchedulerConfig ,
8
+ SpeculativeConfig , VllmConfig )
8
9
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
9
10
from vllm .sampling_params import SamplingParams
10
11
from vllm .v1 .core .sched .output import SchedulerOutput
@@ -25,6 +26,7 @@ def create_scheduler(
25
26
enable_prefix_caching : Optional [bool ] = None ,
26
27
long_prefill_token_threshold : int = 0 ,
27
28
disable_chunked_mm_input : bool = False ,
29
+ num_speculative_tokens : Optional [int ] = None ,
28
30
) -> Scheduler :
29
31
'''Create scheduler under test.
30
32
@@ -80,12 +82,17 @@ def create_scheduler(
80
82
],
81
83
)
82
84
cache_config .num_gpu_blocks = 10000
85
+ speculative_config : Optional [SpeculativeConfig ] = None
86
+ if num_speculative_tokens is not None :
87
+ speculative_config = SpeculativeConfig (
88
+ model = "ngram" , num_speculative_tokens = num_speculative_tokens )
83
89
return Scheduler (
84
90
scheduler_config ,
85
91
model_config ,
86
92
cache_config ,
87
93
lora_config = None ,
88
94
kv_cache_config = kv_cache_config ,
95
+ speculative_config = speculative_config ,
89
96
log_stats = True ,
90
97
structured_output_manager = StructuredOutputManager (vllm_config ),
91
98
)
@@ -671,13 +678,14 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
671
678
@pytest .mark .parametrize (
672
679
"spec_tokens,output_tokens,expected" ,
673
680
[
674
- ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (3 , 3 )), # perfect match
675
- ([[1 , 2 , 3 ]], [[1 , 5 ]], (3 , 1 )), # early mismatch
676
- ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (3 , 3 )), # multiple sequences
677
- ([[1 ]], [[1 , 2 ]], (1 , 1 )), # single token sequence
678
- ([[]], [[5 ]], (0 , 0 )), # empty sequence
681
+ ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (1 , 3 , 3 , [1 , 1 , 1 ])), # perfect match
682
+ ([[1 , 2 , 3 ]], [[1 , 5 ]], (1 , 3 , 1 , [1 , 0 , 0 ])), # early mismatch
683
+ ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]],
684
+ (2 , 3 , 3 , [2 , 1 ])), # multiple sequences
685
+ ([[1 ]], [[1 , 2 ]], (1 , 1 , 1 , [1 ])), # single token sequence
686
+ ([[]], [[5 ]], (0 , 0 , 0 , [0 ])), # empty sequence
679
687
([[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 7 ], [4 , 8 ]],
680
- (6 , 3 )), # multiple mismatches
688
+ (2 , 6 , 3 , [ 2 , 1 , 0 ] )), # multiple mismatches
681
689
])
682
690
def test_schedule_spec_decoding_stats (spec_tokens , output_tokens , expected ):
683
691
"""Test scheduling behavior with speculative decoding.
@@ -686,7 +694,8 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
686
694
1. Speculated tokens get scheduled correctly
687
695
2. Spec decoding stats properly count number of draft and accepted tokens
688
696
"""
689
- scheduler = create_scheduler ()
697
+ num_spec_tokens = max (1 , max (len (t ) for t in spec_tokens ))
698
+ scheduler = create_scheduler (num_speculative_tokens = num_spec_tokens )
690
699
requests = create_requests (num_requests = len (spec_tokens ), num_tokens = 1 )
691
700
req_ids = []
692
701
req_to_index = {}
@@ -759,5 +768,7 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
759
768
else :
760
769
assert scheduler_stats .spec_decoding_stats is not None
761
770
stats = scheduler_stats .spec_decoding_stats
762
- assert stats .num_draft_tokens == expected [0 ]
763
- assert stats .num_accepted_tokens == expected [1 ]
771
+ assert stats .num_drafts == expected [0 ]
772
+ assert stats .num_draft_tokens == expected [1 ]
773
+ assert stats .num_accepted_tokens == expected [2 ]
774
+ assert stats .num_accepted_tokens_per_pos == expected [3 ]
0 commit comments