|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
| 3 | +import json |
3 | 4 | from dataclasses import dataclass, field
|
4 |
| -from typing import TYPE_CHECKING |
| 5 | +from typing import TYPE_CHECKING, Any |
5 | 6 |
|
6 | 7 | import torch
|
7 | 8 |
|
8 | 9 | import vllm.envs
|
9 | 10 | from vllm.config import VllmConfig
|
10 | 11 | from vllm.logger import init_logger
|
| 12 | +from vllm.sampling_params import SamplingParams |
11 | 13 | from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
12 | 14 | from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
13 | 15 | from vllm.utils import LazyLoader
|
14 | 16 | from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
15 | 17 | StructuredOutputGrammar,
|
16 | 18 | StructuredOutputOptions)
|
| 19 | +from vllm.v1.structured_output.utils import (choice_as_grammar, |
| 20 | + convert_lark_to_ebnf, |
| 21 | + grammar_is_likely_lark) |
17 | 22 |
|
18 | 23 | if TYPE_CHECKING:
|
19 | 24 | import xgrammar as xgr
|
@@ -156,3 +161,112 @@ def is_terminated(self) -> bool:
|
156 | 161 | def reset(self):
|
157 | 162 | self.num_processed_tokens = 0
|
158 | 163 | self.matcher.reset()
|
| 164 | + |
| 165 | + |
| 166 | +def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool: |
| 167 | + """Check if JSON schema contains features unsupported by xgrammar.""" |
| 168 | + |
| 169 | + def check_object(obj: dict[str, Any]) -> bool: |
| 170 | + if not isinstance(obj, dict): |
| 171 | + return False |
| 172 | + |
| 173 | + # Check for pattern restrictions |
| 174 | + if "pattern" in obj: |
| 175 | + return True |
| 176 | + |
| 177 | + # Check for numeric ranges |
| 178 | + if obj.get("type") in ("integer", "number") and any( |
| 179 | + key in obj |
| 180 | + for key in ("minimum", "maximum", "exclusiveMinimum", |
| 181 | + "exclusiveMaximum", "multipleOf")): |
| 182 | + return True |
| 183 | + |
| 184 | + # Check for array unsupported keywords |
| 185 | + if obj.get("type") == "array" and any( |
| 186 | + key in obj |
| 187 | + for key in ("uniqueItems", "contains", "minContains", |
| 188 | + "maxContains", "minItems", "maxItems")): |
| 189 | + return True |
| 190 | + |
| 191 | + # Unsupported keywords for strings |
| 192 | + if obj.get("type") == "string" and "format" in obj: |
| 193 | + return True |
| 194 | + |
| 195 | + # Unsupported keywords for objects |
| 196 | + if obj.get("type") == "object" and any( |
| 197 | + key in obj for key in ("minProperties", "maxProperties", |
| 198 | + "propertyNames", "patternProperties")): |
| 199 | + return True |
| 200 | + |
| 201 | + # Recursively check all nested objects and arrays |
| 202 | + for value in obj.values(): |
| 203 | + if isinstance(value, dict): |
| 204 | + if check_object(value): |
| 205 | + return True |
| 206 | + elif isinstance(value, list): |
| 207 | + for item in value: |
| 208 | + if isinstance(item, dict) and check_object(item): |
| 209 | + return True |
| 210 | + |
| 211 | + return False |
| 212 | + |
| 213 | + return check_object(schema) |
| 214 | + |
| 215 | + |
| 216 | +def validate_xgrammar_grammar(sampling_params: SamplingParams) -> None: |
| 217 | + """Validate that the request is supported by structured output. |
| 218 | +
|
| 219 | + Raises ValueError if the request is not supported. |
| 220 | + """ |
| 221 | + if sampling_params.guided_decoding is None: |
| 222 | + return |
| 223 | + |
| 224 | + gd_params = sampling_params.guided_decoding |
| 225 | + |
| 226 | + if gd_params.regex: |
| 227 | + try: |
| 228 | + xgr.Grammar.from_regex(gd_params.regex) |
| 229 | + except Exception as err: |
| 230 | + raise ValueError("Failed to transform regex into a grammar: " |
| 231 | + f"{err}") from err |
| 232 | + |
| 233 | + if gd_params.choice: |
| 234 | + choice_grammar = choice_as_grammar(gd_params.choice) |
| 235 | + try: |
| 236 | + xgr.Grammar.from_ebnf(choice_grammar) |
| 237 | + except Exception as err: |
| 238 | + raise ValueError("Failed to transform choices into a grammar: " |
| 239 | + "{err}") from err |
| 240 | + gd_params.choice = None |
| 241 | + gd_params.grammar = choice_grammar |
| 242 | + return |
| 243 | + |
| 244 | + if gd_params.json: |
| 245 | + if isinstance(gd_params.json, str): |
| 246 | + try: |
| 247 | + schema = json.loads(gd_params.json) |
| 248 | + except json.JSONDecodeError as e: |
| 249 | + raise ValueError("Invalid JSON grammar specification.") from e |
| 250 | + else: |
| 251 | + schema = gd_params.json |
| 252 | + |
| 253 | + if has_xgrammar_unsupported_json_features(schema): |
| 254 | + raise ValueError("The provided JSON schema contains features not " |
| 255 | + "supported by xgrammar.") |
| 256 | + return |
| 257 | + |
| 258 | + if gd_params.grammar: |
| 259 | + if grammar_is_likely_lark(gd_params.grammar): |
| 260 | + # xgrammar supports EBNF grammars only |
| 261 | + try: |
| 262 | + gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar) |
| 263 | + except ValueError as e: |
| 264 | + raise ValueError( |
| 265 | + "Failed to convert the grammar from Lark to EBNF. ") from e |
| 266 | + |
| 267 | + # Test parsing EBNF grammar, possibly already converted from Lark |
| 268 | + try: |
| 269 | + # parse the grammar, but we aren't compiling it. |
| 270 | + xgr.Grammar.from_ebnf(gd_params.grammar) |
| 271 | + except Exception as e: |
| 272 | + raise ValueError("Invalid grammar specification.") from e |
0 commit comments