diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 46670026491..aa4b1ba7149 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -2,6 +2,7 @@ import json import os +import struct from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Optional, Union @@ -115,14 +116,14 @@ def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str, p_rank_offset = int(p_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2 if kv_rank == 0: - self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") + self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") - self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") + self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}") else: self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") - self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") - self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") + self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}") + self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") def initialize(self, local_hostname: str, metadata_server: str, @@ -176,7 +177,7 @@ def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: def wait_for_ack(self, src_ptr: int, length: int) -> None: """Asynchronously wait for ACK from the receiver.""" - ack = self.sender_ack.recv_pyobj() + ack = self.sender_ack.recv() if ack != b'ACK': logger.error("Failed to receive ACK from the receiver") @@ -187,18 +188,22 @@ def send_bytes(self, user_data: bytes) -> None: length = len(user_data) src_ptr = self.allocate_managed_buffer(length) self.write_bytes_to_buffer(src_ptr, user_data, length) - self.sender_socket.send_pyobj((src_ptr, length)) + self.sender_socket.send_multipart( + [struct.pack("!Q", src_ptr), + struct.pack("!Q", length)]) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) def recv_bytes(self) -> bytes: """Receive bytes from the remote process.""" - src_ptr, length = self.receiver_socket.recv_pyobj() + data = self.receiver_socket.recv_multipart() + src_ptr = struct.unpack("!Q", data[0])[0] + length = struct.unpack("!Q", data[1])[0] dst_ptr = self.allocate_managed_buffer(length) self.transfer_sync(dst_ptr, src_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length) # Buffer cleanup - self.receiver_ack.send_pyobj(b'ACK') + self.receiver_ack.send(b'ACK') self.free_managed_buffer(dst_ptr, length) return ret