@@ -478,7 +478,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
478
478
self .config = config
479
479
self .quant_config = quant_config
480
480
self .multimodal_config = multimodal_config
481
- self .sliding_window = config .text_config .interleaved_sliding_window
481
+ self .sliding_window = getattr (config .text_config ,
482
+ "interleaved_sliding_window" , None )
482
483
483
484
self .vision_tower = SiglipVisionModel (config .vision_config ,
484
485
quant_config ,
@@ -680,13 +681,14 @@ def prepare_attn_masks(
680
681
global_attn_mask = torch .where (img_mask == 2 , 0 , global_attn_mask )
681
682
global_attn_masks .append (global_attn_mask )
682
683
683
- # Create a local causal mask with sliding window (1024).
684
- local_attn_mask = torch .ones_like (global_attn_mask )
685
- local_attn_mask = torch .tril (local_attn_mask ,
686
- diagonal = - self .sliding_window )
687
- local_attn_mask = torch .where (local_attn_mask == 0 ,
688
- global_attn_mask , float ("-inf" ))
689
- local_attn_masks .append (local_attn_mask )
684
+ if self .sliding_window is not None :
685
+ # Create a local causal mask with sliding window (1024).
686
+ local_attn_mask = torch .ones_like (global_attn_mask )
687
+ local_attn_mask = torch .tril (local_attn_mask ,
688
+ diagonal = - self .sliding_window )
689
+ local_attn_mask = torch .where (local_attn_mask == 0 ,
690
+ global_attn_mask , float ("-inf" ))
691
+ local_attn_masks .append (local_attn_mask )
690
692
kwargs ["global_attn_masks" ] = global_attn_masks
691
693
kwargs ["local_attn_masks" ] = local_attn_masks
692
694
return kwargs
0 commit comments