Skip to content

Commit bb095ff

Browse files
WoosukKwonliuzijing2014
authored andcommitted
[V1][Spec Decode] Always use argmax for sampling draft tokens (vllm-project#16899)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
1 parent cf179b2 commit bb095ff

File tree

3 files changed

+18
-23
lines changed

3 files changed

+18
-23
lines changed

vllm/v1/sample/rejection_sampler.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def rejection_sample(
226226
is_greedy,
227227
max_spec_len,
228228
vocab_size,
229-
IS_NGRAM=draft_probs is None,
229+
NO_DRAFT_PROBS=draft_probs is None,
230230
num_warps=1,
231231
)
232232
return output_token_ids
@@ -423,7 +423,7 @@ def sample_recovered_tokens(
423423
q,
424424
vocab_size,
425425
triton.next_power_of_2(vocab_size),
426-
IS_NGRAM=draft_probs is None,
426+
NO_DRAFT_PROBS=draft_probs is None,
427427
)
428428
return recovered_token_ids
429429

@@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
490490
is_greedy_ptr, # [batch_size]
491491
max_spec_len,
492492
vocab_size,
493-
IS_NGRAM: tl.constexpr,
493+
NO_DRAFT_PROBS: tl.constexpr,
494494
):
495495
req_idx = tl.program_id(0)
496496
is_greedy = tl.load(is_greedy_ptr + req_idx)
@@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
509509
for pos in range(num_draft_tokens):
510510
if not rejected:
511511
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
512-
if IS_NGRAM:
512+
if NO_DRAFT_PROBS:
513513
draft_prob = 1
514514
else:
515515
draft_prob = tl.load(draft_probs_ptr +
@@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
575575
q_ptr, # [batch_size, vocab_size]
576576
vocab_size,
577577
PADDED_VOCAB_SIZE: tl.constexpr,
578-
IS_NGRAM: tl.constexpr,
578+
NO_DRAFT_PROBS: tl.constexpr,
579579
):
580580
req_idx = tl.program_id(0)
581581
if req_idx == 0:
@@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
591591
return
592592

593593
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
594-
if IS_NGRAM:
594+
if NO_DRAFT_PROBS:
595595
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
596596
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
597597
draft_token_id)
@@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
624624
recovered_id = tl.argmax(prob / q, axis=-1)
625625
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
626626

627-
if IS_NGRAM:
627+
if NO_DRAFT_PROBS:
628628
# Restore the original probability.
629629
tl.store(
630630
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,

vllm/v1/spec_decode/eagle.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def propose(
5151
# [batch_size, max_num_blocks_per_req]
5252
block_table: torch.Tensor,
5353
sampling_metadata: SamplingMetadata,
54-
) -> tuple[torch.Tensor, torch.Tensor]:
54+
) -> torch.Tensor:
5555
num_tokens = target_token_ids.shape[0]
5656
batch_size = next_token_ids.shape[0]
5757
last_token_indices = cu_num_tokens[1:] - 1
@@ -94,17 +94,15 @@ def propose(
9494
)
9595
sample_hidden_states = hidden_states[last_token_indices]
9696
logits = self.model.compute_logits(sample_hidden_states, None)
97-
draft_token_ids, draft_probs = compute_probs_and_sample_next_token(
98-
logits, sampling_metadata)
97+
draft_token_ids = logits.argmax(dim=-1)
9998

10099
# Early exit if there is only one draft token to be generated.
101100
if self.num_speculative_tokens == 1:
102-
# [batch_size, 1] and [batch_size, 1, vocab_size]
103-
return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1)
101+
# [batch_size, 1]
102+
return draft_token_ids.view(-1, 1)
104103

105104
# Generate the remaining draft tokens.
106105
draft_token_ids_list = [draft_token_ids]
107-
draft_probs_list = [draft_probs]
108106

109107
positions = target_positions[last_token_indices]
110108
hidden_states = sample_hidden_states
@@ -159,16 +157,12 @@ def propose(
159157
positions=clamped_positions,
160158
)
161159
logits = self.model.compute_logits(hidden_states, None)
162-
draft_token_ids, probs = compute_probs_and_sample_next_token(
163-
logits, sampling_metadata)
160+
draft_token_ids = logits.argmax(dim=-1)
164161
draft_token_ids_list.append(draft_token_ids)
165-
draft_probs_list.append(probs)
166162

167163
# [batch_size, num_speculative_tokens]
168164
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
169-
# [batch_size, num_speculative_tokens, vocab_size]
170-
draft_probs = torch.stack(draft_probs_list, dim=1)
171-
return draft_token_ids, draft_probs
165+
return draft_token_ids
172166

173167
@staticmethod
174168
def prepare_inputs(
@@ -238,6 +232,10 @@ def load_model(self, target_model: nn.Module) -> None:
238232
self.model.lm_head = target_model.lm_head
239233

240234

235+
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
236+
# to sample the draft tokens. We will use this after we find a way to manage
237+
# the draft prob tensor.
238+
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
241239
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
242240
# We should refactor this to reuse the same sampling implementation.
243241
def compute_probs_and_sample_next_token(

vllm/v1/worker/gpu_model_runner.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def execute_model(
12301230
target_hidden_states = hidden_states[token_indices]
12311231
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
12321232

1233-
draft_token_ids, draft_probs = self.drafter.propose(
1233+
draft_token_ids = self.drafter.propose(
12341234
target_token_ids=target_token_ids,
12351235
target_positions=target_positions,
12361236
target_hidden_states=target_hidden_states,
@@ -1241,9 +1241,6 @@ def execute_model(
12411241
sampling_metadata=sampling_metadata,
12421242
)
12431243
spec_token_ids = draft_token_ids.tolist()
1244-
# TODO(woosuk): Cache draft_probs and use it for rejection sampling
1245-
# in the next step.
1246-
del draft_probs
12471244

12481245
# Clear KVConnector state after all KVs are generated.
12491246
if has_kv_transfer_group():

0 commit comments

Comments
 (0)