diff --git a/pyproject.toml b/pyproject.toml index febe2de61b4..d4ea9bb0868 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,8 @@ ignore_patterns = [ line-length = 80 exclude = [ # External file, leaving license intact - "examples/other/fp8/quantizer/quantize.py" + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" ] [tool.ruff.lint.per-file-ignores] diff --git a/setup.py b/setup.py index b0cc2f48163..ed4b88364a6 100755 --- a/setup.py +++ b/setup.py @@ -378,7 +378,6 @@ def run(self) -> None: "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", - "vllm/vllm_flash_attn/__init__.py", "vllm/cumem_allocator.abi3.so", # "vllm/_version.py", # not available in nightly wheels yet ] diff --git a/vllm/vllm_flash_attn/__init__.py b/vllm/vllm_flash_attn/__init__.py index e69de29bb2d..cf8f1207a65 100644 --- a/vllm/vllm_flash_attn/__init__.py +++ b/vllm/vllm_flash_attn/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib.metadata + +try: + __version__ = importlib.metadata.version("vllm-flash-attn") +except importlib.metadata.PackageNotFoundError: + # in this case, vllm-flash-attn is built from installing vllm editable + __version__ = "0.0.0.dev0" + +from .flash_attn_interface import (fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + get_scheduler_metadata, + is_fa_version_supported, sparse_attn_func, + sparse_attn_varlen_func) + +__all__ = [ + 'flash_attn_varlen_func', 'flash_attn_with_kvcache', + 'get_scheduler_metadata', 'sparse_attn_func', 'sparse_attn_varlen_func', + 'is_fa_version_supported', 'fa_version_unsupported_reason' +] diff --git a/vllm/vllm_flash_attn/flash_attn_interface.pyi b/vllm/vllm_flash_attn/flash_attn_interface.pyi new file mode 100644 index 00000000000..ca8311e0135 --- /dev/null +++ b/vllm/vllm_flash_attn/flash_attn_interface.pyi @@ -0,0 +1,245 @@ +# ruff: ignore +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, Literal, overload + +import torch + +def get_scheduler_metadata( + batch_size: int, + max_seqlen_q: int, + max_seqlen_k: int, + num_heads_q: int, + num_heads_kv: int, + headdim: int, + cache_seqlens: torch.Tensor, + qkv_dtype: torch.dtype = ..., + headdim_v: int | None = ..., + cu_seqlens_q: torch.Tensor | None = ..., + cu_seqlens_k_new: torch.Tensor | None = ..., + cache_leftpad: torch.Tensor | None = ..., + page_size: int = ..., + max_seqlen_k_new: int = ..., + causal: bool = ..., + window_size: tuple[int, int] = ..., + has_softcap: bool = ..., + num_splits: int = ..., + pack_gqa: Any | None = ..., + sm_margin: int = ..., +): ... +@overload +def flash_attn_varlen_func( + q: tuple[int, int, int], + k: tuple[int, int, int], + v: tuple[int, int, int], + max_seqlen_q: int, + cu_seqlens_q: torch.Tensor | None, + max_seqlen_k: int, + cu_seqlens_k: torch.Tensor | None = ..., + seqused_k: Any | None = ..., + q_v: Any | None = ..., + dropout_p: float = ..., + causal: bool = ..., + window_size: list[int] | None = ..., + softmax_scale: float = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + deterministic: bool = ..., + return_attn_probs: bool = ..., + block_table: Any | None = ..., + return_softmax_lse: Literal[False] = ..., + out: Any = ..., + # FA3 Only + scheduler_metadata: Any | None = ..., + q_descale: Any | None = ..., + k_descale: Any | None = ..., + v_descale: Any | None = ..., + # Version selector + fa_version: int = ..., +) -> tuple[int, int, int]: ... +@overload +def flash_attn_varlen_func( + q: tuple[int, int, int], + k: tuple[int, int, int], + v: tuple[int, int, int], + max_seqlen_q: int, + cu_seqlens_q: torch.Tensor | None, + max_seqlen_k: int, + cu_seqlens_k: torch.Tensor | None = ..., + seqused_k: Any | None = ..., + q_v: Any | None = ..., + dropout_p: float = ..., + causal: bool = ..., + window_size: list[int] | None = ..., + softmax_scale: float = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + deterministic: bool = ..., + return_attn_probs: bool = ..., + block_table: Any | None = ..., + return_softmax_lse: Literal[True] = ..., + out: Any = ..., + # FA3 Only + scheduler_metadata: Any | None = ..., + q_descale: Any | None = ..., + k_descale: Any | None = ..., + v_descale: Any | None = ..., + # Version selector + fa_version: int = ..., +) -> tuple[tuple[int, int, int], tuple[int, int]]: ... +@overload +def flash_attn_with_kvcache( + q: tuple[int, int, int, int], + k_cache: tuple[int, int, int, int], + v_cache: tuple[int, int, int, int], + k: tuple[int, int, int, int] | None = ..., + v: tuple[int, int, int, int] | None = ..., + rotary_cos: tuple[int, int] | None = ..., + rotary_sin: tuple[int, int] | None = ..., + cache_seqlens: int | torch.Tensor | None = None, + cache_batch_idx: torch.Tensor | None = None, + cache_leftpad: torch.Tensor | None = ..., + block_table: torch.Tensor | None = ..., + softmax_scale: float = ..., + causal: bool = ..., + window_size: tuple[int, int] = ..., # -1 means infinite context window + softcap: float = ..., + rotary_interleaved: bool = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + num_splits: int = ..., + return_softmax_lse: Literal[False] = ..., + *, + out: Any = ..., + # FA3 Only + scheduler_metadata: Any | None = ..., + q_descale: Any | None = ..., + k_descale: Any | None = ..., + v_descale: Any | None = ..., + # Version selector + fa_version: int = ..., +) -> tuple[int, int, int, int]: ... +@overload +def flash_attn_with_kvcache( + q: tuple[int, int, int, int], + k_cache: tuple[int, int, int, int], + v_cache: tuple[int, int, int, int], + k: tuple[int, int, int, int] | None = ..., + v: tuple[int, int, int, int] | None = ..., + rotary_cos: tuple[int, int] | None = ..., + rotary_sin: tuple[int, int] | None = ..., + cache_seqlens: int | torch.Tensor | None = None, + cache_batch_idx: torch.Tensor | None = None, + cache_leftpad: torch.Tensor | None = ..., + block_table: torch.Tensor | None = ..., + softmax_scale: float = ..., + causal: bool = ..., + window_size: tuple[int, int] = ..., # -1 means infinite context window + softcap: float = ..., + rotary_interleaved: bool = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + num_splits: int = ..., + return_softmax_lse: Literal[True] = ..., + *, + out: Any = ..., + # FA3 Only + scheduler_metadata: Any | None = ..., + q_descale: Any | None = ..., + k_descale: Any | None = ..., + v_descale: Any | None = ..., + # Version selector + fa_version: int = ..., +) -> tuple[tuple[int, int, int], tuple[int, int]]: ... +@overload +def sparse_attn_func( + q: tuple[int, int, int, int], + k: tuple[int, int, int, int], + v: tuple[int, int, int, int], + block_count: tuple[int, int, float], + block_offset: tuple[int, int, float, int], + column_count: tuple[int, int, float], + column_index: tuple[int, int, float, int], + dropout_p: float = ..., + softmax_scale: float = ..., + causal: bool = ..., + softcap: float = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + deterministic: bool = ..., + return_attn_probs: bool = ..., + *, + return_softmax_lse: Literal[False] = ..., + out: Any = ..., +) -> tuple[int, int, int]: ... +@overload +def sparse_attn_func( + q: tuple[int, int, int, int], + k: tuple[int, int, int, int], + v: tuple[int, int, int, int], + block_count: tuple[int, int, float], + block_offset: tuple[int, int, float, int], + column_count: tuple[int, int, float], + column_index: tuple[int, int, float, int], + dropout_p: float = ..., + softmax_scale: float = ..., + causal: bool = ..., + softcap: float = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + deterministic: bool = ..., + return_attn_probs: bool = ..., + *, + return_softmax_lse: Literal[True] = ..., + out: Any = ..., +) -> tuple[tuple[int, int, int], tuple[int, int]]: ... +@overload +def sparse_attn_varlen_func( + q: tuple[int, int, int], + k: tuple[int, int, int], + v: tuple[int, int, int], + block_count: tuple[int, int, float], + block_offset: tuple[int, int, float, int], + column_count: tuple[int, int, float], + column_index: tuple[int, int, float, int], + cu_seqlens_q: torch.Tensor | None, + cu_seqlens_k: torch.Tensor | None, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float = ..., + softmax_scale: float = ..., + causal: bool = ..., + softcap: float = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + deterministic: bool = ..., + return_attn_probs: bool = ..., + *, + return_softmax_lse: Literal[False] = ..., + out: Any = ..., +) -> tuple[int, int, int]: ... +@overload +def sparse_attn_varlen_func( + q: tuple[int, int, int], + k: tuple[int, int, int], + v: tuple[int, int, int], + block_count: tuple[int, int, float], + block_offset: tuple[int, int, float, int], + column_count: tuple[int, int, float], + column_index: tuple[int, int, float, int], + cu_seqlens_q: torch.Tensor | None, + cu_seqlens_k: torch.Tensor | None, + max_seqlen_q: int, + max_seqlen_k: int, + dropout_p: float = ..., + softmax_scale: float = ..., + causal: bool = ..., + softcap: float = ..., + alibi_slopes: tuple[int] | tuple[int, int] | None = ..., + deterministic: bool = ..., + return_attn_probs: bool = ..., + *, + return_softmax_lse: Literal[True] = ..., + out: Any = ..., +) -> tuple[tuple[int, int, int], tuple[int, int]]: ... +def is_fa_version_supported( + fa_version: int, device: torch.device | None = None +) -> bool: ... +def fa_version_unsupported_reason( + fa_version: int, device: torch.device | None = None +) -> str | None: ...