diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index b4b91eda284..d382ede10b4 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -17,8 +17,14 @@ from utils import ArgPool, Bench, CudaGraphBenchParams from weight_shapes import WEIGHT_SHAPES -from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink -from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, + lora_shrink) + from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, + _LORA_B_PTR_DICT) + from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index b31b980fbe8..9fbad9d2f91 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -10,8 +10,10 @@ from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import HAS_TRITON -TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +TRITON3 = HAS_TRITON and (version.parse(triton.__version__) + >= version.parse("3.0.0")) if TRITON3: diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py index 43918bcd7c5..bffc56a2e75 100644 --- a/vllm/triton_utils/__init__.py +++ b/vllm/triton_utils/__init__.py @@ -2,4 +2,4 @@ from vllm.triton_utils.importing import HAS_TRITON -__all__ = ["HAS_TRITON"] \ No newline at end of file +__all__ = ["HAS_TRITON"] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index a20700248c2..fa29efbf6b2 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -1,17 +1,53 @@ # SPDX-License-Identifier: Apache-2.0 +import sys +import types from importlib.util import find_spec from vllm.logger import init_logger -from vllm.platforms import current_platform logger = init_logger(__name__) HAS_TRITON = ( find_spec("triton") is not None - and not current_platform.is_xpu() # Not compatible + or find_spec("pytorch-triton-xpu") is not None # Not compatible ) if not HAS_TRITON: logger.info("Triton not installed or not compatible; certain GPU-related" " functions will not be available.") + + class TritonPlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton") + self.jit = self._dummy_decorator("jit") + self.autotune = self._dummy_decorator("autotune") + self.heuristics = self._dummy_decorator("heuristics") + self.language = TritonLanguagePlaceholder() + logger.warning_once( + "Triton is not installed. Using dummy decorators. " + "Install it via `pip install triton` to enable kernel" + "compilation.") + + def _dummy_decorator(self, name): + + def decorator(func=None, **kwargs): + if func is None: + return lambda f: f + return func + + return decorator + + class TritonLanguagePlaceholder(types.ModuleType): + + def __init__(self): + super().__init__("triton.language") + self.constexpr = None + self.dtype = None + + sys.modules['triton'] = TritonPlaceholder() + sys.modules['triton.language'] = TritonLanguagePlaceholder() + +if 'triton' in sys.modules: + logger.info("Triton module has been replaced with a placeholder.") diff --git a/vllm/utils.py b/vllm/utils.py index ed406a6b7b1..92b493a69e5 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -63,6 +63,9 @@ from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs +# NOTE: import triton_utils to make TritonPlaceholderModule work +# if triton is unavailable +import vllm.triton_utils # noqa: F401 from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: