Skip to content

Commit 01d5d45

Browse files
zou3519mawong-amd
authored andcommitted
Add option to use torch._inductor.standalone_compile (vllm-project#17057)
Signed-off-by: rzou <zou3519@gmail.com>
1 parent c2d46b7 commit 01d5d45

File tree

3 files changed

+150
-29
lines changed

3 files changed

+150
-29
lines changed

vllm/compilation/backends.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from vllm.logger import init_logger
1818
from vllm.utils import weak_ref_tensors
1919

20-
from .compiler_interface import EagerAdaptor, InductorAdaptor
20+
from .compiler_interface import (CompilerInterface, EagerAdaptor,
21+
InductorAdaptor, InductorStandaloneAdaptor)
2122
from .counter import compilation_counter
2223
from .inductor_pass import InductorPass
2324
from .monitor import end_monitoring_torch_compile
@@ -26,6 +27,19 @@
2627
logger = init_logger(__name__)
2728

2829

30+
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
31+
if compilation_config.use_inductor:
32+
if envs.VLLM_TEST_STANDALONE_COMPILE:
33+
logger.info("Using InductorStandaloneAdaptor")
34+
return InductorStandaloneAdaptor()
35+
else:
36+
logger.info("Using InductorAdaptor")
37+
return InductorAdaptor()
38+
else:
39+
logger.info("Using EagerAdaptor")
40+
return EagerAdaptor()
41+
42+
2943
class CompilerManager:
3044
"""
3145
A manager to manage the compilation process, including
@@ -41,11 +55,11 @@ class CompilerManager:
4155
support int as key.
4256
"""
4357

44-
def __init__(self, use_inductor: bool):
58+
def __init__(self, compilation_config: CompilationConfig):
4559
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
46-
cls = InductorAdaptor if use_inductor else EagerAdaptor
47-
self.compiler = cls()
4860
self.is_cache_updated = False
61+
self.compilation_config = compilation_config
62+
self.compiler = make_compiler(compilation_config)
4963

5064
def compute_hash(self, vllm_config: VllmConfig) -> str:
5165
return self.compiler.compute_hash(vllm_config)
@@ -123,8 +137,15 @@ def compile(self,
123137

124138
# no compiler cached the graph, or the cache is disabled,
125139
# we need to compile it
140+
if isinstance(self.compiler, InductorAdaptor):
141+
# Let compile_fx generate a key for us
142+
maybe_key = None
143+
else:
144+
maybe_key = \
145+
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
126146
compiled_graph, handle = self.compiler.compile(
127-
graph, example_inputs, additional_inductor_config, runtime_shape)
147+
graph, example_inputs, additional_inductor_config, runtime_shape,
148+
maybe_key)
128149

129150
assert compiled_graph is not None, "Failed to compile the graph"
130151

@@ -336,7 +357,7 @@ def __init__(
336357
self.compilation_config = vllm_config.compilation_config
337358

338359
self.compiler_manager: CompilerManager = CompilerManager(
339-
self.compilation_config.use_inductor)
360+
self.compilation_config)
340361

341362
# `torch.compile` is JIT compiled, so we don't need to
342363
# do anything here

vllm/compilation/compiler_interface.py

Lines changed: 118 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def compile(
5050
graph: fx.GraphModule,
5151
example_inputs: List[Any],
5252
compiler_config: Dict[str, Any],
53-
runtime_shape: Optional[int] = None
53+
runtime_shape: Optional[int] = None,
54+
key: Optional[str] = None,
5455
) -> Tuple[Optional[Callable], Optional[Any]]:
5556
"""
5657
Compile the graph with the given example inputs and compiler config,
@@ -71,6 +72,10 @@ def compile(
7172
If the compiler doesn't support caching, it should return None for the
7273
handle. If the compiler fails to compile the graph, it should return
7374
None for the compiled function as well.
75+
76+
`key` is required for StandaloneInductorAdapter, it specifies where to
77+
save the compiled artifact. The compiled artifact gets saved to
78+
`cache_dir/key`.
7479
"""
7580
return None, None
7681

@@ -127,23 +132,108 @@ def produce_guards_expression(self, *args, **kwargs):
127132
return ""
128133

129134

135+
def get_inductor_factors() -> List[Any]:
136+
factors: List[Any] = []
137+
# summarize system state
138+
from torch._inductor.codecache import CacheBase
139+
system_factors = CacheBase.get_system()
140+
factors.append(system_factors)
141+
142+
# summarize pytorch state
143+
from torch._inductor.codecache import torch_key
144+
torch_factors = torch_key()
145+
factors.append(torch_factors)
146+
return factors
147+
148+
149+
class InductorStandaloneAdaptor(CompilerInterface):
150+
"""
151+
The adaptor for the Inductor compiler.
152+
Requires PyTorch 2.8+.
153+
This is not on by default yet, but we plan to turn it on by default for
154+
PyTorch 2.8.
155+
156+
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
157+
"""
158+
name = "inductor_standalone"
159+
160+
def compute_hash(self, vllm_config: VllmConfig) -> str:
161+
factors = get_inductor_factors()
162+
hash_str = hashlib.md5(str(factors).encode(),
163+
usedforsecurity=False).hexdigest()[:10]
164+
return hash_str
165+
166+
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
167+
self.cache_dir = cache_dir
168+
169+
def compile(
170+
self,
171+
graph: fx.GraphModule,
172+
example_inputs: List[Any],
173+
compiler_config: Dict[str, Any],
174+
runtime_shape: Optional[int] = None,
175+
key: Optional[str] = None,
176+
) -> Tuple[Optional[Callable], Optional[Any]]:
177+
current_config = {}
178+
if compiler_config is not None:
179+
current_config.update(compiler_config)
180+
set_inductor_config(current_config, runtime_shape)
181+
182+
if isinstance(runtime_shape, int):
183+
dynamic_shapes = "from_example_inputs"
184+
else:
185+
dynamic_shapes = "from_tracing_context"
186+
187+
from torch._inductor import standalone_compile
188+
with pass_context(runtime_shape):
189+
compiled_graph = standalone_compile(
190+
graph,
191+
example_inputs,
192+
dynamic_shapes=dynamic_shapes,
193+
options={"config_patches": current_config})
194+
195+
# Save the compiled artifact to disk in the specified path
196+
assert key is not None
197+
path = os.path.join(self.cache_dir, key)
198+
compiled_graph.save(path=path, format="unpacked")
199+
return compiled_graph, (key, path)
200+
201+
def load(self,
202+
handle: Any,
203+
graph: fx.GraphModule,
204+
example_inputs: List[Any],
205+
graph_index: int,
206+
runtime_shape: Optional[int] = None) -> Callable:
207+
assert isinstance(handle, tuple)
208+
assert isinstance(handle[0], str)
209+
assert isinstance(handle[1], str)
210+
path = handle[1]
211+
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
212+
path=path, format="unpacked")
213+
from torch._inductor.compile_fx import graph_returns_tuple
214+
returns_tuple = graph_returns_tuple(graph)
215+
216+
def compiled_graph_wrapper(*args):
217+
graph_output = inductor_compiled_graph(*args)
218+
# unpack the tuple if needed
219+
# TODO(rzou): the implication is that we're not
220+
# reading the python bytecode correctly in vLLM?
221+
if returns_tuple:
222+
return graph_output
223+
else:
224+
return graph_output[0]
225+
226+
return compiled_graph_wrapper
227+
228+
130229
class InductorAdaptor(CompilerInterface):
131230
"""
132-
The adaptor for the Inductor compiler, version 2.5 and 2.6.
231+
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
133232
"""
134233
name = "inductor"
135234

136235
def compute_hash(self, vllm_config: VllmConfig) -> str:
137-
factors: List[Any] = []
138-
# summarize system state
139-
from torch._inductor.codecache import CacheBase
140-
system_factors = CacheBase.get_system()
141-
factors.append(system_factors)
142-
143-
# summarize pytorch state
144-
from torch._inductor.codecache import torch_key
145-
torch_factors = torch_key()
146-
factors.append(torch_factors)
236+
factors = get_inductor_factors()
147237
hash_str = hashlib.md5(str(factors).encode(),
148238
usedforsecurity=False).hexdigest()[:10]
149239
return hash_str
@@ -168,23 +258,19 @@ def compile(
168258
graph: fx.GraphModule,
169259
example_inputs: List[Any],
170260
compiler_config: Dict[str, Any],
171-
runtime_shape: Optional[int] = None
261+
runtime_shape: Optional[int] = None,
262+
key: Optional[str] = None,
172263
) -> Tuple[Optional[Callable], Optional[Any]]:
173-
current_config = {}
174264
from torch._inductor.compile_fx import compile_fx
265+
current_config = {}
266+
if compiler_config is not None:
267+
current_config.update(compiler_config)
175268

176269
# disable remote cache
177270
current_config["fx_graph_cache"] = True
178271
current_config["fx_graph_remote_cache"] = False
179272

180-
if compiler_config is not None:
181-
current_config.update(compiler_config)
182-
183-
if isinstance(runtime_shape, int):
184-
# for a specific batchsize, tuning triton kernel parameters
185-
# can be beneficial
186-
current_config["max_autotune"] = True
187-
current_config["coordinate_descent_tuning"] = True
273+
set_inductor_config(current_config, runtime_shape)
188274

189275
# inductor can inplace modify the graph, so we need to copy it
190276
# see https://github.com/pytorch/pytorch/issues/138980
@@ -422,6 +508,14 @@ def metrics_context(self) -> contextlib.AbstractContextManager:
422508
return contextlib.nullcontext()
423509

424510

511+
def set_inductor_config(config, runtime_shape):
512+
if isinstance(runtime_shape, int):
513+
# for a specific batchsize, tuning triton kernel parameters
514+
# can be beneficial
515+
config["max_autotune"] = True
516+
config["coordinate_descent_tuning"] = True
517+
518+
425519
class EagerAdaptor(CompilerInterface):
426520
name = "eager"
427521

@@ -430,7 +524,8 @@ def compile(
430524
graph: fx.GraphModule,
431525
example_inputs: List[Any],
432526
compiler_config: Dict[str, Any],
433-
runtime_shape: Optional[int] = None
527+
runtime_shape: Optional[int] = None,
528+
key: Optional[str] = None,
434529
) -> Tuple[Optional[Callable], Optional[Any]]:
435530
# we don't need to compile the graph, just return the graph itself.
436531
# It does not support caching, return None for the handle.

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
270270
lambda: bool(
271271
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
272272

273+
# Internal flag to enable/disable Inductor standalone compile
274+
"VLLM_TEST_STANDALONE_COMPILE":
275+
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",
276+
273277
# local rank of the process in the distributed setting, used to determine
274278
# the GPU device id
275279
"LOCAL_RANK":
@@ -812,6 +816,7 @@ def factorize(name: str):
812816
"VLLM_USE_TRITON_AWQ",
813817
"VLLM_DP_RANK",
814818
"VLLM_DP_SIZE",
819+
"VLLM_TEST_STANDALONE_COMPILE",
815820
]
816821
for key in environment_variables_to_hash:
817822
if key in environment_variables:

0 commit comments

Comments
 (0)