Skip to content

Commit 8de2901

Browse files
authored
[Bugfix] gemma[2,3] interleaved attention when sliding window is disabled (#17180)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
1 parent c53e073 commit 8de2901

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

vllm/model_executor/models/gemma2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def __init__(self,
145145
# reference:
146146
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
147147
layer_idx = extract_layer_index(prefix)
148-
use_sliding_window = (layer_idx % 2 == 0 and
149-
config.interleaved_sliding_window is not None)
148+
use_sliding_window = (layer_idx % 2 == 0 and getattr(
149+
config, "interleaved_sliding_window", None) is not None)
150150
sliding_window = config.interleaved_sliding_window if \
151151
use_sliding_window else None
152152
self.attn = Attention(self.num_heads,

vllm/model_executor/models/gemma3.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def __init__(self,
146146

147147
# TODO(woosuk): Add reference to the original HF implementation.
148148
layer_idx = extract_layer_index(prefix)
149-
self.is_sliding = bool((layer_idx + 1) % config.sliding_window_pattern)
149+
self.is_sliding = (getattr(
150+
config, "interleaved_sliding_window", None) is not None and bool(
151+
(layer_idx + 1) % config.sliding_window_pattern))
150152
# Initialize the rotary embedding.
151153
if self.is_sliding:
152154
# Local attention. Override the values in config.json.

vllm/model_executor/models/gemma3_mm.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
478478
self.config = config
479479
self.quant_config = quant_config
480480
self.multimodal_config = multimodal_config
481-
self.sliding_window = config.text_config.interleaved_sliding_window
481+
self.sliding_window = getattr(config.text_config,
482+
"interleaved_sliding_window", None)
482483

483484
self.vision_tower = SiglipVisionModel(config.vision_config,
484485
quant_config,
@@ -680,13 +681,14 @@ def prepare_attn_masks(
680681
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
681682
global_attn_masks.append(global_attn_mask)
682683

683-
# Create a local causal mask with sliding window (1024).
684-
local_attn_mask = torch.ones_like(global_attn_mask)
685-
local_attn_mask = torch.tril(local_attn_mask,
686-
diagonal=-self.sliding_window)
687-
local_attn_mask = torch.where(local_attn_mask == 0,
688-
global_attn_mask, float("-inf"))
689-
local_attn_masks.append(local_attn_mask)
684+
if self.sliding_window is not None:
685+
# Create a local causal mask with sliding window (1024).
686+
local_attn_mask = torch.ones_like(global_attn_mask)
687+
local_attn_mask = torch.tril(local_attn_mask,
688+
diagonal=-self.sliding_window)
689+
local_attn_mask = torch.where(local_attn_mask == 0,
690+
global_attn_mask, float("-inf"))
691+
local_attn_masks.append(local_attn_mask)
690692
kwargs["global_attn_masks"] = global_attn_masks
691693
kwargs["local_attn_masks"] = local_attn_masks
692694
return kwargs

0 commit comments

Comments
 (0)