diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 453ae7b6f56..bb9993448aa 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -45,8 +45,12 @@ def main(): parser.add_argument("--enable_chunked_prefill", action='store_true') parser.add_argument("--max_num_batched_tokens", type=int, default=2048) parser.add_argument("--temp", type=float, default=0) + parser.add_argument("--use_v1", type=str, default="1", help='1 or 0') args = parser.parse_args() + # TODO: remove this option once EAGLE in v1 is ready. + os.environ["VLLM_USE_V1"] = args.use_v1 + model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" @@ -94,10 +98,16 @@ def main(): # to account for the token from the target model that's always going to be # accepted acceptance_counts = [0] * (args.num_spec_tokens + 1) - for output in outputs: - for step, count in enumerate( - output.metrics.spec_token_acceptance_counts): - acceptance_counts[step] += count + if args.use_v1 == '1': + for output in outputs: + for step, count in enumerate( + output.spec_token_acceptance_counts[0]): + acceptance_counts[step] += count + else: + for output in outputs: + for step, count in enumerate( + output.metrics.spec_token_acceptance_counts): + acceptance_counts[step] += count print("-" * 50) print(f"mean acceptance length: \ diff --git a/vllm/outputs.py b/vllm/outputs.py index 014e8d5d882..19d8fe08eb6 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -33,6 +33,10 @@ class CompletionOutput: to stop, None if the completion finished for some other reason including encountering the EOS token. lora_request: The LoRA request that was used to generate the output. + spec_token_acceptance_counts: A list tracking the total number of + accepted tokens at each speculation step for a request. Its length + is num_spec_tokens + 1 since there is always one token generated + by the target model. """ index: int @@ -43,6 +47,7 @@ class CompletionOutput: finish_reason: Optional[str] = None stop_reason: Union[int, str, None] = None lora_request: Optional[LoRARequest] = None + spec_token_acceptance_counts: Optional[list[int]] = None def finished(self) -> bool: return self.finish_reason is not None @@ -133,6 +138,9 @@ def __init__( self.encoder_prompt = encoder_prompt self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens + self.spec_token_acceptance_counts = [ + o.spec_token_acceptance_counts for o in outputs + ] def add(self, next_output: "RequestOutput") -> None: """Merge subsequent RequestOutput into this one""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a81574875a5..ccddd341743 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -608,7 +608,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) + num_accepted_tokens=len(generated_token_ids) - 1, + request_id=req_id) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -769,11 +770,13 @@ def make_spec_decoding_stats( spec_decoding_stats: Optional[SpecDecodingStats], num_draft_tokens: int, num_accepted_tokens: int, + request_id: str, ) -> Optional[SpecDecodingStats]: if not self.log_stats: return None if spec_decoding_stats is None: spec_decoding_stats = SpecDecodingStats() spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_accepted_tokens=num_accepted_tokens, + request_id=request_id) return spec_decoding_stats diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 1264e43c79d..33a6225009b 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -60,6 +60,7 @@ class EngineCoreRequest( eos_token_id: Optional[int] arrival_time: float lora_request: Optional[LoRARequest] + num_spec_tokens: int class EngineCoreEventType(enum.IntEnum): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 4c67186f704..f3285735e57 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -92,7 +92,7 @@ def __init__( asyncio_mode=False, vllm_config=vllm_config, executor_class=executor_class, - log_stats=False, # FIXME: implement + log_stats=True, # FIXME: implement ) if not multiprocess_mode: @@ -183,11 +183,20 @@ def add_request( priority: int = 0, ) -> None: # Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + num_spec_tokens = 0 + if self.vllm_config.speculative_config is not None: + num_spec_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens) + request = self.processor.process_inputs( + request_id, + prompt, + params, + arrival_time, + lora_request, + trace_headers, + prompt_adapter_request, + priority, + num_spec_tokens=num_spec_tokens) n = params.n if isinstance(params, SamplingParams) else 1 @@ -223,7 +232,7 @@ def step(self) -> list[RequestOutput]: # 2) Process EngineCoreOutputs. processed_outputs = self.output_processor.process_outputs( - outputs.outputs) + outputs.outputs, scheduler_stats=outputs.scheduler_stats) # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 70f072d3c93..4d7b86ec951 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,7 +14,7 @@ from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, - RequestStateStats) + RequestStateStats, SchedulerStats) class RequestOutputCollector: @@ -81,6 +81,7 @@ def __init__( arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + num_spec_tokens: int = 0, ): self.request_id = request_id self.parent_req = parent_req @@ -99,6 +100,8 @@ def __init__( self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None + self.spec_token_acceptance_counts = [0] * (num_spec_tokens + 1) + @classmethod def from_new_request( cls, @@ -133,13 +136,13 @@ def from_new_request( arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, + num_spec_tokens=request.num_spec_tokens, ) def make_request_output( - self, - new_token_ids: list[int], - finish_reason: Optional[FinishReason], + self, new_token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]] ) -> Optional[RequestOutput]: finished = finish_reason is not None @@ -150,7 +153,10 @@ def make_request_output( return None completion_output = self._new_completion_output( - new_token_ids, finish_reason, stop_reason) + new_token_ids, + finish_reason, + stop_reason, + spec_token_acceptance_counts=spec_token_acceptance_counts) request_id = self.request_id if self.parent_req is None: @@ -190,6 +196,7 @@ def _new_completion_output( token_ids: list[int], finish_reason: Optional[FinishReason], stop_reason: Union[int, str, None], + spec_token_acceptance_counts: Optional[list[int]], ) -> CompletionOutput: finished = finish_reason is not None @@ -212,7 +219,8 @@ def _new_completion_output( logprobs=logprobs, cumulative_logprob=self.logprobs_processor.cumulative_logprob, finish_reason=str(finish_reason) if finished else None, - stop_reason=stop_reason if finished else None) + stop_reason=stop_reason if finished else None, + spec_token_acceptance_counts=spec_token_acceptance_counts) class OutputProcessor: @@ -280,6 +288,7 @@ def process_outputs( engine_core_outputs: list[EngineCoreOutput], engine_core_timestamp: Optional[float] = None, iteration_stats: Optional[IterationStats] = None, + scheduler_stats: Optional[SchedulerStats] = None, ) -> OutputProcessorOutput: """ Process the EngineCoreOutputs: @@ -318,6 +327,8 @@ def process_outputs( self._update_stats_from_output(req_state, engine_core_output, engine_core_timestamp, iteration_stats) + self._update_stats_from_scheduler(req_id, req_state, + scheduler_stats) new_token_ids = engine_core_output.new_token_ids finish_reason = engine_core_output.finish_reason @@ -337,7 +348,11 @@ def process_outputs( # 4) Create and handle RequestOutput objects. if request_output := req_state.make_request_output( - new_token_ids, finish_reason, stop_reason): + new_token_ids, + finish_reason, + stop_reason, + spec_token_acceptance_counts=req_state. + spec_token_acceptance_counts): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put(request_output) @@ -403,3 +418,13 @@ def _update_stats_from_finished(self, req_state: RequestState, ParentRequest.observe_finished_request( req_state.parent_req, iteration_stats, req_state.stats.num_generation_tokens) + + def _update_stats_from_scheduler( + self, req_id: str, req_state: RequestState, + scheduler_stats: Optional[SchedulerStats]): + if scheduler_stats is not None and \ + scheduler_stats.spec_decoding_stats is not None: + num_accepted_tokens = scheduler_stats. \ + spec_decoding_stats.per_request_stats.get(req_id, 0) + for i in range(num_accepted_tokens): + req_state.spec_token_acceptance_counts[i] += 1 diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6d3290f1656..6488ca6ed4d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -198,6 +198,7 @@ def process_inputs( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, + num_spec_tokens: int = 0, ) -> EngineCoreRequest: # TODO(woosuk): Support pooling models. @@ -313,7 +314,7 @@ def process_inputs( eos_token_id=eos_token_id, arrival_time=arrival_time, lora_request=lora_request, - ) + num_spec_tokens=num_spec_tokens) def _validate_model_inputs(self, inputs: ProcessorInputs, diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 7bb3c209d1d..a44523c945c 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass +from dataclasses import dataclass, field import numpy as np @@ -13,20 +13,25 @@ class SpecDecodingStats: num_draft_tokens: int = 0 num_accepted_tokens: int = 0 + per_request_stats: dict = field(default_factory=dict) def take(self): copied = SpecDecodingStats(self.num_draft_tokens, - self.num_accepted_tokens) + self.num_accepted_tokens, + self.per_request_stats) self.reset() return copied def reset(self): self.num_draft_tokens = 0 self.num_accepted_tokens = 0 + self.per_request_stats = {} - def observe(self, num_draft_tokens: int, num_accepted_tokens: int): + def observe(self, num_draft_tokens: int, num_accepted_tokens: int, + request_id: str): self.num_draft_tokens += num_draft_tokens self.num_accepted_tokens += num_accepted_tokens + self.per_request_stats[request_id] = num_accepted_tokens + 1 class SpecDecodingMetrics: