Skip to content

[Bugfix][Misc] Use TritonPlaceholderModule to defensively import triton #15099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions benchmarks/kernels/benchmark_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@
from utils import ArgPool, Bench, CudaGraphBenchParams
from weight_shapes import WEIGHT_SHAPES

from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
lora_shrink)
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
_LORA_B_PTR_DICT)

from vllm.utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON

TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0"))

if TRITON3:

Expand Down
2 changes: 1 addition & 1 deletion vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from vllm.triton_utils.importing import HAS_TRITON

__all__ = ["HAS_TRITON"]
__all__ = ["HAS_TRITON"]
40 changes: 38 additions & 2 deletions vllm/triton_utils/importing.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,53 @@
# SPDX-License-Identifier: Apache-2.0

import sys
import types
from importlib.util import find_spec

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

HAS_TRITON = (
find_spec("triton") is not None
and not current_platform.is_xpu() # Not compatible
or find_spec("pytorch-triton-xpu") is not None # Not compatible
)

if not HAS_TRITON:
logger.info("Triton not installed or not compatible; certain GPU-related"
" functions will not be available.")

class TritonPlaceholder(types.ModuleType):

def __init__(self):
super().__init__("triton")
self.jit = self._dummy_decorator("jit")
self.autotune = self._dummy_decorator("autotune")
self.heuristics = self._dummy_decorator("heuristics")
self.language = TritonLanguagePlaceholder()
logger.warning_once(
"Triton is not installed. Using dummy decorators. "
"Install it via `pip install triton` to enable kernel"
"compilation.")

def _dummy_decorator(self, name):

def decorator(func=None, **kwargs):
if func is None:
return lambda f: f
return func

return decorator

class TritonLanguagePlaceholder(types.ModuleType):

def __init__(self):
super().__init__("triton.language")
self.constexpr = None
self.dtype = None

sys.modules['triton'] = TritonPlaceholder()
sys.modules['triton.language'] = TritonLanguagePlaceholder()

if 'triton' in sys.modules:
logger.info("Triton module has been replaced with a placeholder.")
3 changes: 3 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
from typing_extensions import Never, ParamSpec, TypeIs, assert_never

import vllm.envs as envs
# NOTE: import triton_utils to make TritonPlaceholderModule work
# if triton is unavailable
import vllm.triton_utils # noqa: F401
from vllm.logger import enable_trace_function_call, init_logger

if TYPE_CHECKING:
Expand Down