10
10
import jsonschema .exceptions
11
11
import pytest
12
12
13
- from vllm .entrypoints .openai .tool_parsers .mistral_tool_parser import ( # noqa
14
- MistralToolParser )
13
+ from vllm .entrypoints .openai .tool_parsers .mistral_tool_parser import (
14
+ MistralToolCall , MistralToolParser )
15
15
from vllm .sampling_params import GuidedDecodingParams , SamplingParams
16
16
17
17
from ...utils import check_logprobs_close
@@ -194,7 +194,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
194
194
)
195
195
196
196
197
- @pytest .mark .skip ("RE-ENABLE: test is currently failing on main." )
198
197
@pytest .mark .parametrize ("model" , MISTRAL_FORMAT_MODELS )
199
198
@pytest .mark .parametrize ("dtype" , ["bfloat16" ])
200
199
@pytest .mark .parametrize ("max_tokens" , [64 ])
@@ -246,10 +245,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str,
246
245
assert "�" not in outputs [0 ].outputs [0 ].text .strip ()
247
246
248
247
249
- @pytest .mark .skip ( "RE-ENABLE: test is currently failing on main." )
248
+ @pytest .mark .parametrize ( "model" , MISTRAL_FORMAT_MODELS )
250
249
@pytest .mark .parametrize ("dtype" , ["bfloat16" ])
251
- @pytest .mark .parametrize ("model" ,
252
- MISTRAL_FORMAT_MODELS ) # v1 can't do func calling
253
250
def test_mistral_function_calling (vllm_runner , model : str , dtype : str ) -> None :
254
251
with vllm_runner (model ,
255
252
dtype = dtype ,
@@ -270,7 +267,8 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
270
267
parsed_message = tool_parser .extract_tool_calls (model_output , None )
271
268
272
269
assert parsed_message .tools_called
273
- assert parsed_message .tool_calls [0 ].id == "0UAqFzWsD"
270
+
271
+ assert MistralToolCall .is_valid_id (parsed_message .tool_calls [0 ].id )
274
272
assert parsed_message .tool_calls [
275
273
0 ].function .name == "get_current_weather"
276
274
assert parsed_message .tool_calls [
@@ -281,28 +279,38 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
281
279
@pytest .mark .parametrize ("model" , MODELS )
282
280
@pytest .mark .parametrize ("guided_backend" ,
283
281
["outlines" , "lm-format-enforcer" , "xgrammar" ])
284
- def test_mistral_guided_decoding (vllm_runner , model : str ,
285
- guided_backend : str ) -> None :
286
- with vllm_runner (model , dtype = 'bfloat16' ,
287
- tokenizer_mode = "mistral" ) as vllm_model :
282
+ def test_mistral_guided_decoding (
283
+ monkeypatch : pytest .MonkeyPatch ,
284
+ vllm_runner ,
285
+ model : str ,
286
+ guided_backend : str ,
287
+ ) -> None :
288
+ with monkeypatch .context () as m :
289
+ # Guided JSON not supported in xgrammar + V1 yet
290
+ m .setenv ("VLLM_USE_V1" , "0" )
288
291
289
- guided_decoding = GuidedDecodingParams (json = SAMPLE_JSON_SCHEMA ,
290
- backend = guided_backend )
291
- params = SamplingParams (max_tokens = 512 ,
292
- temperature = 0.7 ,
293
- guided_decoding = guided_decoding )
294
-
295
- messages = [{
296
- "role" : "system" ,
297
- "content" : "you are a helpful assistant"
298
- }, {
299
- "role" :
300
- "user" ,
301
- "content" :
302
- f"Give an example JSON for an employee profile that "
303
- f"fits this schema: { SAMPLE_JSON_SCHEMA } "
304
- }]
305
- outputs = vllm_model .model .chat (messages , sampling_params = params )
292
+ with vllm_runner (
293
+ model ,
294
+ dtype = 'bfloat16' ,
295
+ tokenizer_mode = "mistral" ,
296
+ guided_decoding_backend = guided_backend ,
297
+ ) as vllm_model :
298
+ guided_decoding = GuidedDecodingParams (json = SAMPLE_JSON_SCHEMA )
299
+ params = SamplingParams (max_tokens = 512 ,
300
+ temperature = 0.7 ,
301
+ guided_decoding = guided_decoding )
302
+
303
+ messages = [{
304
+ "role" : "system" ,
305
+ "content" : "you are a helpful assistant"
306
+ }, {
307
+ "role" :
308
+ "user" ,
309
+ "content" :
310
+ f"Give an example JSON for an employee profile that "
311
+ f"fits this schema: { SAMPLE_JSON_SCHEMA } "
312
+ }]
313
+ outputs = vllm_model .model .chat (messages , sampling_params = params )
306
314
307
315
generated_text = outputs [0 ].outputs [0 ].text
308
316
json_response = json .loads (generated_text )
0 commit comments