Skip to content

[Bugfix] Fix mistral model tests #17181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import jsonschema.exceptions
import pytest

from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall, MistralToolParser)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close
Expand Down Expand Up @@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
)


@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
Expand Down Expand Up @@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str,
assert "�" not in outputs[0].outputs[0].text.strip()


@pytest.mark.skip("RE-ENABLE: test is currently failing on main.")
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model",
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
with vllm_runner(model,
dtype=dtype,
Expand All @@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
parsed_message = tool_parser.extract_tool_calls(model_output, None)

assert parsed_message.tools_called
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"

assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id)
assert parsed_message.tool_calls[
0].function.name == "get_current_weather"
assert parsed_message.tool_calls[
Expand All @@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(vllm_runner, model: str,
guided_backend: str) -> None:
with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:
def test_mistral_guided_decoding(
monkeypatch: pytest.MonkeyPatch,
vllm_runner,
model: str,
guided_backend: str,
) -> None:
with monkeypatch.context() as m:
# Guided JSON not supported in xgrammar + V1 yet
m.setenv("VLLM_USE_V1", "0")

guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA,
backend=guided_backend)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)
with vllm_runner(
model,
dtype='bfloat16',
tokenizer_mode="mistral",
guided_decoding_backend=guided_backend,
) as vllm_model:
guided_decoding = GuidedDecodingParams(json=SAMPLE_JSON_SCHEMA)
params = SamplingParams(max_tokens=512,
temperature=0.7,
guided_decoding=guided_decoding)

messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
f"Give an example JSON for an employee profile that "
f"fits this schema: {SAMPLE_JSON_SCHEMA}"
}]
outputs = vllm_model.model.chat(messages, sampling_params=params)

generated_text = outputs[0].outputs[0].text
json_response = json.loads(generated_text)
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def generate_random_id():
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))

@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9


@ToolParserManager.register_module("mistral")
class MistralToolParser(ToolParser):
Expand Down