Skip to content

Commit 07b2c18

Browse files
dyli-googlelk-chen
authored andcommitted
[V1] Move usage stats to worker and start logging TPU hardware (vllm-project#16211)
1 parent 0d8a569 commit 07b2c18

File tree

6 files changed

+22
-10
lines changed

6 files changed

+22
-10
lines changed

vllm/usage/usage_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ def _report_usage_once(self, model_architecture: str,
174174
cuda_get_device_properties(0, ("name", "total_memory")))
175175
if current_platform.is_cuda():
176176
self.cuda_runtime = torch.version.cuda
177+
if current_platform.is_tpu():
178+
try:
179+
import torch_xla
180+
self.gpu_count = torch_xla.runtime.world_size()
181+
self.gpu_type = torch_xla.tpu.get_tpu_type()
182+
self.gpu_memory_per_device = (
183+
torch_xla.core.xla_model.get_memory_info()["bytes_limit"])
184+
except Exception:
185+
pass
177186
self.provider = _detect_cloud_provider()
178187
self.architecture = platform.machine()
179188
self.platform = platform.platform()

vllm/v1/engine/async_llm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
3737
StatLoggerBase)
3838
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
39-
from vllm.v1.utils import report_usage_stats
4039

4140
logger = init_logger(__name__)
4241

@@ -113,9 +112,6 @@ def __init__(
113112
except RuntimeError:
114113
pass
115114

116-
# If usage stat is enabled, collect relevant info.
117-
report_usage_stats(vllm_config, usage_context)
118-
119115
@classmethod
120116
def from_vllm_config(
121117
cls,

vllm/v1/engine/llm_engine.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from vllm.v1.engine.parallel_sampling import ParentRequest
2929
from vllm.v1.engine.processor import Processor
3030
from vllm.v1.executor.abstract import Executor
31-
from vllm.v1.utils import report_usage_stats
3231

3332
logger = init_logger(__name__)
3433

@@ -97,9 +96,6 @@ def __init__(
9796
# for v0 compatibility
9897
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
9998

100-
# If usage stat is enabled, collect relevant info.
101-
report_usage_stats(vllm_config, usage_context)
102-
10399
@classmethod
104100
def from_vllm_config(
105101
cls,

vllm/v1/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
205205
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
206206

207207

208-
def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
208+
def report_usage_stats(
209+
vllm_config,
210+
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT) -> None:
209211
"""Report usage statistics if enabled."""
210212

211213
if not is_usage_stats_enabled():

vllm/v1/worker/gpu_worker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.utils import GiB_bytes
2424
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2525
from vllm.v1.outputs import ModelRunnerOutput
26+
from vllm.v1.utils import report_usage_stats
2627
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2728
from vllm.v1.worker.worker_base import WorkerBase
2829

@@ -141,6 +142,10 @@ def init_device(self):
141142
self.model_runner: GPUModelRunner = GPUModelRunner(
142143
self.vllm_config, self.device)
143144

145+
if self.rank == 0:
146+
# If usage stat is enabled, collect relevant info.
147+
report_usage_stats(self.vllm_config)
148+
144149
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
145150
# to hijack tensor allocation.
146151
def load_model(self) -> None:

vllm/v1/worker/tpu_worker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
2222
KVCacheSpec)
2323
from vllm.v1.outputs import ModelRunnerOutput
24-
from vllm.v1.utils import bind_kv_cache
24+
from vllm.v1.utils import bind_kv_cache, report_usage_stats
2525
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
2626

2727
logger = init_logger(__name__)
@@ -133,6 +133,10 @@ def init_device(self):
133133
# Init ModelRunner here, so that we have access to self.device.
134134
self.model_runner = TPUModelRunner(self.vllm_config, self.device)
135135

136+
if rank == 0:
137+
# If usage stat is enabled, collect relevant info.
138+
report_usage_stats(self.vllm_config)
139+
136140
def determine_available_memory(self) -> int:
137141
kv_caches: dict[str, torch.Tensor] = {}
138142
kv_cache_spec = self.model_runner.get_kv_cache_spec()

0 commit comments

Comments
 (0)