Skip to content

[V1][Metrics] add support for kv event publishing #16750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/online_serving/kv_events.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/bin/bash
# This file demonstrates the KV cache event publishing
# We will launch a vllm instances configured to publish KV cache
# events and launch a simple subscriber to log those events.

set -xe

echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
sleep 1

MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}

# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT

# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}

export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')

# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}

vllm serve $MODEL_NAME \
--port 8100 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-events-config \
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &

wait_for_server 8100

SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"

python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
sleep 1

# serve two example requests
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')

output2=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')

# Cleanup commands
pkill -9 -u "$USER" -f python
pkill -9 -u "$USER" -f vllm

sleep 1

echo "Cleaned up"

# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"

echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""
67 changes: 67 additions & 0 deletions examples/online_serving/kv_events_subscriber.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union

import msgspec
import zmq


#
# Types copied from vllm.distributed.kv_events
#
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
gc=False):
ts: float
events: list[Any]


class KVCacheEvent(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False,
tag=True):
"""Base class for all KV cache-related events"""


class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
num_toks_per_block: list[int]
lora_id: Optional[int]


class BlockRemoved(KVCacheEvent):
block_hashes: list[int]


class AllBlocksCleared(KVCacheEvent):
pass


class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]


decoder = msgspec.msgpack.Decoder(type=KVEventBatch)

context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.connect("tcp://localhost:5557")
topic = "kv-events"
socket.setsockopt_string(zmq.SUBSCRIBE, topic)

print("Listening for KV cache events on topic:", topic)

while True:
try:
_, seq_bytes, payload = socket.recv_multipart()
seq = int.from_bytes(seq_bytes, "big")
event_batch = decoder.decode(payload)
print(f"Received event batch at {event_batch.ts}:")
for event in event_batch.events:
print(f" - {event}")
except KeyboardInterrupt:
print("Interrupted")
break
except Exception as e:
print("Error decoding message:", e)
146 changes: 146 additions & 0 deletions tests/distributed/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Optional, Union

import msgspec
import msgspec.msgpack
import pytest
import zmq

from vllm.config import KVEventsConfig
from vllm.distributed.kv_events import EventPublisherFactory

from .test_events import SampleBatch


@pytest.fixture
def random_port():
"""Generate a random port number for testing"""
return random.randint(10000, 60000)


@pytest.fixture
def publisher_config(random_port, request):
"""Create a publisher config with inproc transport"""
how = request.param if hasattr(request, "param") else "inproc"

if how == "inproc":
endpoint = f"inproc://test-{random_port}"
replay_endpoint = endpoint + "-replay"
else:
endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}"

return KVEventsConfig(enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test")


@pytest.fixture
def publisher(publisher_config):
"""Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config)
yield pub
pub.close()


@pytest.fixture
def subscriber(publisher_config):
"""Create and return a subscriber for testing"""
endpoint = publisher_config.endpoint
replay_endpoint = publisher_config.replay_endpoint

if endpoint.startswith("tcp://*"):
endpoint = endpoint.replace("*", "127.0.0.1")
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")

sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
yield sub
sub.close()


class MockSubscriber:
"""Helper class to receive and verify published events"""

def __init__(self,
pub_endpoint: str,
replay_endpoint: Optional[str] = None,
topic: str = "",
decode_type=SampleBatch):
self.ctx = zmq.Context.instance()

# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
self.sub.connect(pub_endpoint)

# Set up replay socket if provided
self.replay = None
if replay_endpoint:
self.replay = self.ctx.socket(zmq.REQ)
self.replay.connect(replay_endpoint)

self.topic = topic
self.topic_bytes = topic.encode('utf-8')
self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type)

def receive_one(self,
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
"""Receive a single message with timeout"""
print(f"self.sub: {self.sub}")
if not self.sub.poll(timeout):
return None

topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
assert topic_bytes == self.topic_bytes

seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
self.last_seq = seq
self.received_msgs.append((seq, data))
return seq, data

def request_replay(self, start_seq: int) -> None:
"""Request replay of messages starting from start_seq"""
if not self.replay:
raise ValueError("Replay socket not initialized")

self.replay.send(start_seq.to_bytes(8, "big"))

def receive_replay(self) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages"""
if not self.replay:
raise ValueError("Replay socket not initialized")

replayed: list[tuple[int, SampleBatch]] = []
while True:
try:
if not self.replay.poll(1000):
break

frames = self.replay.recv_multipart()
if not frames or (len(frames) == 1 and not frames[0]):
# End of replay marker
break

seq_bytes, payload = frames
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
replayed.append((seq, data))
except zmq.ZMQError as _:
break

return replayed

def close(self):
"""Clean up resources"""
self.sub.close()
if self.replay:
self.replay.close()
Loading