From 0931329a5a849875543c694236ce9d93d317579d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 4 Apr 2025 22:03:00 -0700 Subject: [PATCH 01/12] [Spec Decode] Do not generate draft tokens beyond max_model_len Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 41 +++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3aaaf34bc79..19a658004a4 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,6 +9,8 @@ from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata +PADDING_SLOT_ID = -1 + class EagleProposer: @@ -20,7 +22,10 @@ def __init__( self.vllm_config = vllm_config self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) + self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size + + # Optimization: pre-compute and cache the arange tensor. self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, device=device) @@ -103,22 +108,52 @@ def propose( # Update the inputs. input_ids = draft_token_ids_list[-1] positions += 1 + + # NOTE(woosuk): We should handle the case where the draft model + # generates tokens beyond the max model length. Since it is complex + # to remove the request from the batch, we keep the request in the + # batch but adjust the position ids and slot mappings to avoid the + # out-of-range access during the model execution. The draft tokens + # generated through this adjustment should be ignored. + exceeds_max_model_len = positions >= self.max_model_len + # Mask out the position ids that exceed the max model length. + # Otherwise, we may get out-of-range error in RoPE. + clamped_positions = torch.where(exceeds_max_model_len, 0, + positions) + + # Increment the sequence lengths. attn_metadata.max_seq_len += 1 attn_metadata.seq_lens += 1 + # Consider max model length. + attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, + self.max_model_len) + # For the requests that exceed the max model length, we set the + # sequence length to 1 to minimize the overheads in attention. + attn_metadata.seq_lens = torch.where(exceeds_max_model_len, 1, + attn_metadata.seq_lens) + # Compute the slot mapping. - block_numbers = positions // self.block_size + block_numbers = clamped_positions // self.block_size block_ids = block_table.gather(dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + - positions % self.block_size) + clamped_positions % self.block_size) + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadverently updated with the + # padding tokens. + attn_metadata.slot_mapping = torch.where( + exceeds_max_model_len, + PADDING_SLOT_ID, + attn_metadata.slot_mapping, + ) # Run the model. with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=input_ids, hidden_states=hidden_states, - positions=positions, + positions=clamped_positions, ) logits = self.model.compute_logits(hidden_states, None) draft_token_ids, probs = compute_probs_and_sample_next_token( From 55b6e1d3eb67aa75d37aae3b1252ca0cef4e9568 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 5 Apr 2025 10:02:41 -0700 Subject: [PATCH 02/12] Fix ngram Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/ngram_proposer.py | 7 ++++++- vllm/v1/worker/gpu_model_runner.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 7e548bb48b5..3d978431e63 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -18,6 +18,9 @@ def __init__(self, vllm_config: VllmConfig): # tokens follow the match, we will return the maximum amount of # tokens until the end. self.k = vllm_config.speculative_config.num_speculative_tokens + # Maximum length of the model. + self.max_model_len = vllm_config.model_config.max_model_len + # Trigger Numba JIT compilation for N-gram proposer. # This usually takes less than 1 second. self.propose(np.zeros(1024, dtype=np.int32)) @@ -50,9 +53,11 @@ def propose( followed that pattern. Here we will return [4,2,3] because we only have three tokens after the match. """ + # Do not generate draft tokens beyond the max model length. + k = min(self.k, self.max_model_len - context_token_ids.shape[0]) # TODO(woosuk): Optimize this. for n in range(self.max_n, self.min_n - 1, -1): - result = _find_subarray_kmp(context_token_ids, n, self.k) + result = _find_subarray_kmp(context_token_ids, n, k) if result is not None: return result return None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 82b07c6cd32..e9f079988e6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1235,7 +1235,8 @@ def generate_draft_token_ids( draft_token_ids.append([]) continue - # Skip requests that require top-p, top-k, etc. + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. req_id = self.input_batch.req_ids[i] if not is_spec_decode_supported(req_id, self.input_batch): draft_token_ids.append([]) @@ -1244,6 +1245,11 @@ def generate_draft_token_ids( # Add sampled_token_ids to token_ids_cpu. start_idx = self.input_batch.num_tokens_no_spec[i] end_idx = start_idx + num_sampled_ids + if end_idx >= self.max_model_len: + # Skip requests that have already reached the max model length. + draft_token_ids.append([]) + continue + self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids drafter_output = self.drafter.propose( self.input_batch.token_ids_cpu[i, :end_idx]) From f5c3af6e0a62aede42dbf038ea93b9a51dce1832 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 5 Apr 2025 10:04:50 -0700 Subject: [PATCH 03/12] Update comments Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 19a658004a4..1088a5c675c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -111,10 +111,10 @@ def propose( # NOTE(woosuk): We should handle the case where the draft model # generates tokens beyond the max model length. Since it is complex - # to remove the request from the batch, we keep the request in the - # batch but adjust the position ids and slot mappings to avoid the + # to remove such requests from the batch, we keep them in the batch + # but adjust the position ids and slot mappings to avoid the # out-of-range access during the model execution. The draft tokens - # generated through this adjustment should be ignored. + # generated with this adjustment should be ignored. exceeds_max_model_len = positions >= self.max_model_len # Mask out the position ids that exceed the max model length. # Otherwise, we may get out-of-range error in RoPE. @@ -128,7 +128,7 @@ def propose( attn_metadata.max_seq_len = min(attn_metadata.max_seq_len, self.max_model_len) # For the requests that exceed the max model length, we set the - # sequence length to 1 to minimize the overheads in attention. + # sequence length to 1 to minimize their overheads in attention. attn_metadata.seq_lens = torch.where(exceeds_max_model_len, 1, attn_metadata.seq_lens) @@ -140,7 +140,7 @@ def propose( attn_metadata.slot_mapping = (block_ids * self.block_size + clamped_positions % self.block_size) # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadverently updated with the + # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. attn_metadata.slot_mapping = torch.where( exceeds_max_model_len, From d449ede1101ce6330967bc37f34614c645986277 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 20 Apr 2025 19:42:03 -0700 Subject: [PATCH 04/12] Add test Signed-off-by: Woosuk Kwon --- tests/v1/spec_decode/test_max_len.py | 57 ++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 6 +++ vllm/v1/spec_decode/eagle.py | 10 ++--- 3 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 tests/v1/spec_decode/test_max_len.py diff --git a/tests/v1/spec_decode/test_max_len.py b/tests/v1/spec_decode/test_max_len.py new file mode 100644 index 00000000000..f577fb4ab32 --- /dev/null +++ b/tests/v1/spec_decode/test_max_len.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test whether spec decoding handles the max model length properly.""" + +import pytest + +from vllm import LLM, SamplingParams + +_PROMPTS = [ + "1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1", + "Repeat the following sentence 10 times: Consistency is key to mastering any skill.", # noqa: E501 + "Who won the Turing Award in 2018, and for what contribution? Describe in detail.", # noqa: E501 +] + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_ngram_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="facebook/opt-125m", + max_model_len=100, + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "ngram", + "prompt_lookup_max": 5, + "prompt_lookup_min": 3, + "num_speculative_tokens": num_speculative_tokens, + }, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) + + +@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 10]) +def test_eagle_max_len( + monkeypatch: pytest.MonkeyPatch, + num_speculative_tokens: int, +): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="meta-llama/Meta-Llama-3-8B-Instruct", + enforce_eager=True, # For faster initialization. + speculative_config={ + "method": "eagle", + "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", + "num_speculative_tokens": num_speculative_tokens, + }, + max_model_len=100, + ) + sampling_params = SamplingParams(max_tokens=100, ignore_eos=True) + llm.generate(_PROMPTS, sampling_params) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 69e7cc8ee08..8108b106126 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -185,6 +185,12 @@ def schedule(self) -> SchedulerOutput: num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + # Make sure the input position does not exceed the max model len. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - request.num_computed_tokens) + assert num_new_tokens > 0 + # Schedule encoder inputs. if request.has_encoder_inputs: (encoder_inputs_to_schedule, num_new_tokens, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 981a7ae9013..9505bd7ce43 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -136,8 +136,7 @@ def propose( self.max_model_len) # For the requests that exceed the max model length, we set the # sequence length to 1 to minimize their overheads in attention. - attn_metadata.seq_lens = torch.where(exceeds_max_model_len, 1, - attn_metadata.seq_lens) + attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1) # Compute the slot mapping. block_numbers = clamped_positions // self.block_size @@ -149,11 +148,8 @@ def propose( # Mask out the slot mappings that exceed the max model length. # Otherwise, the KV cache will be inadvertently updated with the # padding tokens. - attn_metadata.slot_mapping = torch.where( - exceeds_max_model_len, - PADDING_SLOT_ID, - attn_metadata.slot_mapping, - ) + attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len, + PADDING_SLOT_ID) # Run the model. with set_forward_context(attn_metadata, self.vllm_config): From af7462b6d1ebc10ca83da57f2804c9e49a40e83c Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 20 Apr 2025 19:45:28 -0700 Subject: [PATCH 05/12] Add comment Signed-off-by: Woosuk Kwon --- vllm/v1/core/sched/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8108b106126..16efc42f212 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -186,6 +186,7 @@ def schedule(self) -> SchedulerOutput: assert num_new_tokens > 0 # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. num_new_tokens = min( num_new_tokens, self.max_model_len - request.num_computed_tokens) From f2edf18ea59f5d4f077d4ddecd40a1531a443c5d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 20 Apr 2025 20:13:30 -0700 Subject: [PATCH 06/12] [V1][Spec Decode] Use argmax for sampling draft tokens Signed-off-by: Woosuk Kwon --- vllm/v1/sample/rejection_sampler.py | 14 ++++---- vllm/v1/spec_decode/eagle.py | 55 ++++------------------------- vllm/v1/worker/gpu_model_runner.py | 5 +-- 3 files changed, 14 insertions(+), 60 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3cf7fde5cd0..9061a64db57 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -226,7 +226,7 @@ def rejection_sample( is_greedy, max_spec_len, vocab_size, - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, num_warps=1, ) return output_token_ids @@ -423,7 +423,7 @@ def sample_recovered_tokens( q, vocab_size, triton.next_power_of_2(vocab_size), - IS_NGRAM=draft_probs is None, + NO_DRAFT_PROBS=draft_probs is None, ) return recovered_token_ids @@ -490,7 +490,7 @@ def rejection_random_sample_kernel( is_greedy_ptr, # [batch_size] max_spec_len, vocab_size, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) is_greedy = tl.load(is_greedy_ptr + req_idx) @@ -509,7 +509,7 @@ def rejection_random_sample_kernel( for pos in range(num_draft_tokens): if not rejected: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_prob = 1 else: draft_prob = tl.load(draft_probs_ptr + @@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel( q_ptr, # [batch_size, vocab_size] vocab_size, PADDED_VOCAB_SIZE: tl.constexpr, - IS_NGRAM: tl.constexpr, + NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) if req_idx == 0: @@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel( return vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - if IS_NGRAM: + if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id) @@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel( recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - if IS_NGRAM: + if NO_DRAFT_PROBS: # Restore the original probability. tl.store( target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9505bd7ce43..9adb9397870 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -51,7 +51,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] last_token_indices = cu_num_tokens[1:] - 1 @@ -94,17 +94,15 @@ def propose( ) sample_hidden_states = hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) - draft_token_ids, draft_probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: - # [batch_size, 1] and [batch_size, 1, vocab_size] - return draft_token_ids.view(-1, 1), draft_probs.unsqueeze(dim=1) + # [batch_size, 1] + return draft_token_ids.view(-1, 1) # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - draft_probs_list = [draft_probs] positions = target_positions[last_token_indices] hidden_states = sample_hidden_states @@ -159,16 +157,12 @@ def propose( positions=clamped_positions, ) logits = self.model.compute_logits(hidden_states, None) - draft_token_ids, probs = compute_probs_and_sample_next_token( - logits, sampling_metadata) + draft_token_ids = logits.argmax(dim=-1) draft_token_ids_list.append(draft_token_ids) - draft_probs_list.append(probs) # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # [batch_size, num_speculative_tokens, vocab_size] - draft_probs = torch.stack(draft_probs_list, dim=1) - return draft_token_ids, draft_probs + return draft_token_ids @staticmethod def prepare_inputs( @@ -238,43 +232,6 @@ def load_model(self, target_model: nn.Module) -> None: self.model.lm_head = target_model.lm_head -# FIXME(woosuk): The logic here is duplicated with the main sampling code. -# We should refactor this to reuse the same sampling implementation. -def compute_probs_and_sample_next_token( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, -) -> tuple[torch.Tensor, torch.Tensor]: - if sampling_metadata.all_greedy: - # For greedy requests, draft_probs is not used in rejection sampling. - # Therefore, we can just return the logits. - probs = logits - next_token_ids = logits.argmax(dim=-1) - return next_token_ids, probs - - is_greedy = sampling_metadata.temperature == -1 - temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) - logits.div_(temperature.view(-1, 1)) - probs = logits.softmax(dim=-1, dtype=torch.float32) - - # NOTE(woosuk): Currently, we ignore most of the sampling parameters in - # generating the draft tokens. We only use the temperature. While this - # could degrade the acceptance rate, it does not affect the distribution - # of the generated tokens after rejection sampling. - - # TODO(woosuk): Consider seeds. - q = torch.empty_like(probs) - q.exponential_() - next_token_ids = probs.div_(q).argmax(dim=-1).view(-1) - if not sampling_metadata.all_random: - greedy_token_ids = probs.argmax(dim=-1) - next_token_ids = torch.where( - is_greedy, - greedy_token_ids, - next_token_ids, - ) - return next_token_ids, probs - - @triton.jit def prepare_input_kernel( out_ptr, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4cb5a8e171a..269a7fabf75 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1229,7 +1229,7 @@ def execute_model( target_hidden_states = hidden_states[token_indices] target_slot_mapping = attn_metadata.slot_mapping[token_indices] - draft_token_ids, draft_probs = self.drafter.propose( + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, @@ -1240,9 +1240,6 @@ def execute_model( sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() - # TODO(woosuk): Cache draft_probs and use it for rejection sampling - # in the next step. - del draft_probs # Clear KVConnector state after all KVs are generated. if has_kv_transfer_group(): From 5a0646d4dbb53bd625a430578a843b5e9caa55f4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 20 Apr 2025 21:00:02 -0700 Subject: [PATCH 07/12] Fix test Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 691ca59b062..ea024fca33b 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -30,6 +30,7 @@ def create_scheduler( use_kv_connector: bool = False, num_blocks: int = 10000, block_size: int = 16, + max_model_len: Optional[int] = None, ) -> Scheduler: '''Create scheduler under test. @@ -44,10 +45,12 @@ def create_scheduler( Returns: :class:`Scheduler` instance ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_num_batched_tokens, + max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, ) @@ -296,6 +299,7 @@ def test_no_mm_input_chunking(): model="llava-hf/llava-1.5-7b-hf", max_num_batched_tokens=1024, disable_chunked_mm_input=True, + max_model_len=2048, ) mm_positions = [[PlaceholderRange(offset=400, length=800)]] requests = create_requests(num_requests=1, From 0c6b21110ce0801b67a33fcb82bc0602a3ca3454 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 20 Apr 2025 22:31:06 -0700 Subject: [PATCH 08/12] Fix test Signed-off-by: Woosuk Kwon --- tests/v1/core/test_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index ea024fca33b..f173344344f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -53,6 +53,7 @@ def create_scheduler( max_model_len=max_model_len, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, ) model_config = ModelConfig( model=model, From 2e1f95d55202d18c023c152b4e883325032bc657 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 21 Apr 2025 08:41:43 -0700 Subject: [PATCH 09/12] minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/ngram_proposer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 3d978431e63..704153d43a2 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -55,6 +55,9 @@ def propose( """ # Do not generate draft tokens beyond the max model length. k = min(self.k, self.max_model_len - context_token_ids.shape[0]) + if k <= 0: + return None + # TODO(woosuk): Optimize this. for n in range(self.max_n, self.min_n - 1, -1): result = _find_subarray_kmp(context_token_ids, n, k) From cfe9668fb91f0fc8fa0f1f7d06a178fbb7af67b5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 21 Apr 2025 09:29:48 -0700 Subject: [PATCH 10/12] fix ngram test Signed-off-by: Woosuk Kwon --- tests/v1/spec_decode/test_ngram.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/v1/spec_decode/test_ngram.py b/tests/v1/spec_decode/test_ngram.py index 5caa4f052fc..50548219fff 100644 --- a/tests/v1/spec_decode/test_ngram.py +++ b/tests/v1/spec_decode/test_ngram.py @@ -2,7 +2,7 @@ import numpy as np -from vllm.config import SpeculativeConfig, VllmConfig +from vllm.config import ModelConfig, SpeculativeConfig, VllmConfig from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, _find_subarray_kmp, _kmp_lps_array) @@ -42,14 +42,24 @@ def test_find_subarray_kmp(): def test_ngram_proposer(): def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: - return NgramProposer(vllm_config=VllmConfig( - speculative_config=SpeculativeConfig.from_dict( - { - "prompt_lookup_min": min_n, - "prompt_lookup_max": max_n, - "num_speculative_tokens": k, - "method": "ngram", - }))) + # Dummy model config. Just to set max_model_len. + model_config = ModelConfig(model="facebook/opt-125m", + task="generate", + max_model_len=100, + tokenizer="facebook/opt-125m", + tokenizer_mode="auto", + dtype="auto", + seed=None, + trust_remote_code=False) + return NgramProposer( + vllm_config=VllmConfig(model_config=model_config, + speculative_config=SpeculativeConfig. + from_dict({ + "prompt_lookup_min": min_n, + "prompt_lookup_max": max_n, + "num_speculative_tokens": k, + "method": "ngram", + }))) # No match. result = ngram_proposer( From 4911507c1c4128427d618d30b9ad838b387928cc Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 23 Apr 2025 13:58:31 -0700 Subject: [PATCH 11/12] Add back Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9adb9397870..013f5902418 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -232,6 +232,49 @@ def load_model(self, target_model: nn.Module) -> None: self.model.lm_head = target_model.lm_head +# NOTE(woosuk): Currently, the below code is not used and we always use argmax +# to sample the draft tokens. We will use this after we find a way to manage +# the draft prob tensor. +# Refer to https://github.com/vllm-project/vllm/pull/16899 +# FIXME(woosuk): The logic here is duplicated with the main sampling code. +# We should refactor this to reuse the same sampling implementation. +def compute_probs_and_sample_next_token( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> tuple[torch.Tensor, torch.Tensor]: + if sampling_metadata.all_greedy: + # For greedy requests, draft_probs is not used in rejection sampling. + # Therefore, we can just return the logits. + probs = logits + next_token_ids = logits.argmax(dim=-1) + return next_token_ids, probs + + is_greedy = sampling_metadata.temperature == -1 + temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature) + logits.div_(temperature.view(-1, 1)) + probs = logits.softmax(dim=-1, dtype=torch.float32) + + # NOTE(woosuk): Currently, we ignore most of the sampling parameters in + # generating the draft tokens. We only use the temperature. While this + # could degrade the acceptance rate, it does not affect the distribution + # of the generated tokens after rejection sampling. + + # TODO(woosuk): Consider seeds. + q = torch.empty_like(probs) + q.exponential_() + # NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs + # will be used later for rejection sampling. + next_token_ids = probs.div(q).argmax(dim=-1).view(-1) + if not sampling_metadata.all_random: + greedy_token_ids = probs.argmax(dim=-1) + next_token_ids = torch.where( + is_greedy, + greedy_token_ids, + next_token_ids, + ) + return next_token_ids, probs + + @triton.jit def prepare_input_kernel( out_ptr, From 74fce8326587d663e93ced85dfa774fb7de84655 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 23 Apr 2025 13:59:29 -0700 Subject: [PATCH 12/12] minor Signed-off-by: Woosuk Kwon --- vllm/v1/spec_decode/eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 013f5902418..95f0c067d40 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -235,7 +235,7 @@ def load_model(self, target_model: nn.Module) -> None: # NOTE(woosuk): Currently, the below code is not used and we always use argmax # to sample the draft tokens. We will use this after we find a way to manage # the draft prob tensor. -# Refer to https://github.com/vllm-project/vllm/pull/16899 +# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details. # FIXME(woosuk): The logic here is duplicated with the main sampling code. # We should refactor this to reuse the same sampling implementation. def compute_probs_and_sample_next_token(