10
10
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
11
11
AttentionMetadata , AttentionType ,
12
12
is_quantized_kv_cache )
13
+ from vllm .attention .layer import Attention
13
14
from vllm .attention .ops .merge_attn_states import merge_attn_states
14
15
from vllm .attention .utils .fa_utils import (flash_attn_supports_fp8 ,
15
16
get_flash_attn_version )
17
+ from vllm .config import VllmConfig , get_layers_from_vllm_config
16
18
from vllm .logger import init_logger
17
19
from vllm .platforms import current_platform
18
20
from vllm .utils import cdiv
@@ -276,20 +278,35 @@ def make_local_attention_virtual_batches(
276
278
block_table_local
277
279
278
280
281
+ def _get_sliding_window_configs (
282
+ vllm_config : VllmConfig ) -> set [Optional [tuple [int , int ]]]:
283
+ """Get the set of all sliding window configs used in the model."""
284
+ sliding_window_configs : set [Optional [tuple [int , int ]]] = set ()
285
+ layers = get_layers_from_vllm_config (vllm_config , Attention )
286
+ for layer in layers .values ():
287
+ assert isinstance (layer .impl , FlashAttentionImpl )
288
+ sliding_window_configs .add (layer .impl .sliding_window )
289
+ return sliding_window_configs
290
+
291
+
279
292
class FlashAttentionMetadataBuilder :
280
293
281
294
def __init__ (self , runner : "GPUModelRunner" ):
282
295
model_config = runner .model_config
283
296
284
297
self .runner = runner
285
- self .aot_schedule = (get_flash_attn_version () == 3 )
286
298
self .num_heads_q = model_config .get_num_attention_heads (
287
299
runner .parallel_config )
288
300
self .num_heads_kv = model_config .get_num_kv_heads (
289
301
runner .parallel_config )
290
302
self .headdim = model_config .get_head_size ()
291
303
self .page_size = self .runner .block_size
292
304
305
+ self .aot_schedule = (get_flash_attn_version () == 3 )
306
+ # Sliding window size to be used with the AOT scheduler will be
307
+ # populated on first build() call.
308
+ self .aot_sliding_window : Optional [tuple [int , int ]] = None
309
+
293
310
def reorder_batch (self , input_batch : "InputBatch" ,
294
311
scheduler_output : "SchedulerOutput" ) -> bool :
295
312
return False
@@ -307,6 +324,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
307
324
slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
308
325
self .runner .device , non_blocking = True ).long ()
309
326
327
+ if self .aot_sliding_window is None :
328
+ self .aot_sliding_window = (- 1 , - 1 )
329
+ # For the AOT scheduler we need the sliding window value to be
330
+ # constant for all layers to. We have to populate this on the first
331
+ # build() call so the layers are constructed (cannot populate)
332
+ # in __init__.
333
+ if self .aot_schedule :
334
+ sliding_window_configs = _get_sliding_window_configs (
335
+ self .runner .vllm_config )
336
+ if len (sliding_window_configs ) == 1 :
337
+ sliding_window_config = sliding_window_configs .pop ()
338
+ if sliding_window_config is not None :
339
+ self .aot_sliding_window = sliding_window_config
340
+ elif len (sliding_window_configs ) > 1 :
341
+ self .aot_schedule = False
342
+
310
343
def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
311
344
max_seq_len , causal ):
312
345
if self .aot_schedule :
@@ -321,6 +354,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
321
354
page_size = self .page_size ,
322
355
cu_seqlens_q = cu_query_lens ,
323
356
causal = causal ,
357
+ window_size = self .aot_sliding_window ,
324
358
)
325
359
return None
326
360
0 commit comments