Skip to content

Commit f6417b0

Browse files
hmellorMu Huai
authored and
Mu Huai
committed
Improve configs - ModelConfig (vllm-project#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 6fa7a6f commit f6417b0

36 files changed

+492
-650
lines changed

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ class VllmRunner:
738738
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
739739
- `enable_chunked_prefill`: Set to `False` instead of `None` for
740740
test reproducibility.
741-
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph.
741+
- `enforce_eager`: Set to `False` to test CUDA graph.
742742
"""
743743

744744
def __init__(

tests/engine/test_arg_utils.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010

11-
from vllm.config import PoolerConfig, config
11+
from vllm.config import config
1212
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
1313
get_type, is_not_builtin, is_type,
1414
literal_to_kwargs, nullable_kvs,
@@ -222,17 +222,6 @@ def test_prefix_cache_default():
222222
assert not engine_args.enable_prefix_caching
223223

224224

225-
def test_valid_pooling_config():
226-
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
227-
args = parser.parse_args([
228-
'--override-pooler-config',
229-
'{"pooling_type": "MEAN"}',
230-
])
231-
engine_args = EngineArgs.from_cli_args(args=args)
232-
assert engine_args.override_pooler_config == PoolerConfig(
233-
pooling_type="MEAN", )
234-
235-
236225
@pytest.mark.parametrize(
237226
("arg"),
238227
[

tests/quantization/test_register_quantization_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
1515
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
1616
from vllm.model_executor.layers.quantization import (
17-
get_quantization_config, register_quantization_config)
17+
QuantizationMethods, get_quantization_config, register_quantization_config)
1818
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
1919
QuantizationConfig)
2020

@@ -54,7 +54,7 @@ def __init__(self, num_bits: int = 8) -> None:
5454
"""Initialize the quantization config."""
5555
self.num_bits = num_bits
5656

57-
def get_name(self) -> str:
57+
def get_name(self) -> QuantizationMethods:
5858
"""Name of the quantization method."""
5959
return "custom_quant"
6060

tests/test_config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def test_get_pooling_config():
185185
revision=None,
186186
)
187187

188-
pooling_config = model_config._init_pooler_config(None)
188+
pooling_config = model_config._init_pooler_config()
189189
assert pooling_config is not None
190190

191191
assert pooling_config.normalize
@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
205205
dtype="float16",
206206
revision=None)
207207

208-
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
208+
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
209+
model_config.override_pooler_config = override_pooler_config
209210

210-
pooling_config = model_config._init_pooler_config(override_config)
211+
pooling_config = model_config._init_pooler_config()
211212
assert pooling_config is not None
212-
assert asdict(pooling_config) == asdict(override_config)
213+
assert asdict(pooling_config) == asdict(override_pooler_config)
213214

214215

215216
@pytest.mark.skipif(current_platform.is_rocm(),

vllm/config.py

Lines changed: 258 additions & 255 deletions
Large diffs are not rendered by default.

vllm/engine/arg_utils.py

Lines changed: 137 additions & 314 deletions
Large diffs are not rendered by default.

vllm/entrypoints/llm.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
1515
BeamSearchSequence, get_beam_search_score)
16-
from vllm.config import CompilationConfig
16+
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
1717
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
1818
TaskOption)
1919
from vllm.engine.llm_engine import LLMEngine
@@ -32,6 +32,7 @@
3232
from vllm.lora.request import LoRARequest
3333
from vllm.model_executor.guided_decoding.guided_fields import (
3434
GuidedDecodingRequest, LLMGuidedOptions)
35+
from vllm.model_executor.layers.quantization import QuantizationMethods
3536
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
3637
PoolingRequestOutput, RequestOutput,
3738
ScoringRequestOutput)
@@ -163,20 +164,20 @@ def __init__(
163164
self,
164165
model: str,
165166
tokenizer: Optional[str] = None,
166-
tokenizer_mode: str = "auto",
167+
tokenizer_mode: TokenizerMode = "auto",
167168
skip_tokenizer_init: bool = False,
168169
trust_remote_code: bool = False,
169170
allowed_local_media_path: str = "",
170171
tensor_parallel_size: int = 1,
171-
dtype: str = "auto",
172-
quantization: Optional[str] = None,
172+
dtype: ModelDType = "auto",
173+
quantization: Optional[QuantizationMethods] = None,
173174
revision: Optional[str] = None,
174175
tokenizer_revision: Optional[str] = None,
175176
seed: Optional[int] = None,
176177
gpu_memory_utilization: float = 0.9,
177178
swap_space: float = 4,
178179
cpu_offload_gb: float = 0,
179-
enforce_eager: Optional[bool] = None,
180+
enforce_eager: bool = False,
180181
max_seq_len_to_capture: int = 8192,
181182
disable_custom_all_reduce: bool = False,
182183
disable_async_output_proc: bool = False,
@@ -189,12 +190,7 @@ def __init__(
189190
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
190191
**kwargs,
191192
) -> None:
192-
'''
193-
LLM constructor.
194-
195-
Note: if enforce_eager is unset (enforce_eager is None)
196-
it defaults to False.
197-
'''
193+
"""LLM constructor."""
198194

199195
if "disable_log_stats" not in kwargs:
200196
kwargs["disable_log_stats"] = True

vllm/model_executor/layers/quantization/aqlm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from vllm import _custom_ops as ops
1414
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
15+
from vllm.model_executor.layers.quantization import QuantizationMethods
1516
from vllm.model_executor.layers.quantization.base_config import (
1617
QuantizationConfig)
1718
from vllm.model_executor.utils import set_weight_attrs
@@ -186,7 +187,7 @@ def __repr__(self) -> str:
186187
f"out_group_size={self.out_group_size})")
187188

188189
@classmethod
189-
def get_name(cls) -> str:
190+
def get_name(cls) -> QuantizationMethods:
190191
return "aqlm"
191192

192193
@classmethod

vllm/model_executor/layers/quantization/awq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm import _custom_ops as ops
88
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
99
UnquantizedLinearMethod)
10+
from vllm.model_executor.layers.quantization import QuantizationMethods
1011
from vllm.model_executor.layers.quantization.base_config import (
1112
QuantizationConfig)
1213
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@@ -44,7 +45,7 @@ def __repr__(self) -> str:
4445
f"zero_point={self.zero_point}, "
4546
f"modules_to_not_convert={self.modules_to_not_convert})")
4647

47-
def get_name(self) -> str:
48+
def get_name(self) -> QuantizationMethods:
4849
return "awq"
4950

5051
def get_supported_act_dtypes(self) -> List[torch.dtype]:

vllm/model_executor/layers/quantization/awq_marlin.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1414
UnquantizedLinearMethod,
1515
set_weight_attrs)
16+
from vllm.model_executor.layers.quantization import QuantizationMethods
1617
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
1718
is_layer_skipped_awq)
1819
from vllm.model_executor.layers.quantization.base_config import (
@@ -73,7 +74,7 @@ def __repr__(self) -> str:
7374
f"modules_to_not_convert={self.modules_to_not_convert})")
7475

7576
@classmethod
76-
def get_name(cls) -> str:
77+
def get_name(cls) -> QuantizationMethods:
7778
return "awq_marlin"
7879

7980
@classmethod
@@ -101,8 +102,8 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
101102
modules_to_not_convert, config)
102103

103104
@classmethod
104-
def override_quantization_method(cls, hf_quant_cfg,
105-
user_quant) -> Optional[str]:
105+
def override_quantization_method(
106+
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
106107
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
107108
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
108109
or user_quant == "awq_marlin")

vllm/model_executor/layers/quantization/base_config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33
import inspect
44
from abc import ABC, abstractmethod
5-
from typing import Any, Dict, List, Optional, Type
5+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
66

77
import torch
88
from torch import nn
99

10+
if TYPE_CHECKING:
11+
from vllm.model_executor.layers.quantization import QuantizationMethods
12+
else:
13+
QuantizationMethods = str
14+
1015

1116
class QuantizeMethodBase(ABC):
1217
"""Base class for different quantized methods."""
@@ -66,7 +71,7 @@ def __init__(self):
6671
self.packed_modules_mapping: Dict[str, List[str]] = dict()
6772

6873
@abstractmethod
69-
def get_name(self) -> str:
74+
def get_name(self) -> QuantizationMethods:
7075
"""Name of the quantization method."""
7176
raise NotImplementedError
7277

@@ -99,8 +104,8 @@ def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
99104
raise NotImplementedError
100105

101106
@classmethod
102-
def override_quantization_method(cls, hf_quant_cfg,
103-
user_quant) -> Optional[str]:
107+
def override_quantization_method(
108+
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
104109
"""
105110
Detects if this quantization method can support a given checkpoint
106111
format by overriding the user specified quantization method --

vllm/model_executor/layers/quantization/bitblas.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from vllm.logger import init_logger
77
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
8+
from vllm.model_executor.layers.quantization import QuantizationMethods
89
from vllm.model_executor.layers.quantization.base_config import (
910
QuantizationConfig)
1011
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
@@ -100,7 +101,7 @@ def __repr__(self) -> str:
100101
f"quant_method={self.quant_method})")
101102

102103
@classmethod
103-
def get_name(cls) -> str:
104+
def get_name(cls) -> QuantizationMethods:
104105
return "bitblas"
105106

106107
@classmethod
@@ -139,8 +140,8 @@ def from_config(cls, config: Dict[str, Any]) -> "BitBLASConfig":
139140
lm_head_quantized)
140141

141142
@classmethod
142-
def override_quantization_method(cls, hf_quant_cfg,
143-
user_quant) -> Optional[str]:
143+
def override_quantization_method(
144+
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
144145
# compat: autogptq >=0.8.0 use checkpoint_format: str
145146
# compat: autogptq <=0.7.1 is_bitblas_format: bool
146147
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"

vllm/model_executor/layers/quantization/bitsandbytes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
88
UnquantizedLinearMethod,
99
set_weight_attrs)
10+
from vllm.model_executor.layers.quantization import QuantizationMethods
1011
from vllm.model_executor.layers.quantization.base_config import (
1112
QuantizationConfig)
1213
from vllm.utils import direct_register_custom_op
@@ -56,7 +57,7 @@ def __repr__(self) -> str:
5657
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
5758

5859
@classmethod
59-
def get_name(self) -> str:
60+
def get_name(self) -> QuantizationMethods:
6061
return "bitsandbytes"
6162

6263
@classmethod

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.model_executor.layers.fused_moe import FusedMoE
1717
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1818
UnquantizedLinearMethod)
19+
from vllm.model_executor.layers.quantization import QuantizationMethods
1920
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
2021
QuantizationConfig, QuantizeMethodBase)
2122
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
@@ -71,7 +72,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
7172
def get_min_capability(cls) -> int:
7273
return 70
7374

74-
def get_name(self) -> str:
75+
def get_name(self) -> QuantizationMethods:
7576
return "compressed-tensors"
7677

7778
def get_quant_method(

vllm/model_executor/layers/quantization/deepspeedfp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn.functional as F
88

99
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
10+
from vllm.model_executor.layers.quantization import QuantizationMethods
1011
from vllm.model_executor.layers.quantization.base_config import (
1112
QuantizationConfig)
1213
from vllm.model_executor.utils import set_weight_attrs
@@ -41,8 +42,8 @@ def __repr__(self) -> str:
4142
f"group_size={self.group_size}")
4243

4344
@classmethod
44-
def get_name(cls) -> str:
45-
return "DeepSpeedFP"
45+
def get_name(cls) -> QuantizationMethods:
46+
return "deepspeedfp"
4647

4748
@classmethod
4849
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":

vllm/model_executor/layers/quantization/experts_int8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
99
from vllm.model_executor.layers.linear import (LinearBase,
1010
UnquantizedLinearMethod)
11+
from vllm.model_executor.layers.quantization import QuantizationMethods
1112
from vllm.model_executor.layers.quantization.base_config import (
1213
QuantizationConfig, QuantizeMethodBase)
1314
from vllm.model_executor.utils import set_weight_attrs
@@ -20,7 +21,7 @@ def __init__(self) -> None:
2021
super().__init__()
2122

2223
@classmethod
23-
def get_name(cls) -> str:
24+
def get_name(cls) -> QuantizationMethods:
2425
return "experts_int8"
2526

2627
@classmethod

vllm/model_executor/layers/quantization/fbgemm_fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from vllm.logger import init_logger
1010
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1111
UnquantizedLinearMethod)
12+
from vllm.model_executor.layers.quantization import QuantizationMethods
1213
from vllm.model_executor.layers.quantization.base_config import (
1314
QuantizationConfig, QuantizeMethodBase)
1415
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -38,7 +39,7 @@ def __init__(self, ignore_list: List[str], input_scale_ub: float):
3839
self.fp8_linear = Fp8LinearOp()
3940

4041
@classmethod
41-
def get_name(cls) -> str:
42+
def get_name(cls) -> QuantizationMethods:
4243
return "fbgemm_fp8"
4344

4445
@classmethod

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
FusedMoeWeightScaleSupported)
1717
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
1818
UnquantizedLinearMethod)
19+
from vllm.model_executor.layers.quantization import QuantizationMethods
1920
from vllm.model_executor.layers.quantization.base_config import (
2021
QuantizationConfig, QuantizeMethodBase)
2122
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@@ -83,7 +84,7 @@ def __init__(
8384
self.weight_block_size = weight_block_size
8485

8586
@classmethod
86-
def get_name(cls) -> str:
87+
def get_name(cls) -> QuantizationMethods:
8788
return "fp8"
8889

8990
@classmethod

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
1414
FusedMoEMethodBase)
1515
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
16+
from vllm.model_executor.layers.quantization import QuantizationMethods
1617
from vllm.model_executor.layers.quantization.base_config import (
1718
QuantizationConfig, QuantizeMethodBase)
1819
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -31,7 +32,7 @@ def __init__(self, ) -> None:
3132
def __repr__(self) -> str:
3233
return ("GGUFConfig()")
3334

34-
def get_name(self) -> str:
35+
def get_name(self) -> QuantizationMethods:
3536
return "gguf"
3637

3738
def get_supported_act_dtypes(self) -> List[torch.dtype]:

0 commit comments

Comments
 (0)