Skip to content

Commit ee0eb6b

Browse files
huydhnradeksm
authored andcommitted
Fix more broken speculative decode tests (vllm-project#17450)
Signed-off-by: Huy Do <huydhn@gmail.com>
1 parent f71f63d commit ee0eb6b

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

tests/spec_decode/e2e/test_medusa_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
205205
@pytest.mark.parametrize(
206206
"common_llm_kwargs",
207207
[{
208-
"block_size": 8,
208+
"block_size": 16,
209209
# 2 for small prompt, 256//8 for generated.
210210
"num_gpu_blocks_override": 2 + 256 // 8,
211211
"max_model_len": (2 + 256 // 8) * 8,

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
267267
@pytest.mark.parametrize(
268268
"common_llm_kwargs",
269269
[{
270-
"block_size": 8,
270+
"block_size": 16,
271271
# 2 for small prompt, 256//8 for generated.
272272
"num_gpu_blocks_override": 2 + 256 // 8,
273273
"max_model_len": (2 + 256 // 8) * 8,
@@ -321,7 +321,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
321321
@pytest.mark.parametrize(
322322
"common_llm_kwargs",
323323
[{
324-
"block_size": 8,
324+
"block_size": 16,
325325
# 2 for small prompt, 256//8 for generated.
326326
"num_gpu_blocks_override": 2 + 256 // 8,
327327
"max_model_len": (2 + 256 // 8) * 8,

tests/spec_decode/e2e/test_ngram_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
152152
@pytest.mark.parametrize(
153153
"common_llm_kwargs",
154154
[{
155-
"block_size": 8,
155+
"block_size": 16,
156156
# 2 for small prompt, 256//8 for generated.
157157
"num_gpu_blocks_override": 2 + 256 // 8,
158158
"max_model_len": (2 + 256 // 8) * 8,

vllm/spec_decode/multi_step_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,14 @@ def init_device(self) -> None:
5151
def set_include_gpu_probs_tensor(self) -> None:
5252
# Need include_gpu_probs_tensor for MultiStepWorker
5353
self.model_runner.sampler.include_gpu_probs_tensor = True
54+
if hasattr(self.model_runner.model, "sampler"):
55+
(self.model_runner.model.sampler.include_gpu_probs_tensor) = True
5456

5557
def set_should_modify_greedy_probs_inplace(self) -> None:
5658
self.model_runner.sampler.should_modify_greedy_probs_inplace = True
59+
if hasattr(self.model_runner.model, "sampler"):
60+
(self.model_runner.model.sampler.should_modify_greedy_probs_inplace
61+
) = True
5762

5863
@torch.inference_mode()
5964
def sampler_output(

0 commit comments

Comments
 (0)