2
2
3
3
import json
4
4
import os
5
+ import struct
5
6
from concurrent .futures import ThreadPoolExecutor
6
7
from dataclasses import dataclass
7
8
from typing import Optional , Union
@@ -115,14 +116,14 @@ def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
115
116
p_rank_offset = int (p_port ) + 8 + self .local_rank * 2
116
117
d_rank_offset = int (d_port ) + 8 + self .local_rank * 2
117
118
if kv_rank == 0 :
118
- self .sender_socket .bind (f"tcp://* :{ p_rank_offset + 1 } " )
119
+ self .sender_socket .bind (f"tcp://{ p_host } :{ p_rank_offset + 1 } " )
119
120
self .receiver_socket .connect (f"tcp://{ d_host } :{ d_rank_offset + 1 } " )
120
121
self .sender_ack .connect (f"tcp://{ d_host } :{ d_rank_offset + 2 } " )
121
- self .receiver_ack .bind (f"tcp://* :{ p_rank_offset + 2 } " )
122
+ self .receiver_ack .bind (f"tcp://{ p_host } :{ p_rank_offset + 2 } " )
122
123
else :
123
124
self .receiver_socket .connect (f"tcp://{ p_host } :{ p_rank_offset + 1 } " )
124
- self .sender_socket .bind (f"tcp://* :{ d_rank_offset + 1 } " )
125
- self .receiver_ack .bind (f"tcp://* :{ d_rank_offset + 2 } " )
125
+ self .sender_socket .bind (f"tcp://{ d_host } :{ d_rank_offset + 1 } " )
126
+ self .receiver_ack .bind (f"tcp://{ d_host } :{ d_rank_offset + 2 } " )
126
127
self .sender_ack .connect (f"tcp://{ p_host } :{ p_rank_offset + 2 } " )
127
128
128
129
def initialize (self , local_hostname : str , metadata_server : str ,
@@ -176,7 +177,7 @@ def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
176
177
177
178
def wait_for_ack (self , src_ptr : int , length : int ) -> None :
178
179
"""Asynchronously wait for ACK from the receiver."""
179
- ack = self .sender_ack .recv_pyobj ()
180
+ ack = self .sender_ack .recv ()
180
181
if ack != b'ACK' :
181
182
logger .error ("Failed to receive ACK from the receiver" )
182
183
@@ -187,18 +188,22 @@ def send_bytes(self, user_data: bytes) -> None:
187
188
length = len (user_data )
188
189
src_ptr = self .allocate_managed_buffer (length )
189
190
self .write_bytes_to_buffer (src_ptr , user_data , length )
190
- self .sender_socket .send_pyobj ((src_ptr , length ))
191
+ self .sender_socket .send_multipart (
192
+ [struct .pack ("!Q" , src_ptr ),
193
+ struct .pack ("!Q" , length )])
191
194
self .buffer_cleaner .submit (self .wait_for_ack , src_ptr , length )
192
195
193
196
def recv_bytes (self ) -> bytes :
194
197
"""Receive bytes from the remote process."""
195
- src_ptr , length = self .receiver_socket .recv_pyobj ()
198
+ data = self .receiver_socket .recv_multipart ()
199
+ src_ptr = struct .unpack ("!Q" , data [0 ])[0 ]
200
+ length = struct .unpack ("!Q" , data [1 ])[0 ]
196
201
dst_ptr = self .allocate_managed_buffer (length )
197
202
self .transfer_sync (dst_ptr , src_ptr , length )
198
203
ret = self .read_bytes_from_buffer (dst_ptr , length )
199
204
200
205
# Buffer cleanup
201
- self .receiver_ack .send_pyobj (b'ACK' )
206
+ self .receiver_ack .send (b'ACK' )
202
207
self .free_managed_buffer (dst_ptr , length )
203
208
204
209
return ret
0 commit comments