Skip to content

Commit f8db0bd

Browse files
LucasWilkinsonheyselbi
authored andcommitted
[BugFix][Attention] Fix sliding window attention in V1 giving incorrect results (vllm-project#17574)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
1 parent e335c34 commit f8db0bd

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1111
AttentionMetadata, AttentionType,
1212
is_quantized_kv_cache)
13+
from vllm.attention.layer import Attention
1314
from vllm.attention.ops.merge_attn_states import merge_attn_states
1415
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
1516
get_flash_attn_version)
17+
from vllm.config import VllmConfig, get_layers_from_vllm_config
1618
from vllm.logger import init_logger
1719
from vllm.platforms import current_platform
1820
from vllm.utils import cdiv
@@ -276,20 +278,35 @@ def make_local_attention_virtual_batches(
276278
block_table_local
277279

278280

281+
def _get_sliding_window_configs(
282+
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
283+
"""Get the set of all sliding window configs used in the model."""
284+
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
285+
layers = get_layers_from_vllm_config(vllm_config, Attention)
286+
for layer in layers.values():
287+
assert isinstance(layer.impl, FlashAttentionImpl)
288+
sliding_window_configs.add(layer.impl.sliding_window)
289+
return sliding_window_configs
290+
291+
279292
class FlashAttentionMetadataBuilder:
280293

281294
def __init__(self, runner: "GPUModelRunner"):
282295
model_config = runner.model_config
283296

284297
self.runner = runner
285-
self.aot_schedule = (get_flash_attn_version() == 3)
286298
self.num_heads_q = model_config.get_num_attention_heads(
287299
runner.parallel_config)
288300
self.num_heads_kv = model_config.get_num_kv_heads(
289301
runner.parallel_config)
290302
self.headdim = model_config.get_head_size()
291303
self.page_size = self.runner.block_size
292304

305+
self.aot_schedule = (get_flash_attn_version() == 3)
306+
# Sliding window size to be used with the AOT scheduler will be
307+
# populated on first build() call.
308+
self.aot_sliding_window: Optional[tuple[int, int]] = None
309+
293310
def reorder_batch(self, input_batch: "InputBatch",
294311
scheduler_output: "SchedulerOutput") -> bool:
295312
return False
@@ -307,6 +324,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
307324
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
308325
self.runner.device, non_blocking=True).long()
309326

327+
if self.aot_sliding_window is None:
328+
self.aot_sliding_window = (-1, -1)
329+
# For the AOT scheduler we need the sliding window value to be
330+
# constant for all layers to. We have to populate this on the first
331+
# build() call so the layers are constructed (cannot populate)
332+
# in __init__.
333+
if self.aot_schedule:
334+
sliding_window_configs = _get_sliding_window_configs(
335+
self.runner.vllm_config)
336+
if len(sliding_window_configs) == 1:
337+
sliding_window_config = sliding_window_configs.pop()
338+
if sliding_window_config is not None:
339+
self.aot_sliding_window = sliding_window_config
340+
elif len(sliding_window_configs) > 1:
341+
self.aot_schedule = False
342+
310343
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
311344
max_seq_len, causal):
312345
if self.aot_schedule:
@@ -321,6 +354,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
321354
page_size=self.page_size,
322355
cu_seqlens_q=cu_query_lens,
323356
causal=causal,
357+
window_size=self.aot_sliding_window,
324358
)
325359
return None
326360

0 commit comments

Comments
 (0)