Skip to content

Commit ebb3930

Browse files
[Misc] Move config fields to MultiModalConfig (vllm-project#17343)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent cde384c commit ebb3930

File tree

8 files changed

+62
-36
lines changed

8 files changed

+62
-36
lines changed

vllm/config.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ class ModelConfig:
263263
the model name will be the same as `model`.
264264
limit_mm_per_prompt: Maximum number of data items per modality
265265
per prompt. Only applicable for multimodal models.
266+
mm_processor_kwargs: Overrides for the multi-modal processor obtained
267+
from `AutoProcessor.from_pretrained`.
268+
disable_mm_preprocessor_cache: If True, disable caching of the
269+
processed multi-modal inputs.
266270
use_async_output_proc: Whether to use async output processor.
267271
Defaults to True.
268272
config_format: The config format which shall be loaded.
@@ -273,10 +277,6 @@ class ModelConfig:
273277
hf_overrides: If a dictionary, contains arguments to be forwarded to the
274278
HuggingFace config. If a callable, it is called to update the
275279
HuggingFace config.
276-
mm_processor_kwargs: Arguments to be forwarded to the model's processor
277-
for multi-modal data, e.g., image processor.
278-
disable_mm_preprocessor_cache: If true, then disables caching of the
279-
multi-modal preprocessor/mapper. (not recommended)
280280
override_neuron_config: Initialize non default neuron config or
281281
override default neuron config that are specific to Neuron devices,
282282
this argument will be used to configure the neuron config that
@@ -320,7 +320,6 @@ def compute_hash(self) -> str:
320320
factors.append(self.max_logprobs)
321321
factors.append(self.disable_sliding_window)
322322
factors.append(self.trust_remote_code)
323-
factors.append(self.mm_processor_kwargs)
324323
factors.append(self.generation_config)
325324
factors.append(self.model_impl)
326325
factors.append(self.override_generation_config)
@@ -359,12 +358,12 @@ def __init__(
359358
skip_tokenizer_init: bool = False,
360359
served_model_name: Optional[Union[str, list[str]]] = None,
361360
limit_mm_per_prompt: Optional[dict[str, int]] = None,
361+
mm_processor_kwargs: Optional[dict[str, Any]] = None,
362+
disable_mm_preprocessor_cache: bool = False,
362363
use_async_output_proc: bool = True,
363364
config_format: ConfigFormat = ConfigFormat.AUTO,
364365
hf_token: Optional[Union[bool, str]] = None,
365366
hf_overrides: Optional[HfOverrides] = None,
366-
mm_processor_kwargs: Optional[dict[str, Any]] = None,
367-
disable_mm_preprocessor_cache: bool = False,
368367
override_neuron_config: Optional[dict[str, Any]] = None,
369368
override_pooler_config: Optional["PoolerConfig"] = None,
370369
logits_processor_pattern: Optional[str] = None,
@@ -469,8 +468,6 @@ def __init__(
469468
self.model, hf_token=hf_token, revision=revision)
470469
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
471470
self.use_async_output_proc = use_async_output_proc
472-
self.mm_processor_kwargs = mm_processor_kwargs
473-
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
474471

475472
# Set enforce_eager to False if the value is unset.
476473
if self.enforce_eager is None:
@@ -515,7 +512,10 @@ def __init__(
515512
self.served_model_name = get_served_model_name(model,
516513
served_model_name)
517514
self.multimodal_config = self._init_multimodal_config(
518-
limit_mm_per_prompt)
515+
limit_mm_per_prompt=limit_mm_per_prompt,
516+
mm_processor_kwargs=mm_processor_kwargs,
517+
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
518+
)
519519
if not self.skip_tokenizer_init:
520520
self._verify_tokenizer_mode()
521521

@@ -581,14 +581,27 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str,
581581
self.tokenizer = s3_tokenizer.dir
582582

583583
def _init_multimodal_config(
584-
self, limit_mm_per_prompt: Optional[dict[str, int]]
584+
self,
585+
limit_mm_per_prompt: Optional[dict[str, int]],
586+
mm_processor_kwargs: Optional[dict[str, Any]],
587+
disable_mm_preprocessor_cache: bool,
585588
) -> Optional["MultiModalConfig"]:
586589
if self.registry.is_multimodal_model(self.architectures):
587-
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
590+
return MultiModalConfig(
591+
limit_per_prompt=limit_mm_per_prompt or {},
592+
mm_processor_kwargs=mm_processor_kwargs or {},
593+
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
594+
)
588595

589596
if limit_mm_per_prompt:
590597
raise ValueError("`limit_mm_per_prompt` is only supported for "
591598
"multimodal models.")
599+
if mm_processor_kwargs:
600+
raise ValueError("`mm_processor_kwargs` is only supported for "
601+
"multimodal models.")
602+
if disable_mm_preprocessor_cache:
603+
raise ValueError("`disable_mm_preprocessor_cache` is only "
604+
"supported for multimodal models.")
592605

593606
return None
594607

@@ -2776,7 +2789,23 @@ class MultiModalConfig:
27762789
Defaults to 1 (V0) or 999 (V1) for each modality.
27772790
27782791
For example, to allow up to 16 images and 2 videos per prompt:
2779-
``{"images": 16, "videos": 2}``
2792+
:code:`{"images": 16, "videos": 2}`
2793+
"""
2794+
2795+
mm_processor_kwargs: Optional[dict[str, object]] = None
2796+
"""
2797+
Overrides for the multi-modal processor obtained from
2798+
:meth:`transformers.AutoProcessor.from_pretrained`.
2799+
2800+
The available overrides depend on the model that is being run.
2801+
2802+
For example, for Phi-3-Vision:
2803+
:code:`{"num_crops": 4}`.
2804+
"""
2805+
2806+
disable_mm_preprocessor_cache: bool = False
2807+
"""
2808+
If :code:`True`, disable caching of the processed multi-modal inputs.
27802809
"""
27812810

27822811
def compute_hash(self) -> str:
@@ -4080,8 +4109,6 @@ def __str__(self):
40804109
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
40814110
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
40824111
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
4083-
f"disable_mm_preprocessor_cache={self.model_config.disable_mm_preprocessor_cache!r}, " # noqa
4084-
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
40854112
f"pooler_config={self.model_config.pooler_config!r}, "
40864113
f"compilation_config={self.compilation_config!r}")
40874114

vllm/engine/arg_utils.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -672,20 +672,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
672672
)
673673
multimodal_group.add_argument('--limit-mm-per-prompt',
674674
**multimodal_kwargs["limit_per_prompt"])
675-
676-
parser.add_argument(
675+
multimodal_group.add_argument(
677676
'--mm-processor-kwargs',
678-
default=None,
679-
type=json.loads,
680-
help=('Overrides for the multi-modal processor obtained from '
681-
'``AutoProcessor.from_pretrained``. The available overrides '
682-
'depend on the model that is being run.'
683-
'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
684-
parser.add_argument(
677+
**multimodal_kwargs["mm_processor_kwargs"])
678+
multimodal_group.add_argument(
685679
'--disable-mm-preprocessor-cache',
686-
action='store_true',
687-
help='If True, disable caching of the processed multi-modal '
688-
'inputs.')
680+
**multimodal_kwargs["disable_mm_preprocessor_cache"])
689681

690682
# LoRA related configs
691683
lora_kwargs = get_kwargs(LoRAConfig)

vllm/inputs/registry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def init_processor(
101101
Initialize a HuggingFace-like processor class, merging the
102102
keyword arguments with those in the model's configuration.
103103
"""
104-
base_kwargs = self.model_config.mm_processor_kwargs
104+
mm_config = self.model_config.get_multimodal_config()
105+
base_kwargs = mm_config.mm_processor_kwargs
105106
if base_kwargs is None:
106107
base_kwargs = {}
107108

@@ -139,7 +140,8 @@ def call_hf_processor(
139140
"""
140141
assert callable(hf_processor)
141142

142-
base_kwargs = self.model_config.mm_processor_kwargs
143+
mm_config = self.model_config.get_multimodal_config()
144+
base_kwargs = mm_config.mm_processor_kwargs
143145
if base_kwargs is None:
144146
base_kwargs = {}
145147

vllm/model_executor/models/qwen2_vl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -774,8 +774,9 @@ def _get_image_processor_kwargs(
774774
size: Optional[dict[str, int]] = None,
775775
**kwargs: object,
776776
):
777-
if self.ctx.model_config.mm_processor_kwargs:
778-
kwargs.update(self.ctx.model_config.mm_processor_kwargs)
777+
mm_config = self.ctx.model_config.get_multimodal_config()
778+
if mm_config.mm_processor_kwargs:
779+
kwargs.update(mm_config.mm_processor_kwargs)
779780

780781
if min_pixels is not None:
781782
kwargs["min_pixels"] = min_pixels

vllm/multimodal/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def create_processor(
262262
if tokenizer is None:
263263
tokenizer = cached_tokenizer_from_config(model_config)
264264
if disable_cache is None:
265-
disable_cache = model_config.disable_mm_preprocessor_cache
265+
mm_config = model_config.get_multimodal_config()
266+
disable_cache = mm_config.disable_mm_preprocessor_cache
266267

267268
model_cls = self._get_model_cls(model_config)
268269
factories = self._processor_factories[model_cls]

vllm/transformers_utils/processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __hash__(self) -> int: # type: ignore[override]
3333

3434

3535
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
36-
base_kwargs = model_config.mm_processor_kwargs
36+
mm_config = model_config.get_multimodal_config()
37+
base_kwargs = mm_config.mm_processor_kwargs
3738
if base_kwargs is None:
3839
base_kwargs = {}
3940

vllm/v1/engine/mm_input_cache.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
class MirroredProcessingCache:
3434

3535
def __init__(self, model_config):
36-
self.use_cache = not model_config.disable_mm_preprocessor_cache
36+
mm_config = model_config.multimodal_config
37+
disable_mm_preprocessor_cache = mm_config is not None and \
38+
not mm_config.disable_mm_preprocessor_cache
39+
self.use_cache = not disable_mm_preprocessor_cache
3740
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
3841
MultiModalKwargs)
3942

vllm/v1/engine/processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def __init__(
5151
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
5252

5353
# Multi-modal hasher (for images)
54-
self.use_hash = (
55-
not self.model_config.disable_mm_preprocessor_cache) or \
54+
self.use_hash = self.mm_input_cache_client.use_cache or \
5655
self.cache_config.enable_prefix_caching
5756

5857
def _validate_logprobs(

0 commit comments

Comments
 (0)