Skip to content

Commit d0e5f74

Browse files
DarkLight1337liuzijing2014
authored andcommitted
[Bugfix] Fix mistral model tests (vllm-project#17181)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 15fcf76 commit d0e5f74

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

tests/models/decoder_only/language/test_mistral.py

+36-28
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import jsonschema.exceptions
1111
import pytest
1212

13-
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
14-
MistralToolParser)
13+
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
14+
MistralToolCall, MistralToolParser)
1515
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1616

1717
from ...utils import check_logprobs_close
@@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
194194
)
195195

196196

197-
@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
198197
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
199198
@pytest.mark.parametrize("dtype", ["bfloat16"])
200199
@pytest.mark.parametrize("max_tokens", [64])
@@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str,
246245
assert "�" not in outputs[0].outputs[0].text.strip()
247246

248247

249-
@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
248+
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
250249
@pytest.mark.parametrize("dtype", ["bfloat16"])
251-
@pytest.mark.parametrize("model",
252-
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
253250
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
254251
with vllm_runner(model,
255252
dtype=dtype,
@@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
270267
parsed_message = tool_parser.extract_tool_calls(model_output, None)
271268

272269
assert parsed_message.tools_called
273-
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
270+
271+
assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id)
274272
assert parsed_message.tool_calls[
275273
0].function.name == "get_current_weather"
276274
assert parsed_message.tool_calls[
@@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
281279
@pytest.mark.parametrize("model", MODELS)
282280
@pytest.mark.parametrize("guided_backend",
283281
["outlines", "lm-format-enforcer", "xgrammar"])
284-
def test_mistral_guided_decoding(vllm_runner, model: str,
285-
guided_backend: str) -> None:
286-
with vllm_runner(model, dtype='bfloat16',
287-
tokenizer_mode="mistral") as vllm_model:
282+
def test_mistral_guided_decoding(
283+
monkeypatch: pytest.MonkeyPatch,
284+
vllm_runner,
285+
model: str,
286+
guided_backend: str,
287+
) -> None:
288+
with monkeypatch.context() as m:
289+
# Guided JSON not supported in xgrammar + V1 yet
290+
m.setenv("VLLM_USE_V1", "0")
288291

289-
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
290-
backend=guided_backend)
291-
params = SamplingParams(max_tokens=512,
292-
temperature=0.7,
293-
guided_decoding=guided_decoding)
294-
295-
messages = [{
296-
"role": "system",
297-
"content": "you are a helpful assistant"
298-
}, {
299-
"role":
300-
"user",
301-
"content":
302-
f"Give an example JSON for an employee profile that "
303-
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
304-
}]
305-
outputs = vllm_model.model.chat(messages, sampling_params=params)
292+
with vllm_runner(
293+
model,
294+
dtype='bfloat16',
295+
tokenizer_mode="mistral",
296+
guided_decoding_backend=guided_backend,
297+
) as vllm_model:
298+
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
299+
params = SamplingParams(max_tokens=512,
300+
temperature=0.7,
301+
guided_decoding=guided_decoding)
302+
303+
messages = [{
304+
"role": "system",
305+
"content": "you are a helpful assistant"
306+
}, {
307+
"role":
308+
"user",
309+
"content":
310+
f"Give an example JSON for an employee profile that "
311+
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
312+
}]
313+
outputs = vllm_model.model.chat(messages, sampling_params=params)
306314

307315
generated_text = outputs[0].outputs[0].text
308316
json_response = json.loads(generated_text)

vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def generate_random_id():
3838
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
3939
return "".join(choices(ALPHANUMERIC, k=9))
4040

41+
@staticmethod
42+
def is_valid_id(id: str) -> bool:
43+
return id.isalnum() and len(id) == 9
44+
4145

4246
@ToolParserManager.register_module("mistral")
4347
class MistralToolParser(ToolParser):

0 commit comments

Comments
 (0)