|
20 | 20 |
|
21 | 21 | from tensorrt_llm._utils import (global_mpi_rank, is_trace_enabled, nvtx_range,
|
22 | 22 | trace_func)
|
23 |
| -from tensorrt_llm.bindings.executor import (FinishReason, InflightBatchingStats, |
| 23 | +from tensorrt_llm.bindings.executor import (DisServingRequestStats, |
| 24 | + FinishReason, InflightBatchingStats, |
24 | 25 | IterationStats, KvCacheStats,
|
| 26 | + RequestStage, RequestStats, |
25 | 27 | RequestType, StaticBatchingStats)
|
26 |
| -from tensorrt_llm.bindings.internal.batch_manager import ReqIdsSet |
| 28 | +from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType, |
| 29 | + ReqIdsSet) |
27 | 30 | from tensorrt_llm.logger import logger
|
28 | 31 |
|
29 | 32 | from ..distributed import Distributed
|
@@ -196,6 +199,7 @@ def __init__(self,
|
196 | 199 | self.max_draft_tokens = max_draft_tokens
|
197 | 200 | self.print_log = model_engine.pytorch_backend_config.print_iter_log
|
198 | 201 | self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
|
| 202 | + self.enable_iter_req_stats = model_engine.pytorch_backend_config.enable_iter_req_stats |
199 | 203 | self.num_fetch_requests_cur_rank = 0
|
200 | 204 | self.num_fetch_requests = 0
|
201 | 205 | self.shutdown_event = threading.Event()
|
@@ -373,10 +377,10 @@ def get_latest_iteration_stats(self):
|
373 | 377 | if self.enable_iter_perf_stats == False:
|
374 | 378 | return []
|
375 | 379 |
|
376 |
| - latest_stats = tuple() |
| 380 | + latest_stats = (IterationStats(), None) |
377 | 381 | try:
|
378 | 382 | self.stats_lock.acquire()
|
379 |
| - latest_stats = tuple(self.stats) |
| 383 | + latest_stats = self.stats |
380 | 384 | self.stats = []
|
381 | 385 | finally:
|
382 | 386 | self.stats_lock.release()
|
@@ -510,8 +514,63 @@ def _get_init_iter_stats(self, num_new_active_requests,
|
510 | 514 | stats.static_batching_stats = StaticBatchingStats()
|
511 | 515 | return stats
|
512 | 516 |
|
| 517 | + def _populate_req_stats( |
| 518 | + self, finished_requests: List[LlmRequest], |
| 519 | + active_requests: List[LlmRequest], |
| 520 | + scheduled_requests: ScheduledRequests |
| 521 | + ) -> Optional[List[RequestStats]]: |
| 522 | + |
| 523 | + def get_req_stats(req: LlmRequest) -> RequestStats: |
| 524 | + req_stat = RequestStats() |
| 525 | + req_stat.id = req.request_id |
| 526 | + req_stat.context_prefill_position = req.context_current_position |
| 527 | + req_stat.num_generated_tokens = req.max_beam_num_tokens - req.orig_prompt_len |
| 528 | + req_stat.avg_num_decoded_tokens_per_iter = req.avg_decoded_tokens_per_iter |
| 529 | + req_stat.alloc_total_blocks_per_request = req.alloc_total_blocks |
| 530 | + req_stat.alloc_new_blocks_per_request = req.alloc_new_blocks |
| 531 | + req_stat.reused_blocks_per_request = req.reused_blocks |
| 532 | + req_stat.missed_blocks_per_request = req.missed_blocks |
| 533 | + req_stat.kv_cache_hit_rate_per_request = req.kv_cache_hit_rate |
| 534 | + req_stat.scheduled = req in scheduled_requests.context_requests or req in scheduled_requests.generation_requests |
| 535 | + if req.llm_request_type == LlmRequestType.LLMREQUEST_TYPE_CONTEXT_ONLY or req.llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY: |
| 536 | + req_stat.dis_serving_stats = DisServingRequestStats() |
| 537 | + req_stat.dis_serving_stats.kv_cache_transfer_ms = req.kv_cache_transfer_time_ms |
| 538 | + req_stat.dis_serving_stats.kv_cache_size = req.kv_cache_size |
| 539 | + return req_stat |
| 540 | + |
| 541 | + def get_queued_req_stats(req: LlmRequest) -> RequestStats: |
| 542 | + req_stat = RequestStats() |
| 543 | + req_stat.id = req.request_id |
| 544 | + req_stat.context_prefill_position = 0 |
| 545 | + req_stat.num_generated_tokens = 0 |
| 546 | + req_stat.avg_num_decoded_tokens_per_iter = 0 |
| 547 | + req_stat.alloc_total_blocks_per_request = 0 |
| 548 | + req_stat.alloc_new_blocks_per_request = 0 |
| 549 | + req_stat.reused_blocks_per_request = 0 |
| 550 | + req_stat.missed_blocks_per_request = 0 |
| 551 | + req_stat.kv_cache_hit_rate_per_request = 0 |
| 552 | + return req_stat |
| 553 | + |
| 554 | + req_stats = [] |
| 555 | + for req in active_requests: |
| 556 | + req_stat = get_req_stats(req) |
| 557 | + req_stat.stage = req.stage |
| 558 | + req_stats.append(req_stat) |
| 559 | + |
| 560 | + for req in list(self.request_queue.queue): |
| 561 | + req_stat = get_queued_req_stats(req) |
| 562 | + req.stage = RequestStage.QUEUED |
| 563 | + req_stats.append(req_stat) |
| 564 | + |
| 565 | + for req in finished_requests: |
| 566 | + req_stat = get_req_stats(req) |
| 567 | + req_stat.stage = RequestStage.GENERATION_COMPLETE |
| 568 | + req_stats.append(req_stat) |
| 569 | + |
| 570 | + return req_stats |
| 571 | + |
513 | 572 | def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
|
514 |
| - scheduled_batch): |
| 573 | + scheduled_batch) -> IterationStats: |
515 | 574 | stats.iter_latency_ms = iter_latency_ms
|
516 | 575 |
|
517 | 576 | stats.num_queued_requests = self.request_queue.qsize()
|
@@ -554,23 +613,34 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
|
554 | 613 | stats.inflight_batching_stats.micro_batch_id = 0
|
555 | 614 | return stats
|
556 | 615 |
|
557 |
| - def _append_iter_stats(self, stats): |
| 616 | + def _append_iter_stats(self, |
| 617 | + stats: IterationStats, |
| 618 | + req_stats: Optional[List[RequestStats]] = None): |
| 619 | + |
558 | 620 | try:
|
559 | 621 | self.stats_lock.acquire()
|
560 |
| - self.stats.append(stats) |
| 622 | + self.stats.append((stats, req_stats)) |
561 | 623 | finally:
|
562 | 624 | self.stats_lock.release()
|
563 | 625 |
|
564 | 626 | def _process_iter_stats(self, finished_requests: list[LlmRequest],
|
| 627 | + active_requests: List[LlmRequest], |
565 | 628 | batch_state: BatchState):
|
566 | 629 | iter_end_time = time.time()
|
567 | 630 | iter_latency_ms = iter_end_time - batch_state.iter_start_time
|
568 | 631 | if batch_state.iter_stats is None:
|
569 | 632 | return
|
| 633 | + |
| 634 | + req_stats = self._populate_req_stats( |
| 635 | + finished_requests, active_requests, |
| 636 | + batch_state.decoder_state.scheduled_requests) if ( |
| 637 | + self.enable_iter_req_stats |
| 638 | + and self.enable_iter_perf_stats) else None |
| 639 | + |
570 | 640 | self._append_iter_stats(
|
571 | 641 | self._update_iter_stats(
|
572 | 642 | batch_state.iter_stats, iter_latency_ms, len(finished_requests),
|
573 |
| - batch_state.decoder_state.scheduled_requests)) |
| 643 | + batch_state.decoder_state.scheduled_requests), req_stats) |
574 | 644 |
|
575 | 645 | def _executor_loop_cleanup(self):
|
576 | 646 | with self.response_cv:
|
@@ -677,7 +747,9 @@ def _executor_loop_pp(self):
|
677 | 747 | self._gather_dp_requests_num()
|
678 | 748 |
|
679 | 749 | if self.enable_iter_perf_stats and previous_batch is not None:
|
680 |
| - self._process_iter_stats(finished_requests, previous_batch) |
| 750 | + self._process_iter_stats(finished_requests, |
| 751 | + self.active_requests, |
| 752 | + previous_batch) |
681 | 753 | self._executor_loop_cleanup()
|
682 | 754 |
|
683 | 755 | def _executor_loop_pp_overlap(self):
|
@@ -815,7 +887,9 @@ def _executor_loop_pp_overlap(self):
|
815 | 887 | self._gather_dp_requests_num()
|
816 | 888 |
|
817 | 889 | if self.enable_iter_perf_stats and previous_batch is not None:
|
818 |
| - self._process_iter_stats(finished_requests, previous_batch) |
| 890 | + self._process_iter_stats(finished_requests, |
| 891 | + self.active_requests, |
| 892 | + previous_batch) |
819 | 893 | self._executor_loop_cleanup()
|
820 | 894 |
|
821 | 895 | def _executor_loop(self):
|
@@ -921,7 +995,7 @@ def _executor_loop(self):
|
921 | 995 | iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
|
922 | 996 | 'num_ctx_tokens']
|
923 | 997 | self._process_iter_stats(
|
924 |
| - finished_requests, |
| 998 | + finished_requests, self.active_requests, |
925 | 999 | BatchState(decoder_state=DecoderState(
|
926 | 1000 | scheduled_requests=scheduled_batch),
|
927 | 1001 | iter_stats=iter_stats,
|
@@ -1099,7 +1173,8 @@ def _process_previous_batch(self):
|
1099 | 1173 | self._add_kv_cache_events()
|
1100 | 1174 |
|
1101 | 1175 | if self.enable_iter_perf_stats:
|
1102 |
| - self._process_iter_stats(finished_requests, self.previous_batch) |
| 1176 | + self._process_iter_stats(finished_requests, self.active_requests, |
| 1177 | + self.previous_batch) |
1103 | 1178 |
|
1104 | 1179 | @nvtx_range("_forward_step_inter_pp")
|
1105 | 1180 | def _forward_step_inter_pp(self, scheduled_batch) -> DecoderState:
|
|
0 commit comments