Skip to content

Commit c579b14

Browse files
WoosukKwonMu Huai
authored and
Mu Huai
committed
[Chore] Remove Sampler from Model Code (vllm-project#17084)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 67baab6 commit c579b14

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+48
-1099
lines changed

tests/spec_decode/test_scorer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
6262
scorer_worker = create_worker(Worker, model_name, block_size,
6363
num_gpu_blocks, seed)
6464
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
65-
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
66-
scorer_worker.model_runner.model.sampler.\
67-
should_modify_greedy_probs_inplace = True
65+
scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True
66+
scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True
6867

6968
vocab_size = scorer_worker.vocab_size
7069

vllm/model_executor/models/arctic.py

-10
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from vllm.model_executor.layers.quantization.deepspeedfp import (
2525
DeepSpeedFPConfig, DeepSpeedFPParameter)
2626
from vllm.model_executor.layers.rotary_embedding import get_rope
27-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2827
from vllm.model_executor.layers.vocab_parallel_embedding import (
2928
ParallelLMHead, VocabParallelEmbedding)
3029
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -435,7 +434,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
435434
self.unpadded_vocab_size = config.vocab_size
436435
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
437436
config.vocab_size)
438-
self.sampler = get_sampler()
439437
self.make_empty_intermediate_tensors = (
440438
self.model.make_empty_intermediate_tensors)
441439

@@ -462,14 +460,6 @@ def compute_logits(
462460
sampling_metadata)
463461
return logits
464462

465-
def sample(
466-
self,
467-
logits: Optional[torch.Tensor],
468-
sampling_metadata: SamplingMetadata,
469-
) -> Optional[SamplerOutput]:
470-
next_tokens = self.sampler(logits, sampling_metadata)
471-
return next_tokens
472-
473463
def load_weights(self, weights: Iterable[Tuple[str,
474464
torch.Tensor]]) -> Set[str]:
475465
stacked_params_mapping = [

vllm/model_executor/models/aria.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1616
RowParallelLinear)
1717
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18-
from vllm.model_executor.layers.sampler import (SamplerOutput,
19-
SamplingMetadata, get_sampler)
2018
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
2119
from vllm.model_executor.model_loader.weight_utils import (
2220
default_weight_loader, maybe_remap_kv_scale_name)
21+
from vllm.model_executor.sampling_metadata import SamplingMetadata
2322
from vllm.multimodal import MULTIMODAL_REGISTRY
2423
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
2524
MultiModalKwargs)
@@ -527,7 +526,6 @@ def __init__(
527526
logit_scale = getattr(config, "logit_scale", 1.0)
528527
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
529528
self.vocab_size, logit_scale)
530-
self.sampler = get_sampler()
531529

532530
def _validate_image_sizes(
533531
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
@@ -653,14 +651,6 @@ def compute_logits(self, hidden_states: torch.Tensor,
653651
sampling_metadata)
654652
return logits
655653

656-
def sample(
657-
self,
658-
logits: torch.Tensor,
659-
sampling_metadata: SamplingMetadata,
660-
) -> Optional[SamplerOutput]:
661-
next_tokens = self.sampler(logits, sampling_metadata)
662-
return next_tokens
663-
664654
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
665655
loader = AutoWeightsLoader(self)
666656
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/aya_vision.py

-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0 Adapted from
22
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
3-
from functools import cached_property
43
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
54
TypedDict, Union, cast)
65

@@ -17,7 +16,6 @@
1716

1817
from vllm.config import VllmConfig
1918
from vllm.jsontree import json_map_leaves
20-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2119
from vllm.model_executor.sampling_metadata import SamplingMetadata
2220
from vllm.multimodal import MULTIMODAL_REGISTRY
2321
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
@@ -461,17 +459,3 @@ def compute_logits(
461459
) -> Optional[torch.Tensor]:
462460
return self.language_model.compute_logits(hidden_states,
463461
sampling_metadata)
464-
465-
@cached_property
466-
def sampler(self):
467-
if hasattr(self.language_model, "sampler"):
468-
return self.language_model.sampler
469-
470-
return get_sampler()
471-
472-
def sample(
473-
self,
474-
logits: torch.Tensor,
475-
sampling_metadata: SamplingMetadata,
476-
) -> Optional[SamplerOutput]:
477-
return self.language_model.sample(logits, sampling_metadata)

vllm/model_executor/models/baichuan.py

-10
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
from vllm.model_executor.layers.logits_processor import LogitsProcessor
4040
from vllm.model_executor.layers.quantization import QuantizationConfig
4141
from vllm.model_executor.layers.rotary_embedding import get_rope
42-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4342
from vllm.model_executor.layers.vocab_parallel_embedding import (
4443
ParallelLMHead, VocabParallelEmbedding)
4544
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -396,7 +395,6 @@ def __init__(
396395
if self.config.tie_word_embeddings:
397396
self.lm_head.weight = self.model.embed_tokens.weight
398397
self.logits_processor = LogitsProcessor(config.vocab_size)
399-
self.sampler = get_sampler()
400398
self.make_empty_intermediate_tensors = (
401399
self.model.make_empty_intermediate_tensors)
402400

@@ -423,14 +421,6 @@ def compute_logits(
423421
sampling_metadata)
424422
return logits
425423

426-
def sample(
427-
self,
428-
logits: torch.Tensor,
429-
sampling_metadata: SamplingMetadata,
430-
) -> Optional[SamplerOutput]:
431-
next_tokens = self.sampler(logits, sampling_metadata)
432-
return next_tokens
433-
434424
def load_weights(self, weights: Iterable[Tuple[str,
435425
torch.Tensor]]) -> Set[str]:
436426
loader = AutoWeightsLoader(self)

vllm/model_executor/models/bamba.py

-10
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
MambaMixer2, extra_groups_for_head_shards)
2525
from vllm.model_executor.layers.quantization import QuantizationConfig
2626
from vllm.model_executor.layers.rotary_embedding import get_rope
27-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2827
from vllm.model_executor.layers.vocab_parallel_embedding import (
2928
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
3029
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -462,7 +461,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
462461

463462
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
464463
config.vocab_size)
465-
self.sampler = get_sampler()
466464

467465
self.make_empty_intermediate_tensors = (
468466
self.model.make_empty_intermediate_tensors)
@@ -538,14 +536,6 @@ def compute_logits(
538536
sampling_metadata)
539537
return logits
540538

541-
def sample(
542-
self,
543-
logits: Optional[torch.Tensor],
544-
sampling_metadata: SamplingMetadata,
545-
) -> Optional[SamplerOutput]:
546-
next_tokens = self.sampler(logits, sampling_metadata)
547-
return next_tokens
548-
549539
def load_weights(self, weights: Iterable[Tuple[str,
550540
torch.Tensor]]) -> Set[str]:
551541
loader = AutoWeightsLoader(self)

vllm/model_executor/models/bart.py

-10
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3838
from vllm.model_executor.layers.quantization.base_config import (
3939
QuantizationConfig)
40-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4140
from vllm.model_executor.layers.vocab_parallel_embedding import (
4241
ParallelLMHead, VocabParallelEmbedding)
4342
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -791,7 +790,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
791790

792791
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
793792
config.vocab_size)
794-
self.sampler = get_sampler()
795793

796794
def forward(
797795
self,
@@ -828,14 +826,6 @@ def compute_logits(
828826
sampling_metadata)
829827
return logits
830828

831-
def sample(
832-
self,
833-
logits: Optional[torch.Tensor],
834-
sampling_metadata: SamplingMetadata,
835-
) -> Optional[SamplerOutput]:
836-
next_tokens = self.sampler(logits, sampling_metadata)
837-
return next_tokens
838-
839829
stacked_params_mapping = {
840830
"q_proj": {
841831
"param_name": "qkv_proj",

vllm/model_executor/models/blip2.py

+1-17
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections.abc import Iterable, Mapping, Sequence
4-
from functools import cached_property
54
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
65

76
import torch
@@ -12,7 +11,6 @@
1211
from vllm.config import CacheConfig, VllmConfig
1312
from vllm.model_executor.layers.activation import get_act_fn
1413
from vllm.model_executor.layers.quantization import QuantizationConfig
15-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
1614
from vllm.model_executor.sampling_metadata import SamplingMetadata
1715
from vllm.multimodal import MULTIMODAL_REGISTRY
1816
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
@@ -530,13 +528,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
530528
self.make_empty_intermediate_tensors = (
531529
self.language_model.make_empty_intermediate_tensors)
532530

533-
@cached_property
534-
def sampler(self):
535-
if hasattr(self.language_model, "sampler"):
536-
return self.language_model.sampler
537-
538-
return get_sampler()
539-
540531
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
541532
h = w = self.config.vision_config.image_size
542533
expected_dims = (3, h, w)
@@ -649,7 +640,7 @@ def forward(
649640
intermediate_tensors: Optional[IntermediateTensors] = None,
650641
inputs_embeds: Optional[torch.Tensor] = None,
651642
**kwargs: object,
652-
) -> Union[SamplerOutput, IntermediateTensors]:
643+
) -> IntermediateTensors:
653644
"""Run forward pass for BLIP-2.
654645
655646
One key thing to understand is the `input_ids` already accounts for the
@@ -707,13 +698,6 @@ def compute_logits(
707698
return self.language_model.compute_logits(hidden_states,
708699
sampling_metadata)
709700

710-
def sample(
711-
self,
712-
logits: torch.Tensor,
713-
sampling_metadata: SamplingMetadata,
714-
) -> Optional[SamplerOutput]:
715-
return self.language_model.sample(logits, sampling_metadata)
716-
717701
def load_weights(self, weights: Iterable[Tuple[str,
718702
torch.Tensor]]) -> Set[str]:
719703
loader = AutoWeightsLoader(self)

vllm/model_executor/models/bloom.py

-10
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
RowParallelLinear)
3636
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3737
from vllm.model_executor.layers.quantization import QuantizationConfig
38-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
3938
from vllm.model_executor.layers.vocab_parallel_embedding import (
4039
ParallelLMHead, VocabParallelEmbedding)
4140
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -297,7 +296,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
297296
self.config.hidden_size)
298297

299298
self.logits_processor = LogitsProcessor(config.vocab_size)
300-
self.sampler = get_sampler()
301299
self.make_empty_intermediate_tensors = (
302300
self.transformer.make_empty_intermediate_tensors)
303301

@@ -324,14 +322,6 @@ def compute_logits(
324322
sampling_metadata)
325323
return logits
326324

327-
def sample(
328-
self,
329-
logits: torch.Tensor,
330-
sampling_metadata: SamplingMetadata,
331-
) -> Optional[SamplerOutput]:
332-
next_tokens = self.sampler(logits, sampling_metadata)
333-
return next_tokens
334-
335325
def load_weights(self, weights: Iterable[Tuple[str,
336326
torch.Tensor]]) -> Set[str]:
337327
params_dict = dict(self.named_parameters(remove_duplicate=False))

vllm/model_executor/models/chameleon.py

-10
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2323
from vllm.model_executor.layers.quantization import QuantizationConfig
2424
from vllm.model_executor.layers.rotary_embedding import get_rope
25-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2625
from vllm.model_executor.layers.vocab_parallel_embedding import (
2726
ParallelLMHead, VocabParallelEmbedding)
2827
from vllm.model_executor.model_loader.weight_utils import (
@@ -950,7 +949,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
950949
logit_scale = getattr(config, "logit_scale", 1.0)
951950
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
952951
config.vocab_size, logit_scale)
953-
self.sampler = get_sampler()
954952
self.make_empty_intermediate_tensors = (
955953
self.model.make_empty_intermediate_tensors)
956954

@@ -1054,14 +1052,6 @@ def compute_logits(
10541052

10551053
return logits
10561054

1057-
def sample(
1058-
self,
1059-
logits: torch.Tensor,
1060-
sampling_metadata: SamplingMetadata,
1061-
) -> Optional[SamplerOutput]:
1062-
next_tokens = self.sampler(logits, sampling_metadata)
1063-
return next_tokens
1064-
10651055
def load_weights(self, weights: Iterable[Tuple[str,
10661056
torch.Tensor]]) -> Set[str]:
10671057
stacked_params_mapping = [

vllm/model_executor/models/chatglm.py

-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2222
from vllm.model_executor.layers.quantization import QuantizationConfig
2323
from vllm.model_executor.layers.rotary_embedding import get_rope
24-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2524
from vllm.model_executor.layers.vocab_parallel_embedding import (
2625
ParallelLMHead, VocabParallelEmbedding)
2726
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -429,7 +428,6 @@ def __init__(
429428
self.transformer.embedding.weight)
430429
self.lm_head = self.transformer.output_layer
431430
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
432-
self.sampler = get_sampler()
433431
self.make_empty_intermediate_tensors = (
434432
self.transformer.make_empty_intermediate_tensors)
435433

@@ -442,14 +440,6 @@ def compute_logits(
442440
sampling_metadata)
443441
return logits
444442

445-
def sample(
446-
self,
447-
logits: torch.Tensor,
448-
sampling_metadata: SamplingMetadata,
449-
) -> Optional[SamplerOutput]:
450-
next_tokens = self.sampler(logits, sampling_metadata)
451-
return next_tokens
452-
453443
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
454444
loader = AutoWeightsLoader(self)
455445
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/commandr.py

-10
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3939
from vllm.model_executor.layers.quantization import QuantizationConfig
4040
from vllm.model_executor.layers.rotary_embedding import get_rope
41-
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
4241
from vllm.model_executor.layers.vocab_parallel_embedding import (
4342
VocabParallelEmbedding)
4443
from vllm.model_executor.model_loader.weight_utils import (
@@ -372,7 +371,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
372371
scale=config.logit_scale)
373372
self.model = CohereModel(vllm_config=vllm_config,
374373
prefix=maybe_prefix(prefix, "model"))
375-
self.sampler = get_sampler()
376374
self.make_empty_intermediate_tensors = (
377375
self.model.make_empty_intermediate_tensors)
378376

@@ -406,14 +404,6 @@ def compute_logits(
406404

407405
return logits
408406

409-
def sample(
410-
self,
411-
logits: torch.Tensor,
412-
sampling_metadata: SamplingMetadata,
413-
) -> Optional[SamplerOutput]:
414-
next_tokens = self.sampler(logits, sampling_metadata)
415-
return next_tokens
416-
417407
def load_weights(self, weights: Iterable[Tuple[str,
418408
torch.Tensor]]) -> Set[str]:
419409
stacked_params_mapping = [

0 commit comments

Comments
 (0)