Skip to content

Commit c129623

Browse files
gcalmettesliuzijing2014
authored andcommitted
[Bugfix] validate urls object for multimodal content parts (vllm-project#16990)
Signed-off-by: Guillaume Calmettes <gcalmettes@scaleway.com> Signed-off-by: Zijing Liu <liuzijing2014@gmail.com>
1 parent 5cca099 commit c129623

File tree

4 files changed

+94
-4
lines changed

4 files changed

+94
-4
lines changed

tests/entrypoints/openai/test_audio.py

+29
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,35 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
104104
assert message.content is not None and len(message.content) >= 0
105105

106106

107+
@pytest.mark.asyncio
108+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
109+
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])
110+
async def test_error_on_invalid_audio_url_type(client: openai.AsyncOpenAI,
111+
model_name: str,
112+
audio_url: str):
113+
messages = [{
114+
"role":
115+
"user",
116+
"content": [
117+
{
118+
"type": "audio_url",
119+
"audio_url": audio_url
120+
},
121+
{
122+
"type": "text",
123+
"text": "What's happening in this audio?"
124+
},
125+
],
126+
}]
127+
128+
# audio_url should be a dict {"url": "some url"}, not directly a string
129+
with pytest.raises(openai.BadRequestError):
130+
_ = await client.chat.completions.create(model=model_name,
131+
messages=messages,
132+
max_completion_tokens=10,
133+
temperature=0.0)
134+
135+
107136
@pytest.mark.asyncio
108137
@pytest.mark.parametrize("model_name", [MODEL_NAME])
109138
@pytest.mark.parametrize("audio_url", [TEST_AUDIO_URLS[0]])

tests/entrypoints/openai/test_video.py

+29
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,35 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
108108
assert message.content is not None and len(message.content) >= 0
109109

110110

111+
@pytest.mark.asyncio
112+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
113+
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
114+
async def test_error_on_invalid_video_url_type(client: openai.AsyncOpenAI,
115+
model_name: str,
116+
video_url: str):
117+
messages = [{
118+
"role":
119+
"user",
120+
"content": [
121+
{
122+
"type": "video_url",
123+
"video_url": video_url
124+
},
125+
{
126+
"type": "text",
127+
"text": "What's in this video?"
128+
},
129+
],
130+
}]
131+
132+
# video_url should be a dict {"url": "some url"}, not directly a string
133+
with pytest.raises(openai.BadRequestError):
134+
_ = await client.chat.completions.create(model=model_name,
135+
messages=messages,
136+
max_completion_tokens=10,
137+
temperature=0.0)
138+
139+
111140
@pytest.mark.asyncio
112141
@pytest.mark.parametrize("model_name", [MODEL_NAME])
113142
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)

tests/entrypoints/openai/test_vision.py

+30
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,36 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI,
137137
assert message.content is not None and len(message.content) >= 0
138138

139139

140+
@pytest.mark.asyncio
141+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
142+
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
143+
async def test_error_on_invalid_image_url_type(client: openai.AsyncOpenAI,
144+
model_name: str,
145+
image_url: str):
146+
content_text = "What's in this image?"
147+
messages = [{
148+
"role":
149+
"user",
150+
"content": [
151+
{
152+
"type": "image_url",
153+
"image_url": image_url
154+
},
155+
{
156+
"type": "text",
157+
"text": content_text
158+
},
159+
],
160+
}]
161+
162+
# image_url should be a dict {"url": "some url"}, not directly a string
163+
with pytest.raises(openai.BadRequestError):
164+
_ = await client.chat.completions.create(model=model_name,
165+
messages=messages,
166+
max_completion_tokens=10,
167+
temperature=0.0)
168+
169+
140170
@pytest.mark.asyncio
141171
@pytest.mark.parametrize("model_name", [MODEL_NAME])
142172
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)

vllm/entrypoints/chat_utils.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
ChatCompletionToolMessageParam)
2828
from openai.types.chat.chat_completion_content_part_input_audio_param import (
2929
InputAudio)
30+
from pydantic import TypeAdapter
3031
# yapf: enable
31-
# pydantic needs the TypedDict from typing_extensions
3232
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
3333
ProcessorMixin)
34+
# pydantic needs the TypedDict from typing_extensions
3435
from typing_extensions import Required, TypeAlias, TypedDict
3536

3637
from vllm.config import ModelConfig
@@ -879,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
879880

880881
# No need to validate using Pydantic again
881882
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
882-
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
883883
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
884-
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
885884
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
886885
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
887-
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
886+
# Need to validate url objects
887+
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
888+
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
889+
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
888890

889891
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
890892

0 commit comments

Comments
 (0)