Skip to content

Commit 947f2f5

Browse files
[V1] Allow turning off pickle fallback in vllm.v1.serial_utils (#17427)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 739e03b commit 947f2f5

File tree

2 files changed

+113
-6
lines changed

2 files changed

+113
-6
lines changed

tests/v1/test_serial_utils.py

+98
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import msgspec
77
import numpy as np
8+
import pytest
89
import torch
910

1011
from vllm.multimodal.inputs import (MultiModalBatchedField,
@@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType):
196197
assert torch.equal(obj1.large_non_contig_tensor,
197198
obj2.large_non_contig_tensor)
198199
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
200+
201+
202+
@pytest.mark.parametrize("allow_pickle", [True, False])
203+
def test_dict_serialization(allow_pickle: bool):
204+
"""Test encoding and decoding of a generic Python object using pickle."""
205+
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
206+
decoder = MsgpackDecoder(allow_pickle=allow_pickle)
207+
208+
# Create a sample Python object
209+
obj = {"key": "value", "number": 42}
210+
211+
# Encode the object
212+
encoded = encoder.encode(obj)
213+
214+
# Decode the object
215+
decoded = decoder.decode(encoded)
216+
217+
# Verify the decoded object matches the original
218+
assert obj == decoded, "Decoded object does not match the original object."
219+
220+
221+
@pytest.mark.parametrize("allow_pickle", [True, False])
222+
def test_tensor_serialization(allow_pickle: bool):
223+
"""Test encoding and decoding of a torch.Tensor."""
224+
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
225+
decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle)
226+
227+
# Create a sample tensor
228+
tensor = torch.rand(10, 10)
229+
230+
# Encode the tensor
231+
encoded = encoder.encode(tensor)
232+
233+
# Decode the tensor
234+
decoded = decoder.decode(encoded)
235+
236+
# Verify the decoded tensor matches the original
237+
assert torch.allclose(
238+
tensor, decoded), "Decoded tensor does not match the original tensor."
239+
240+
241+
@pytest.mark.parametrize("allow_pickle", [True, False])
242+
def test_numpy_array_serialization(allow_pickle: bool):
243+
"""Test encoding and decoding of a numpy array."""
244+
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
245+
decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle)
246+
247+
# Create a sample numpy array
248+
array = np.random.rand(10, 10)
249+
250+
# Encode the numpy array
251+
encoded = encoder.encode(array)
252+
253+
# Decode the numpy array
254+
decoded = decoder.decode(encoded)
255+
256+
# Verify the decoded array matches the original
257+
assert np.allclose(
258+
array,
259+
decoded), "Decoded numpy array does not match the original array."
260+
261+
262+
class CustomClass:
263+
264+
def __init__(self, value):
265+
self.value = value
266+
267+
def __eq__(self, other):
268+
return isinstance(other, CustomClass) and self.value == other.value
269+
270+
271+
def test_custom_class_serialization_allowed_with_pickle():
272+
"""Test that serializing a custom class succeeds when allow_pickle=True."""
273+
encoder = MsgpackEncoder(allow_pickle=True)
274+
decoder = MsgpackDecoder(CustomClass, allow_pickle=True)
275+
276+
obj = CustomClass("test_value")
277+
278+
# Encode the custom class
279+
encoded = encoder.encode(obj)
280+
281+
# Decode the custom class
282+
decoded = decoder.decode(encoded)
283+
284+
# Verify the decoded object matches the original
285+
assert obj == decoded, "Decoded object does not match the original object."
286+
287+
288+
def test_custom_class_serialization_disallowed_without_pickle():
289+
"""Test that serializing a custom class fails when allow_pickle=False."""
290+
encoder = MsgpackEncoder(allow_pickle=False)
291+
292+
obj = CustomClass("test_value")
293+
294+
with pytest.raises(TypeError):
295+
# Attempt to encode the custom class
296+
encoder.encode(obj)

vllm/v1/serial_utils.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class MsgpackEncoder:
4747
via dedicated messages. Note that this is a per-tensor limit.
4848
"""
4949

50-
def __init__(self, size_threshold: Optional[int] = None):
50+
def __init__(self,
51+
size_threshold: Optional[int] = None,
52+
allow_pickle: bool = True):
5153
if size_threshold is None:
5254
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
5355
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
@@ -56,6 +58,7 @@ def __init__(self, size_threshold: Optional[int] = None):
5658
# pass custom data to the hook otherwise.
5759
self.aux_buffers: Optional[list[bytestr]] = None
5860
self.size_threshold = size_threshold
61+
self.allow_pickle = allow_pickle
5962

6063
def encode(self, obj: Any) -> Sequence[bytestr]:
6164
try:
@@ -105,6 +108,9 @@ def enc_hook(self, obj: Any) -> Any:
105108
for itemlist in mm._items_by_modality.values()
106109
for item in itemlist]
107110

111+
if not self.allow_pickle:
112+
raise TypeError(f"Object of type {type(obj)} is not serializable")
113+
108114
if isinstance(obj, FunctionType):
109115
# `pickle` is generally faster than cloudpickle, but can have
110116
# problems serializing methods.
@@ -179,12 +185,13 @@ class MsgpackDecoder:
179185
not thread-safe when encoding tensors / numpy arrays.
180186
"""
181187

182-
def __init__(self, t: Optional[Any] = None):
188+
def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True):
183189
args = () if t is None else (t, )
184190
self.decoder = msgpack.Decoder(*args,
185191
ext_hook=self.ext_hook,
186192
dec_hook=self.dec_hook)
187193
self.aux_buffers: Sequence[bytestr] = ()
194+
self.allow_pickle = allow_pickle
188195

189196
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
190197
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
@@ -265,10 +272,12 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
265272
def ext_hook(self, code: int, data: memoryview) -> Any:
266273
if code == CUSTOM_TYPE_RAW_VIEW:
267274
return data
268-
if code == CUSTOM_TYPE_PICKLE:
269-
return pickle.loads(data)
270-
if code == CUSTOM_TYPE_CLOUDPICKLE:
271-
return cloudpickle.loads(data)
275+
276+
if self.allow_pickle:
277+
if code == CUSTOM_TYPE_PICKLE:
278+
return pickle.loads(data)
279+
if code == CUSTOM_TYPE_CLOUDPICKLE:
280+
return cloudpickle.loads(data)
272281

273282
raise NotImplementedError(
274283
f"Extension type code {code} is not supported")

0 commit comments

Comments
 (0)