-
-
Notifications
You must be signed in to change notification settings - Fork 7.5k
[V1] Add request-level, per-step acceptance counts tracking for spec dec. #16367
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
luyuzhe111
wants to merge
5
commits into
vllm-project:main
Choose a base branch
from
luyuzhe111:spec_dec_stats
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
131e679
add request level, per-step acceptance counts tracking for spec dec
luyuzhe111 d21afbf
rebase
luyuzhe111 ab71966
Merge remote-tracking branch 'upstream/main' into spec_dec_stats
luyuzhe111 ddc1afd
update design
luyuzhe111 cbb96bc
minor
luyuzhe111 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from vllm.v1.engine.logprobs import LogprobsProcessor | ||
from vllm.v1.engine.parallel_sampling import ParentRequest | ||
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, | ||
RequestStateStats) | ||
RequestStateStats, SchedulerStats) | ||
|
||
|
||
class RequestOutputCollector: | ||
|
@@ -81,6 +81,7 @@ def __init__( | |
arrival_time: float, | ||
queue: Optional[RequestOutputCollector], | ||
log_stats: bool, | ||
num_spec_tokens: int = 0, | ||
): | ||
self.request_id = request_id | ||
self.parent_req = parent_req | ||
|
@@ -99,6 +100,8 @@ def __init__( | |
self.stats = RequestStateStats( | ||
arrival_time=arrival_time) if log_stats else None | ||
|
||
self.spec_token_acceptance_counts = [0] * (num_spec_tokens + 1) | ||
|
||
@classmethod | ||
def from_new_request( | ||
cls, | ||
|
@@ -133,13 +136,13 @@ def from_new_request( | |
arrival_time=request.arrival_time, | ||
queue=queue, | ||
log_stats=log_stats, | ||
num_spec_tokens=request.num_spec_tokens, | ||
) | ||
|
||
def make_request_output( | ||
self, | ||
new_token_ids: list[int], | ||
finish_reason: Optional[FinishReason], | ||
self, new_token_ids: list[int], finish_reason: Optional[FinishReason], | ||
stop_reason: Union[int, str, None], | ||
spec_token_acceptance_counts: Optional[list[int]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we pls add the dim of this list: [bs] or [num_spec_tokens + 1]? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah sure. it's [num_spec_tokens + 1]. will add it to the annotation above! |
||
) -> Optional[RequestOutput]: | ||
|
||
finished = finish_reason is not None | ||
|
@@ -150,7 +153,10 @@ def make_request_output( | |
return None | ||
|
||
completion_output = self._new_completion_output( | ||
new_token_ids, finish_reason, stop_reason) | ||
new_token_ids, | ||
finish_reason, | ||
stop_reason, | ||
spec_token_acceptance_counts=spec_token_acceptance_counts) | ||
|
||
request_id = self.request_id | ||
if self.parent_req is None: | ||
|
@@ -190,6 +196,7 @@ def _new_completion_output( | |
token_ids: list[int], | ||
finish_reason: Optional[FinishReason], | ||
stop_reason: Union[int, str, None], | ||
spec_token_acceptance_counts: Optional[list[int]], | ||
) -> CompletionOutput: | ||
|
||
finished = finish_reason is not None | ||
|
@@ -212,7 +219,8 @@ def _new_completion_output( | |
logprobs=logprobs, | ||
cumulative_logprob=self.logprobs_processor.cumulative_logprob, | ||
finish_reason=str(finish_reason) if finished else None, | ||
stop_reason=stop_reason if finished else None) | ||
stop_reason=stop_reason if finished else None, | ||
spec_token_acceptance_counts=spec_token_acceptance_counts) | ||
|
||
|
||
class OutputProcessor: | ||
|
@@ -280,6 +288,7 @@ def process_outputs( | |
engine_core_outputs: list[EngineCoreOutput], | ||
engine_core_timestamp: Optional[float] = None, | ||
iteration_stats: Optional[IterationStats] = None, | ||
scheduler_stats: Optional[SchedulerStats] = None, | ||
) -> OutputProcessorOutput: | ||
""" | ||
Process the EngineCoreOutputs: | ||
|
@@ -318,6 +327,8 @@ def process_outputs( | |
self._update_stats_from_output(req_state, engine_core_output, | ||
engine_core_timestamp, | ||
iteration_stats) | ||
self._update_stats_from_scheduler(req_id, req_state, | ||
scheduler_stats) | ||
|
||
new_token_ids = engine_core_output.new_token_ids | ||
finish_reason = engine_core_output.finish_reason | ||
|
@@ -337,7 +348,11 @@ def process_outputs( | |
|
||
# 4) Create and handle RequestOutput objects. | ||
if request_output := req_state.make_request_output( | ||
new_token_ids, finish_reason, stop_reason): | ||
new_token_ids, | ||
finish_reason, | ||
stop_reason, | ||
spec_token_acceptance_counts=req_state. | ||
spec_token_acceptance_counts): | ||
if req_state.queue is not None: | ||
# AsyncLLM: put into queue for handling by generate(). | ||
req_state.queue.put(request_output) | ||
|
@@ -403,3 +418,13 @@ def _update_stats_from_finished(self, req_state: RequestState, | |
ParentRequest.observe_finished_request( | ||
req_state.parent_req, iteration_stats, | ||
req_state.stats.num_generation_tokens) | ||
|
||
def _update_stats_from_scheduler( | ||
self, req_id: str, req_state: RequestState, | ||
scheduler_stats: Optional[SchedulerStats]): | ||
if scheduler_stats is not None and \ | ||
scheduler_stats.spec_decoding_stats is not None: | ||
num_accepted_tokens = scheduler_stats. \ | ||
spec_decoding_stats.per_request_stats.get(req_id, 0) | ||
for i in range(num_accepted_tokens): | ||
req_state.spec_token_acceptance_counts[i] += 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a few things to note here:
RequestOutput.metrics
API in general is not implemented in V1 (and maybe no longer works in V0). So far, I would say there is no decision whether to retain this API in the long termRequestOutput.metrics
API might be to expose the in-memory Prometheus metrics via an API. In which case, we would only be exposing via the API the same metrics that are exposed via Prometheus, so I don't want to introduce a metric that will not be available in the Prometheus.num_accepted_tokens_total / num_drafts
which is much more in line with what we will expose through PrometheusThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@markmc regarding 1, 2, 3, they all make sense.
regarding 4, it's indeed the same, but we also have the option of looking at acceptance rate at each position (should just be
num_accepted_tokens per draft position
) since we have kept tracked of the acceptance counts. I only computed the AL for simplicity since we are not doing finer-grained analysis yet.also a quick clarification question: are you suggesting that we actually don't need this per request
spec_token_acceptance_counts
thing once Prometheus's ready? My understanding is that Prometheus does not track request level stats. But if it does, then we probably don't need this PR since I also feel likeoutput.spec_token_acceptance_counts
is a bit ugly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the design doc, I'm proposing adding:
num_drafts
so mean acceptance length can be calculatednum_accepted_tokens_per_pos
to track the acceptance rate per positionBut aggregating across all requests - you would not be able to look at the acceptance counts of individual requests
However, note that Prometheus does allow you to slice according to time intervals - so with a long-running vLLM instance, you can look at these aggregated-across-all-requests metrics over whatever time period you choose
How would you imagine using per-request stats?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition to debugging, another use case for per-request stats is for fine-grained benchmarking. For example, I might want to understand the acceptance length for different tasks (summarization, coding) and
RequestOutput.metrics
used to be very helpful.Per-request stats might also help identify outliers, for example, requests with really poor acceptance counts to shed light on how to further improve the speculator.