Skip to content

Commit 9643be5

Browse files
authored
[TRTLLM-5050][feat] Enable per-request stats with PyT backend (#4156)
* feat: Add per-request stats support with PyT backend Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Adding unit test Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Fixing stats unit test Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> * Fixing test with overlap Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com> --------- Signed-off-by: Patrice Castonguay <55748270+pcastonguay@users.noreply.github.com>
1 parent 286a789 commit 9643be5

File tree

6 files changed

+224
-51
lines changed

6 files changed

+224
-51
lines changed

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,15 @@ void initBindings(pybind11::module_& m)
273273
.def_property_readonly("is_generation_in_progress_state", &GenLlmReq::isGenerationInProgressState)
274274
.def_property_readonly("is_disagg_context_transmission_state", &GenLlmReq::isDisaggContextTransmissionState)
275275
.def_property_readonly("is_disagg_context_complete_state", &GenLlmReq::isDisaggContextCompleteState)
276+
.def_property_readonly("stage", &GenLlmReq::getRequestStage)
277+
.def_property_readonly("kv_cache_transfer_time_ms", &GenLlmReq::getKvCacheTransferTimeMS)
278+
.def_property_readonly("kv_cache_size", &GenLlmReq::getKvCacheSize)
279+
.def_property_readonly("avg_decoded_tokens_per_iter", &GenLlmReq::getAvgDecodedTokensPerIter)
280+
.def_property_readonly("alloc_total_blocks", &GenLlmReq::getAllocTotalBlocksPerRequest)
281+
.def_property_readonly("alloc_new_blocks", &GenLlmReq::getAllocNewBlocksPerRequest)
282+
.def_property_readonly("reused_blocks", &GenLlmReq::getReusedBlocksPerRequest)
283+
.def_property_readonly("missed_blocks", &GenLlmReq::getMissedBlocksPerRequest)
284+
.def_property_readonly("kv_cache_hit_rate", &GenLlmReq::getKVCacheHitRatePerRequest)
276285
.def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType)
277286
.def_property_readonly("position_ids",
278287
[](GenLlmReq& self)

tensorrt_llm/_torch/pyexecutor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class PyTorchConfig:
6161
kv_cache_dtype: str = "auto"
6262
use_kv_cache: bool = True
6363
enable_iter_perf_stats: bool = False
64+
# If true, enables per request stats per iteration
65+
# Must also set enable_iter_perf_stats to true to get request stats
66+
enable_iter_req_stats: bool = False
6467
print_iter_log: bool = False
6568

6669
torch_compile_enabled: bool = False

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 87 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020

2121
from tensorrt_llm._utils import (global_mpi_rank, is_trace_enabled, nvtx_range,
2222
trace_func)
23-
from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats,
23+
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
24+
FinishReason, InflightBatchingStats,
2425
IterationStats, KvCacheStats,
26+
RequestStage, RequestStats,
2527
RequestType, StaticBatchingStats)
26-
from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet
28+
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
29+
ReqIdsSet)
2730
from tensorrt_llm.logger import logger
2831

2932
from ..distributed import Distributed
@@ -196,6 +199,7 @@ def __init__(self,
196199
self.max_draft_tokens = max_draft_tokens
197200
self.print_log = model_engine.pytorch_backend_config.print_iter_log
198201
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
202+
self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats
199203
self.num_fetch_requests_cur_rank = 0
200204
self.num_fetch_requests = 0
201205
self.shutdown_event = threading.Event()
@@ -373,10 +377,10 @@ def get_latest_iteration_stats(self):
373377
if self.enable_iter_perf_stats == False:
374378
return []
375379

376-
latest_stats = tuple()
380+
latest_stats = (IterationStats(), None)
377381
try:
378382
self.stats_lock.acquire()
379-
latest_stats = tuple(self.stats)
383+
latest_stats = self.stats
380384
self.stats = []
381385
finally:
382386
self.stats_lock.release()
@@ -510,8 +514,63 @@ def _get_init_iter_stats(self, num_new_active_requests,
510514
stats.static_batching_stats = StaticBatchingStats()
511515
return stats
512516

517+
def _populate_req_stats(
518+
self, finished_requests: List[LlmRequest],
519+
active_requests: List[LlmRequest],
520+
scheduled_requests: ScheduledRequests
521+
) -> Optional[List[RequestStats]]:
522+
523+
def get_req_stats(req: LlmRequest) -> RequestStats:
524+
req_stat = RequestStats()
525+
req_stat.id = req.request_id
526+
req_stat.context_prefill_position = req.context_current_position
527+
req_stat.num_generated_tokens = req.max_beam_num_tokens - req.orig_prompt_len
528+
req_stat.avg_num_decoded_tokens_per_iter = req.avg_decoded_tokens_per_iter
529+
req_stat.alloc_total_blocks_per_request = req.alloc_total_blocks
530+
req_stat.alloc_new_blocks_per_request = req.alloc_new_blocks
531+
req_stat.reused_blocks_per_request = req.reused_blocks
532+
req_stat.missed_blocks_per_request = req.missed_blocks
533+
req_stat.kv_cache_hit_rate_per_request = req.kv_cache_hit_rate
534+
req_stat.scheduled = req in scheduled_requests.context_requests or req in scheduled_requests.generation_requests
535+
if req.llm_request_type == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY or req.llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY:
536+
req_stat.dis_serving_stats = DisServingRequestStats()
537+
req_stat.dis_serving_stats.kv_cache_transfer_ms = req.kv_cache_transfer_time_ms
538+
req_stat.dis_serving_stats.kv_cache_size = req.kv_cache_size
539+
return req_stat
540+
541+
def get_queued_req_stats(req: LlmRequest) -> RequestStats:
542+
req_stat = RequestStats()
543+
req_stat.id = req.request_id
544+
req_stat.context_prefill_position = 0
545+
req_stat.num_generated_tokens = 0
546+
req_stat.avg_num_decoded_tokens_per_iter = 0
547+
req_stat.alloc_total_blocks_per_request = 0
548+
req_stat.alloc_new_blocks_per_request = 0
549+
req_stat.reused_blocks_per_request = 0
550+
req_stat.missed_blocks_per_request = 0
551+
req_stat.kv_cache_hit_rate_per_request = 0
552+
return req_stat
553+
554+
req_stats = []
555+
for req in active_requests:
556+
req_stat = get_req_stats(req)
557+
req_stat.stage = req.stage
558+
req_stats.append(req_stat)
559+
560+
for req in list(self.request_queue.queue):
561+
req_stat = get_queued_req_stats(req)
562+
req.stage = RequestStage.QUEUED
563+
req_stats.append(req_stat)
564+
565+
for req in finished_requests:
566+
req_stat = get_req_stats(req)
567+
req_stat.stage = RequestStage.GENERATION_COMPLETE
568+
req_stats.append(req_stat)
569+
570+
return req_stats
571+
513572
def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
514-
scheduled_batch):
573+
scheduled_batch) -> IterationStats:
515574
stats.iter_latency_ms = iter_latency_ms
516575

517576
stats.num_queued_requests = self.request_queue.qsize()
@@ -554,23 +613,34 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
554613
stats.inflight_batching_stats.micro_batch_id = 0
555614
return stats
556615

557-
def _append_iter_stats(self, stats):
616+
def _append_iter_stats(self,
617+
stats: IterationStats,
618+
req_stats: Optional[List[RequestStats]] = None):
619+
558620
try:
559621
self.stats_lock.acquire()
560-
self.stats.append(stats)
622+
self.stats.append((stats, req_stats))
561623
finally:
562624
self.stats_lock.release()
563625

564626
def _process_iter_stats(self, finished_requests: list[LlmRequest],
627+
active_requests: List[LlmRequest],
565628
batch_state: BatchState):
566629
iter_end_time = time.time()
567630
iter_latency_ms = iter_end_time - batch_state.iter_start_time
568631
if batch_state.iter_stats is None:
569632
return
633+
634+
req_stats = self._populate_req_stats(
635+
finished_requests, active_requests,
636+
batch_state.decoder_state.scheduled_requests) if (
637+
self.enable_iter_req_stats
638+
and self.enable_iter_perf_stats) else None
639+
570640
self._append_iter_stats(
571641
self._update_iter_stats(
572642
batch_state.iter_stats, iter_latency_ms, len(finished_requests),
573-
batch_state.decoder_state.scheduled_requests))
643+
batch_state.decoder_state.scheduled_requests), req_stats)
574644

575645
def _executor_loop_cleanup(self):
576646
with self.response_cv:
@@ -677,7 +747,9 @@ def _executor_loop_pp(self):
677747
self._gather_dp_requests_num()
678748

679749
if self.enable_iter_perf_stats and previous_batch is not None:
680-
self._process_iter_stats(finished_requests, previous_batch)
750+
self._process_iter_stats(finished_requests,
751+
self.active_requests,
752+
previous_batch)
681753
self._executor_loop_cleanup()
682754

683755
def _executor_loop_pp_overlap(self):
@@ -815,7 +887,9 @@ def _executor_loop_pp_overlap(self):
815887
self._gather_dp_requests_num()
816888

817889
if self.enable_iter_perf_stats and previous_batch is not None:
818-
self._process_iter_stats(finished_requests, previous_batch)
890+
self._process_iter_stats(finished_requests,
891+
self.active_requests,
892+
previous_batch)
819893
self._executor_loop_cleanup()
820894

821895
def _executor_loop(self):
@@ -921,7 +995,7 @@ def _executor_loop(self):
921995
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
922996
'num_ctx_tokens']
923997
self._process_iter_stats(
924-
finished_requests,
998+
finished_requests, self.active_requests,
925999
BatchState(decoder_state=DecoderState(
9261000
scheduled_requests=scheduled_batch),
9271001
iter_stats=iter_stats,
@@ -1099,7 +1173,8 @@ def _process_previous_batch(self):
10991173
self._add_kv_cache_events()
11001174

11011175
if self.enable_iter_perf_stats:
1102-
self._process_iter_stats(finished_requests, self.previous_batch)
1176+
self._process_iter_stats(finished_requests, self.active_requests,
1177+
self.previous_batch)
11031178

11041179
@nvtx_range("_forward_step_inter_pp")
11051180
def _forward_step_inter_pp(self, scheduled_batch) -> DecoderState:

tensorrt_llm/executor/worker.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,34 @@ def _iteration_result_task(self, it_result_queue: IterationResultQueue,
272272
return True # success
273273

274274
def dispatch_stats_task(self) -> bool:
275-
return self._iteration_result_task(
276-
self.stats_queues, self.engine.get_latest_iteration_stats,
277-
self._iter_stats_result, lambda x: x.to_json_str())
275+
276+
# Define a Callable to join iteration and request stats
277+
def stats_serializer(
278+
stats: Tuple[tllm.IterationStats, tllm.RequestStats]) -> str:
279+
iteration_stats, req_stats = stats
280+
stats_dict = json.loads(iteration_stats.to_json_str())
281+
282+
if req_stats is not None and len(req_stats) > 0:
283+
stats_dict["requestStats"] = []
284+
for req_stat in req_stats:
285+
stats_dict["requestStats"].append(
286+
json.loads(req_stat.to_json_str()))
287+
288+
# Convert back to JSON string
289+
return json.dumps(stats_dict)
290+
291+
def get_stats():
292+
if isinstance(self.engine, tllm.Executor):
293+
iter_stats = self.engine.get_latest_iteration_stats()
294+
#TODO: Support req stats with TRT engine
295+
# This would require ensuring iter and req stats have same size
296+
return [(iter_stat, None) for iter_stat in iter_stats]
297+
else:
298+
return self.engine.get_latest_iteration_stats()
299+
300+
return self._iteration_result_task(self.stats_queues, get_stats,
301+
self._iter_stats_result,
302+
stats_serializer)
278303

279304
def dispatch_kv_cache_events_task(self) -> bool:
280305
if isinstance(self.engine, tllm.Executor):

0 commit comments

Comments
 (0)