|
5 | 5 |
|
6 | 6 | import msgspec
|
7 | 7 | import numpy as np
|
| 8 | +import pytest |
8 | 9 | import torch
|
9 | 10 |
|
10 | 11 | from vllm.multimodal.inputs import (MultiModalBatchedField,
|
@@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
196 | 197 | assert torch.equal(obj1.large_non_contig_tensor,
|
197 | 198 | obj2.large_non_contig_tensor)
|
198 | 199 | assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
| 200 | + |
| 201 | + |
| 202 | +@pytest.mark.parametrize("allow_pickle", [True, False]) |
| 203 | +def test_dict_serialization(allow_pickle: bool): |
| 204 | + """Test encoding and decoding of a generic Python object using pickle.""" |
| 205 | + encoder = MsgpackEncoder(allow_pickle=allow_pickle) |
| 206 | + decoder = MsgpackDecoder(allow_pickle=allow_pickle) |
| 207 | + |
| 208 | + # Create a sample Python object |
| 209 | + obj = {"key": "value", "number": 42} |
| 210 | + |
| 211 | + # Encode the object |
| 212 | + encoded = encoder.encode(obj) |
| 213 | + |
| 214 | + # Decode the object |
| 215 | + decoded = decoder.decode(encoded) |
| 216 | + |
| 217 | + # Verify the decoded object matches the original |
| 218 | + assert obj == decoded, "Decoded object does not match the original object." |
| 219 | + |
| 220 | + |
| 221 | +@pytest.mark.parametrize("allow_pickle", [True, False]) |
| 222 | +def test_tensor_serialization(allow_pickle: bool): |
| 223 | + """Test encoding and decoding of a torch.Tensor.""" |
| 224 | + encoder = MsgpackEncoder(allow_pickle=allow_pickle) |
| 225 | + decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle) |
| 226 | + |
| 227 | + # Create a sample tensor |
| 228 | + tensor = torch.rand(10, 10) |
| 229 | + |
| 230 | + # Encode the tensor |
| 231 | + encoded = encoder.encode(tensor) |
| 232 | + |
| 233 | + # Decode the tensor |
| 234 | + decoded = decoder.decode(encoded) |
| 235 | + |
| 236 | + # Verify the decoded tensor matches the original |
| 237 | + assert torch.allclose( |
| 238 | + tensor, decoded), "Decoded tensor does not match the original tensor." |
| 239 | + |
| 240 | + |
| 241 | +@pytest.mark.parametrize("allow_pickle", [True, False]) |
| 242 | +def test_numpy_array_serialization(allow_pickle: bool): |
| 243 | + """Test encoding and decoding of a numpy array.""" |
| 244 | + encoder = MsgpackEncoder(allow_pickle=allow_pickle) |
| 245 | + decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle) |
| 246 | + |
| 247 | + # Create a sample numpy array |
| 248 | + array = np.random.rand(10, 10) |
| 249 | + |
| 250 | + # Encode the numpy array |
| 251 | + encoded = encoder.encode(array) |
| 252 | + |
| 253 | + # Decode the numpy array |
| 254 | + decoded = decoder.decode(encoded) |
| 255 | + |
| 256 | + # Verify the decoded array matches the original |
| 257 | + assert np.allclose( |
| 258 | + array, |
| 259 | + decoded), "Decoded numpy array does not match the original array." |
| 260 | + |
| 261 | + |
| 262 | +class CustomClass: |
| 263 | + |
| 264 | + def __init__(self, value): |
| 265 | + self.value = value |
| 266 | + |
| 267 | + def __eq__(self, other): |
| 268 | + return isinstance(other, CustomClass) and self.value == other.value |
| 269 | + |
| 270 | + |
| 271 | +def test_custom_class_serialization_allowed_with_pickle(): |
| 272 | + """Test that serializing a custom class succeeds when allow_pickle=True.""" |
| 273 | + encoder = MsgpackEncoder(allow_pickle=True) |
| 274 | + decoder = MsgpackDecoder(CustomClass, allow_pickle=True) |
| 275 | + |
| 276 | + obj = CustomClass("test_value") |
| 277 | + |
| 278 | + # Encode the custom class |
| 279 | + encoded = encoder.encode(obj) |
| 280 | + |
| 281 | + # Decode the custom class |
| 282 | + decoded = decoder.decode(encoded) |
| 283 | + |
| 284 | + # Verify the decoded object matches the original |
| 285 | + assert obj == decoded, "Decoded object does not match the original object." |
| 286 | + |
| 287 | + |
| 288 | +def test_custom_class_serialization_disallowed_without_pickle(): |
| 289 | + """Test that serializing a custom class fails when allow_pickle=False.""" |
| 290 | + encoder = MsgpackEncoder(allow_pickle=False) |
| 291 | + |
| 292 | + obj = CustomClass("test_value") |
| 293 | + |
| 294 | + with pytest.raises(TypeError): |
| 295 | + # Attempt to encode the custom class |
| 296 | + encoder.encode(obj) |
0 commit comments