Skip to content

[Chore] Remove Sampler from Model Code #17084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/spec_decode/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed)
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.model.sampler.\
should_modify_greedy_probs_inplace = True
scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True

vocab_size = scorer_worker.vocab_size

Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -435,7 +434,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.unpadded_vocab_size = config.vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand All @@ -462,14 +460,6 @@ def compute_logits(
sampling_metadata)
return logits

def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
Expand Down
12 changes: 1 addition & 11 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import (SamplerOutput,
SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
Expand Down Expand Up @@ -527,7 +526,6 @@ def __init__(
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
self.sampler = get_sampler()

def _validate_image_sizes(
self, images: List[torch.Tensor]) -> List[torch.Tensor]:
Expand Down Expand Up @@ -653,14 +651,6 @@ def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
16 changes: 0 additions & 16 deletions vllm/model_executor/models/aya_vision.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union, cast)

Expand All @@ -17,7 +16,6 @@

from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalKwargs
Expand Down Expand Up @@ -461,17 +459,3 @@ def compute_logits(
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler

return get_sampler()

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
10 changes: 0 additions & 10 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -396,7 +395,6 @@ def __init__(
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand All @@ -423,14 +421,6 @@ def compute_logits(
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -462,7 +461,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()

self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
Expand Down Expand Up @@ -538,14 +536,6 @@ def compute_logits(
sampling_metadata)
return logits

def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -791,7 +790,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()

def forward(
self,
Expand Down Expand Up @@ -828,14 +826,6 @@ def compute_logits(
sampling_metadata)
return logits

def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

stacked_params_mapping = {
"q_proj": {
"param_name": "qkv_proj",
Expand Down
18 changes: 1 addition & 17 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Literal, Optional, Set, Tuple, TypedDict, Union

import torch
Expand All @@ -12,7 +11,6 @@
from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
Expand Down Expand Up @@ -530,13 +528,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)

@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler

return get_sampler()

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
Expand Down Expand Up @@ -649,7 +640,7 @@ def forward(
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
) -> IntermediateTensors:
"""Run forward pass for BLIP-2.

One key thing to understand is the `input_ids` already accounts for the
Expand Down Expand Up @@ -707,13 +698,6 @@ def compute_logits(
return self.language_model.compute_logits(hidden_states,
sampling_metadata)

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -297,7 +296,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.config.hidden_size)

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

Expand All @@ -324,14 +322,6 @@ def compute_logits(
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters(remove_duplicate=False))
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
Expand Down Expand Up @@ -950,7 +949,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down Expand Up @@ -1054,14 +1052,6 @@ def compute_logits(

return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -429,7 +428,6 @@ def __init__(
self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)

Expand All @@ -442,14 +440,6 @@ def compute_logits(
sampling_metadata)
return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
Expand Down
10 changes: 0 additions & 10 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
Expand Down Expand Up @@ -372,7 +371,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
scale=config.logit_scale)
self.model = CohereModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

Expand Down Expand Up @@ -406,14 +404,6 @@ def compute_logits(

return logits

def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
Expand Down
Loading