Skip to content

Commit a5450f1

Browse files
[Security] Use safe serialization and fix zmq setup for mooncake pipe (#17192)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Co-authored-by: Shangming Cai <caishangming@linux.alibaba.com>
1 parent 9d98ab5 commit a5450f1

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import os
5+
import struct
56
from concurrent.futures import ThreadPoolExecutor
67
from dataclasses import dataclass
78
from typing import Optional, Union
@@ -115,14 +116,14 @@ def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
115116
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
116117
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
117118
if kv_rank == 0:
118-
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
119+
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
119120
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
120121
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
121-
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
122+
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
122123
else:
123124
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
124-
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
125-
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
125+
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
126+
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
126127
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
127128

128129
def initialize(self, local_hostname: str, metadata_server: str,
@@ -176,7 +177,7 @@ def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
176177

177178
def wait_for_ack(self, src_ptr: int, length: int) -> None:
178179
"""Asynchronously wait for ACK from the receiver."""
179-
ack = self.sender_ack.recv_pyobj()
180+
ack = self.sender_ack.recv()
180181
if ack != b'ACK':
181182
logger.error("Failed to receive ACK from the receiver")
182183

@@ -187,18 +188,22 @@ def send_bytes(self, user_data: bytes) -> None:
187188
length = len(user_data)
188189
src_ptr = self.allocate_managed_buffer(length)
189190
self.write_bytes_to_buffer(src_ptr, user_data, length)
190-
self.sender_socket.send_pyobj((src_ptr, length))
191+
self.sender_socket.send_multipart(
192+
[struct.pack("!Q", src_ptr),
193+
struct.pack("!Q", length)])
191194
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
192195

193196
def recv_bytes(self) -> bytes:
194197
"""Receive bytes from the remote process."""
195-
src_ptr, length = self.receiver_socket.recv_pyobj()
198+
data = self.receiver_socket.recv_multipart()
199+
src_ptr = struct.unpack("!Q", data[0])[0]
200+
length = struct.unpack("!Q", data[1])[0]
196201
dst_ptr = self.allocate_managed_buffer(length)
197202
self.transfer_sync(dst_ptr, src_ptr, length)
198203
ret = self.read_bytes_from_buffer(dst_ptr, length)
199204

200205
# Buffer cleanup
201-
self.receiver_ack.send_pyobj(b'ACK')
206+
self.receiver_ack.send(b'ACK')
202207
self.free_managed_buffer(dst_ptr, length)
203208

204209
return ret

0 commit comments

Comments
 (0)