@@ -51,7 +51,7 @@ def propose(
51
51
# [batch_size, max_num_blocks_per_req]
52
52
block_table : torch .Tensor ,
53
53
sampling_metadata : SamplingMetadata ,
54
- ) -> tuple [ torch .Tensor , torch . Tensor ] :
54
+ ) -> torch .Tensor :
55
55
num_tokens = target_token_ids .shape [0 ]
56
56
batch_size = next_token_ids .shape [0 ]
57
57
last_token_indices = cu_num_tokens [1 :] - 1
@@ -94,17 +94,15 @@ def propose(
94
94
)
95
95
sample_hidden_states = hidden_states [last_token_indices ]
96
96
logits = self .model .compute_logits (sample_hidden_states , None )
97
- draft_token_ids , draft_probs = compute_probs_and_sample_next_token (
98
- logits , sampling_metadata )
97
+ draft_token_ids = logits .argmax (dim = - 1 )
99
98
100
99
# Early exit if there is only one draft token to be generated.
101
100
if self .num_speculative_tokens == 1 :
102
- # [batch_size, 1] and [batch_size, 1, vocab_size]
103
- return draft_token_ids .view (- 1 , 1 ), draft_probs . unsqueeze ( dim = 1 )
101
+ # [batch_size, 1]
102
+ return draft_token_ids .view (- 1 , 1 )
104
103
105
104
# Generate the remaining draft tokens.
106
105
draft_token_ids_list = [draft_token_ids ]
107
- draft_probs_list = [draft_probs ]
108
106
109
107
positions = target_positions [last_token_indices ]
110
108
hidden_states = sample_hidden_states
@@ -159,16 +157,12 @@ def propose(
159
157
positions = clamped_positions ,
160
158
)
161
159
logits = self .model .compute_logits (hidden_states , None )
162
- draft_token_ids , probs = compute_probs_and_sample_next_token (
163
- logits , sampling_metadata )
160
+ draft_token_ids = logits .argmax (dim = - 1 )
164
161
draft_token_ids_list .append (draft_token_ids )
165
- draft_probs_list .append (probs )
166
162
167
163
# [batch_size, num_speculative_tokens]
168
164
draft_token_ids = torch .stack (draft_token_ids_list , dim = 1 )
169
- # [batch_size, num_speculative_tokens, vocab_size]
170
- draft_probs = torch .stack (draft_probs_list , dim = 1 )
171
- return draft_token_ids , draft_probs
165
+ return draft_token_ids
172
166
173
167
@staticmethod
174
168
def prepare_inputs (
@@ -238,6 +232,10 @@ def load_model(self, target_model: nn.Module) -> None:
238
232
self .model .lm_head = target_model .lm_head
239
233
240
234
235
+ # NOTE(woosuk): Currently, the below code is not used and we always use argmax
236
+ # to sample the draft tokens. We will use this after we find a way to manage
237
+ # the draft prob tensor.
238
+ # Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
241
239
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
242
240
# We should refactor this to reuse the same sampling implementation.
243
241
def compute_probs_and_sample_next_token (
0 commit comments