Skip to content

[V1] Add request-level, per-step acceptance counts tracking for spec dec. #16367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ def main():
parser.add_argument("--enable_chunked_prefill", action='store_true')
parser.add_argument("--max_num_batched_tokens", type=int, default=2048)
parser.add_argument("--temp", type=float, default=0)
parser.add_argument("--use_v1", type=str, default="1", help='1 or 0')
args = parser.parse_args()

# TODO: remove this option once EAGLE in v1 is ready.
os.environ["VLLM_USE_V1"] = args.use_v1

model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"

Expand Down Expand Up @@ -94,10 +98,16 @@ def main():
# to account for the token from the target model that's always going to be
# accepted
acceptance_counts = [0] * (args.num_spec_tokens + 1)
for output in outputs:
for step, count in enumerate(
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count
if args.use_v1 == '1':
for output in outputs:
for step, count in enumerate(
output.spec_token_acceptance_counts[0]):
acceptance_counts[step] += count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few things to note here:

  1. The RequestOutput.metrics API in general is not implemented in V1 (and maybe no longer works in V0). So far, I would say there is no decision whether to retain this API in the long term
  2. With Prometheus, I would say we will expose the "num_accepted_tokens per draft position" metric as an aggregate across all requests. See this design doc
  3. A viable replacement for the RequestOutput.metrics API might be to expose the in-memory Prometheus metrics via an API. In which case, we would only be exposing via the API the same metrics that are exposed via Prometheus, so I don't want to introduce a metric that will not be available in the Prometheus.
  4. What's actually being printed below is the mean acceptance rate which can be more easily calculated as num_accepted_tokens_total / num_drafts which is much more in line with what we will expose through Prometheus

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@markmc regarding 1, 2, 3, they all make sense.

regarding 4, it's indeed the same, but we also have the option of looking at acceptance rate at each position (should just be num_accepted_tokens per draft position) since we have kept tracked of the acceptance counts. I only computed the AL for simplicity since we are not doing finer-grained analysis yet.

also a quick clarification question: are you suggesting that we actually don't need this per request spec_token_acceptance_counts thing once Prometheus's ready? My understanding is that Prometheus does not track request level stats. But if it does, then we probably don't need this PR since I also feel like output.spec_token_acceptance_counts is a bit ugly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the design doc, I'm proposing adding:

  1. num_drafts so mean acceptance length can be calculated
  2. num_accepted_tokens_per_pos to track the acceptance rate per position

But aggregating across all requests - you would not be able to look at the acceptance counts of individual requests

However, note that Prometheus does allow you to slice according to time intervals - so with a long-running vLLM instance, you can look at these aggregated-across-all-requests metrics over whatever time period you choose

How would you imagine using per-request stats?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition to debugging, another use case for per-request stats is for fine-grained benchmarking. For example, I might want to understand the acceptance length for different tasks (summarization, coding) and RequestOutput.metrics used to be very helpful.

Per-request stats might also help identify outliers, for example, requests with really poor acceptance counts to shed light on how to further improve the speculator.

else:
for output in outputs:
for step, count in enumerate(
output.metrics.spec_token_acceptance_counts):
acceptance_counts[step] += count

print("-" * 50)
print(f"mean acceptance length: \
Expand Down
8 changes: 8 additions & 0 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class CompletionOutput:
to stop, None if the completion finished for some other reason
including encountering the EOS token.
lora_request: The LoRA request that was used to generate the output.
spec_token_acceptance_counts: A list tracking the total number of
accepted tokens at each speculation step for a request. Its length
is num_spec_tokens + 1 since there is always one token generated
by the target model.
"""

index: int
Expand All @@ -43,6 +47,7 @@ class CompletionOutput:
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
lora_request: Optional[LoRARequest] = None
spec_token_acceptance_counts: Optional[list[int]] = None

def finished(self) -> bool:
return self.finish_reason is not None
Expand Down Expand Up @@ -133,6 +138,9 @@ def __init__(
self.encoder_prompt = encoder_prompt
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.spec_token_acceptance_counts = [
o.spec_token_acceptance_counts for o in outputs
]

def add(self, next_output: "RequestOutput") -> None:
"""Merge subsequent RequestOutput into this one"""
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,8 @@ def update_from_output(
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
num_accepted_tokens=len(generated_token_ids) - 1)
num_accepted_tokens=len(generated_token_ids) - 1,
request_id=req_id)

cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
Expand Down Expand Up @@ -769,11 +770,13 @@ def make_spec_decoding_stats(
spec_decoding_stats: Optional[SpecDecodingStats],
num_draft_tokens: int,
num_accepted_tokens: int,
request_id: str,
) -> Optional[SpecDecodingStats]:
if not self.log_stats:
return None
if spec_decoding_stats is None:
spec_decoding_stats = SpecDecodingStats()
spec_decoding_stats.observe(num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
num_accepted_tokens=num_accepted_tokens,
request_id=request_id)
return spec_decoding_stats
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class EngineCoreRequest(
eos_token_id: Optional[int]
arrival_time: float
lora_request: Optional[LoRARequest]
num_spec_tokens: int


class EngineCoreEventType(enum.IntEnum):
Expand Down
23 changes: 16 additions & 7 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False, # FIXME: implement
log_stats=True, # FIXME: implement
)

if not multiprocess_mode:
Expand Down Expand Up @@ -183,11 +183,20 @@ def add_request(
priority: int = 0,
) -> None:
# Process raw inputs into the request.
request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
num_spec_tokens = 0
if self.vllm_config.speculative_config is not None:
num_spec_tokens = (
self.vllm_config.speculative_config.num_speculative_tokens)
request = self.processor.process_inputs(
request_id,
prompt,
params,
arrival_time,
lora_request,
trace_headers,
prompt_adapter_request,
priority,
num_spec_tokens=num_spec_tokens)

n = params.n if isinstance(params, SamplingParams) else 1

Expand Down Expand Up @@ -223,7 +232,7 @@ def step(self) -> list[RequestOutput]:

# 2) Process EngineCoreOutputs.
processed_outputs = self.output_processor.process_outputs(
outputs.outputs)
outputs.outputs, scheduler_stats=outputs.scheduler_stats)

# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
Expand Down
39 changes: 32 additions & 7 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats)
RequestStateStats, SchedulerStats)


class RequestOutputCollector:
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
num_spec_tokens: int = 0,
):
self.request_id = request_id
self.parent_req = parent_req
Expand All @@ -99,6 +100,8 @@ def __init__(
self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None

self.spec_token_acceptance_counts = [0] * (num_spec_tokens + 1)

@classmethod
def from_new_request(
cls,
Expand Down Expand Up @@ -133,13 +136,13 @@ def from_new_request(
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
num_spec_tokens=request.num_spec_tokens,
)

def make_request_output(
self,
new_token_ids: list[int],
finish_reason: Optional[FinishReason],
self, new_token_ids: list[int], finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
spec_token_acceptance_counts: Optional[list[int]]
Copy link
Contributor

@ekagra-ranjan ekagra-ranjan Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we pls add the dim of this list: [bs] or [num_spec_tokens + 1]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sure. it's [num_spec_tokens + 1]. will add it to the annotation above!

) -> Optional[RequestOutput]:

finished = finish_reason is not None
Expand All @@ -150,7 +153,10 @@ def make_request_output(
return None

completion_output = self._new_completion_output(
new_token_ids, finish_reason, stop_reason)
new_token_ids,
finish_reason,
stop_reason,
spec_token_acceptance_counts=spec_token_acceptance_counts)

request_id = self.request_id
if self.parent_req is None:
Expand Down Expand Up @@ -190,6 +196,7 @@ def _new_completion_output(
token_ids: list[int],
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
spec_token_acceptance_counts: Optional[list[int]],
) -> CompletionOutput:

finished = finish_reason is not None
Expand All @@ -212,7 +219,8 @@ def _new_completion_output(
logprobs=logprobs,
cumulative_logprob=self.logprobs_processor.cumulative_logprob,
finish_reason=str(finish_reason) if finished else None,
stop_reason=stop_reason if finished else None)
stop_reason=stop_reason if finished else None,
spec_token_acceptance_counts=spec_token_acceptance_counts)


class OutputProcessor:
Expand Down Expand Up @@ -280,6 +288,7 @@ def process_outputs(
engine_core_outputs: list[EngineCoreOutput],
engine_core_timestamp: Optional[float] = None,
iteration_stats: Optional[IterationStats] = None,
scheduler_stats: Optional[SchedulerStats] = None,
) -> OutputProcessorOutput:
"""
Process the EngineCoreOutputs:
Expand Down Expand Up @@ -318,6 +327,8 @@ def process_outputs(
self._update_stats_from_output(req_state, engine_core_output,
engine_core_timestamp,
iteration_stats)
self._update_stats_from_scheduler(req_id, req_state,
scheduler_stats)

new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
Expand All @@ -337,7 +348,11 @@ def process_outputs(

# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason):
new_token_ids,
finish_reason,
stop_reason,
spec_token_acceptance_counts=req_state.
spec_token_acceptance_counts):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
Expand Down Expand Up @@ -403,3 +418,13 @@ def _update_stats_from_finished(self, req_state: RequestState,
ParentRequest.observe_finished_request(
req_state.parent_req, iteration_stats,
req_state.stats.num_generation_tokens)

def _update_stats_from_scheduler(
self, req_id: str, req_state: RequestState,
scheduler_stats: Optional[SchedulerStats]):
if scheduler_stats is not None and \
scheduler_stats.spec_decoding_stats is not None:
num_accepted_tokens = scheduler_stats. \
spec_decoding_stats.per_request_stats.get(req_id, 0)
for i in range(num_accepted_tokens):
req_state.spec_token_acceptance_counts[i] += 1
3 changes: 2 additions & 1 deletion vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def process_inputs(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
num_spec_tokens: int = 0,
) -> EngineCoreRequest:

# TODO(woosuk): Support pooling models.
Expand Down Expand Up @@ -313,7 +314,7 @@ def process_inputs(
eos_token_id=eos_token_id,
arrival_time=arrival_time,
lora_request=lora_request,
)
num_spec_tokens=num_spec_tokens)

def _validate_model_inputs(self,
inputs: ProcessorInputs,
Expand Down
11 changes: 8 additions & 3 deletions vllm/v1/spec_decode/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from dataclasses import dataclass, field

import numpy as np

Expand All @@ -13,20 +13,25 @@
class SpecDecodingStats:
num_draft_tokens: int = 0
num_accepted_tokens: int = 0
per_request_stats: dict = field(default_factory=dict)

def take(self):
copied = SpecDecodingStats(self.num_draft_tokens,
self.num_accepted_tokens)
self.num_accepted_tokens,
self.per_request_stats)
self.reset()
return copied

def reset(self):
self.num_draft_tokens = 0
self.num_accepted_tokens = 0
self.per_request_stats = {}

def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
def observe(self, num_draft_tokens: int, num_accepted_tokens: int,
request_id: str):
self.num_draft_tokens += num_draft_tokens
self.num_accepted_tokens += num_accepted_tokens
self.per_request_stats[request_id] = num_accepted_tokens + 1


class SpecDecodingMetrics:
Expand Down