Skip to content

Commit 346ca1f

Browse files
mgoinnjhill
authored andcommitted
[Bugfix] Enable V1 usage stats (vllm-project#16986)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
1 parent 86504c9 commit 346ca1f

File tree

5 files changed

+75
-5
lines changed

5 files changed

+75
-5
lines changed

vllm/usage/usage_lib.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import vllm.envs as envs
2121
from vllm.connections import global_http_connection
22+
from vllm.utils import cuda_device_count_stateless, cuda_get_device_properties
2223
from vllm.version import __version__ as VLLM_VERSION
2324

2425
_config_home = envs.VLLM_CONFIG_ROOT
@@ -168,10 +169,9 @@ def _report_usage_once(self, model_architecture: str,
168169
# Platform information
169170
from vllm.platforms import current_platform
170171
if current_platform.is_cuda_alike():
171-
device_property = torch.cuda.get_device_properties(0)
172-
self.gpu_count = torch.cuda.device_count()
173-
self.gpu_type = device_property.name
174-
self.gpu_memory_per_device = device_property.total_memory
172+
self.gpu_count = cuda_device_count_stateless()
173+
self.gpu_type, self.gpu_memory_per_device = (
174+
cuda_get_device_properties(0, ("name", "total_memory")))
175175
if current_platform.is_cuda():
176176
self.cuda_runtime = torch.version.cuda
177177
self.provider = _detect_cloud_provider()

vllm/utils.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@
3838
from collections import UserDict, defaultdict
3939
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
4040
Iterable, Iterator, KeysView, Mapping)
41+
from concurrent.futures.process import ProcessPoolExecutor
4142
from dataclasses import dataclass, field
4243
from functools import cache, lru_cache, partial, wraps
4344
from types import MappingProxyType
4445
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
45-
Optional, Tuple, Type, TypeVar, Union, cast, overload)
46+
Optional, Sequence, Tuple, Type, TypeVar, Union, cast,
47+
overload)
4648
from uuid import uuid4
4749

4850
import cachetools
@@ -1235,6 +1237,22 @@ def cuda_is_initialized() -> bool:
12351237
return torch.cuda.is_initialized()
12361238

12371239

1240+
def cuda_get_device_properties(device,
1241+
names: Sequence[str],
1242+
init_cuda=False) -> tuple[Any, ...]:
1243+
"""Get specified CUDA device property values without initializing CUDA in
1244+
the current process."""
1245+
if init_cuda or cuda_is_initialized():
1246+
props = torch.cuda.get_device_properties(device)
1247+
return tuple(getattr(props, name) for name in names)
1248+
1249+
# Run in subprocess to avoid initializing CUDA as a side effect.
1250+
mp_ctx = multiprocessing.get_context("fork")
1251+
with ProcessPoolExecutor(max_workers=1, mp_context=mp_ctx) as executor:
1252+
return executor.submit(cuda_get_device_properties, device, names,
1253+
True).result()
1254+
1255+
12381256
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
12391257
"""Make an instance method that weakly references
12401258
its associated instance and no-ops once that

vllm/v1/engine/async_llm.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger,
3737
StatLoggerBase)
3838
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
39+
from vllm.v1.utils import report_usage_stats
3940

4041
logger = init_logger(__name__)
4142

@@ -114,6 +115,9 @@ def __init__(
114115
except RuntimeError:
115116
pass
116117

118+
# If usage stat is enabled, collect relevant info.
119+
report_usage_stats(vllm_config, usage_context)
120+
117121
@classmethod
118122
def from_vllm_config(
119123
cls,

vllm/v1/engine/llm_engine.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm.v1.engine.parallel_sampling import ParentRequest
2929
from vllm.v1.engine.processor import Processor
3030
from vllm.v1.executor.abstract import Executor
31+
from vllm.v1.utils import report_usage_stats
3132

3233
logger = init_logger(__name__)
3334

@@ -99,6 +100,9 @@ def __init__(
99100
# for v0 compatibility
100101
self.model_executor = self.engine_core.engine_core.model_executor # type: ignore
101102

103+
# If usage stat is enabled, collect relevant info.
104+
report_usage_stats(vllm_config, usage_context)
105+
102106
@classmethod
103107
def from_vllm_config(
104108
cls,

vllm/v1/utils.py

+44
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from vllm.logger import init_logger
1414
from vllm.model_executor.models.utils import extract_layer_index
15+
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
16+
usage_message)
1517
from vllm.utils import get_mp_context, kill_process_tree
1618

1719
if TYPE_CHECKING:
@@ -201,3 +203,45 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
201203
Returns the sliced target tensor.
202204
"""
203205
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
206+
207+
208+
def report_usage_stats(vllm_config, usage_context: UsageContext) -> None:
209+
"""Report usage statistics if enabled."""
210+
211+
if not is_usage_stats_enabled():
212+
return
213+
214+
from vllm.model_executor.model_loader import get_architecture_class_name
215+
216+
usage_message.report_usage(
217+
get_architecture_class_name(vllm_config.model_config),
218+
usage_context,
219+
extra_kvs={
220+
# Common configuration
221+
"dtype":
222+
str(vllm_config.model_config.dtype),
223+
"tensor_parallel_size":
224+
vllm_config.parallel_config.tensor_parallel_size,
225+
"block_size":
226+
vllm_config.cache_config.block_size,
227+
"gpu_memory_utilization":
228+
vllm_config.cache_config.gpu_memory_utilization,
229+
230+
# Quantization
231+
"quantization":
232+
vllm_config.model_config.quantization,
233+
"kv_cache_dtype":
234+
str(vllm_config.cache_config.cache_dtype),
235+
236+
# Feature flags
237+
"enable_lora":
238+
bool(vllm_config.lora_config),
239+
"enable_prompt_adapter":
240+
bool(vllm_config.prompt_adapter_config),
241+
"enable_prefix_caching":
242+
vllm_config.cache_config.enable_prefix_caching,
243+
"enforce_eager":
244+
vllm_config.model_config.enforce_eager,
245+
"disable_custom_all_reduce":
246+
vllm_config.parallel_config.disable_custom_all_reduce,
247+
})

0 commit comments

Comments
 (0)