Skip to content

Commit c30b25f

Browse files
committed
[Bugfix][Misc] use TritonPlaceholderModule to defensively import triton
Signed-off-by: Mengqing Cao <cmq0113@163.com>
1 parent 9420a1f commit c30b25f

File tree

5 files changed

+53
-6
lines changed

5 files changed

+53
-6
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@
1717
from utils import ArgPool, Bench, CudaGraphBenchParams
1818
from weight_shapes import WEIGHT_SHAPES
1919

20-
from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink
21-
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
20+
from vllm.triton_utils import HAS_TRITON
21+
22+
if HAS_TRITON:
23+
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
24+
lora_shrink)
25+
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
26+
_LORA_B_PTR_DICT)
27+
2228
from vllm.utils import FlexibleArgumentParser
2329

2430
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())

vllm/model_executor/layers/mamba/ops/mamba_ssm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from vllm import _custom_ops as ops
1212
from vllm.attention.backends.utils import PAD_SLOT_ID
13+
from vllm.triton_utils import HAS_TRITON
1314

14-
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
15+
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
16+
>= version.parse("3.0.0"))
1517

1618
if TRITON3:
1719

vllm/triton_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
from vllm.triton_utils.importing import HAS_TRITON
44

5-
__all__ = ["HAS_TRITON"]
5+
__all__ = ["HAS_TRITON"]

vllm/triton_utils/importing.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,53 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3+
import sys
4+
import types
35
from importlib.util import find_spec
46

57
from vllm.logger import init_logger
6-
from vllm.platforms import current_platform
78

89
logger = init_logger(__name__)
910

1011
HAS_TRITON = (
1112
find_spec("triton") is not None
12-
and not current_platform.is_xpu() # Not compatible
13+
or find_spec("pytorch-triton-xpu") is not None # Not compatible
1314
)
1415

1516
if not HAS_TRITON:
1617
logger.info("Triton not installed or not compatible; certain GPU-related"
1718
" functions will not be available.")
19+
20+
class TritonPlaceholder(types.ModuleType):
21+
22+
def __init__(self):
23+
super().__init__("triton")
24+
self.jit = self._dummy_decorator("jit")
25+
self.autotune = self._dummy_decorator("autotune")
26+
self.heuristics = self._dummy_decorator("heuristics")
27+
self.language = TritonLanguagePlaceholder()
28+
logger.warning_once(
29+
"Triton is not installed. Using dummy decorators. "
30+
"Install it via `pip install triton` to enable kernel"
31+
"compilation.")
32+
33+
def _dummy_decorator(self, name):
34+
35+
def decorator(func=None, **kwargs):
36+
if func is None:
37+
return lambda f: f
38+
return func
39+
40+
return decorator
41+
42+
class TritonLanguagePlaceholder(types.ModuleType):
43+
44+
def __init__(self):
45+
super().__init__("triton.language")
46+
self.constexpr = None
47+
self.dtype = None
48+
49+
sys.modules['triton'] = TritonPlaceholder()
50+
sys.modules['triton.language'] = TritonLanguagePlaceholder()
51+
52+
if 'triton' in sys.modules:
53+
logger.info("Triton module has been replaced with a placeholder.")

vllm/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@
6363
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
6464

6565
import vllm.envs as envs
66+
# NOTE: import triton_utils to make TritonPlaceholderModule work
67+
# if triton is unavailable
68+
import vllm.triton_utils # noqa: F401
6669
from vllm.logger import enable_trace_function_call, init_logger
6770

6871
if TYPE_CHECKING:

0 commit comments

Comments
 (0)