@@ -50,7 +50,8 @@ def compile(
50
50
graph : fx .GraphModule ,
51
51
example_inputs : List [Any ],
52
52
compiler_config : Dict [str , Any ],
53
- runtime_shape : Optional [int ] = None
53
+ runtime_shape : Optional [int ] = None ,
54
+ key : Optional [str ] = None ,
54
55
) -> Tuple [Optional [Callable ], Optional [Any ]]:
55
56
"""
56
57
Compile the graph with the given example inputs and compiler config,
@@ -71,6 +72,10 @@ def compile(
71
72
If the compiler doesn't support caching, it should return None for the
72
73
handle. If the compiler fails to compile the graph, it should return
73
74
None for the compiled function as well.
75
+
76
+ `key` is required for StandaloneInductorAdapter, it specifies where to
77
+ save the compiled artifact. The compiled artifact gets saved to
78
+ `cache_dir/key`.
74
79
"""
75
80
return None , None
76
81
@@ -127,23 +132,108 @@ def produce_guards_expression(self, *args, **kwargs):
127
132
return ""
128
133
129
134
135
+ def get_inductor_factors () -> List [Any ]:
136
+ factors : List [Any ] = []
137
+ # summarize system state
138
+ from torch ._inductor .codecache import CacheBase
139
+ system_factors = CacheBase .get_system ()
140
+ factors .append (system_factors )
141
+
142
+ # summarize pytorch state
143
+ from torch ._inductor .codecache import torch_key
144
+ torch_factors = torch_key ()
145
+ factors .append (torch_factors )
146
+ return factors
147
+
148
+
149
+ class InductorStandaloneAdaptor (CompilerInterface ):
150
+ """
151
+ The adaptor for the Inductor compiler.
152
+ Requires PyTorch 2.8+.
153
+ This is not on by default yet, but we plan to turn it on by default for
154
+ PyTorch 2.8.
155
+
156
+ Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
157
+ """
158
+ name = "inductor_standalone"
159
+
160
+ def compute_hash (self , vllm_config : VllmConfig ) -> str :
161
+ factors = get_inductor_factors ()
162
+ hash_str = hashlib .md5 (str (factors ).encode (),
163
+ usedforsecurity = False ).hexdigest ()[:10 ]
164
+ return hash_str
165
+
166
+ def initialize_cache (self , cache_dir : str , disable_cache : bool = False ):
167
+ self .cache_dir = cache_dir
168
+
169
+ def compile (
170
+ self ,
171
+ graph : fx .GraphModule ,
172
+ example_inputs : List [Any ],
173
+ compiler_config : Dict [str , Any ],
174
+ runtime_shape : Optional [int ] = None ,
175
+ key : Optional [str ] = None ,
176
+ ) -> Tuple [Optional [Callable ], Optional [Any ]]:
177
+ current_config = {}
178
+ if compiler_config is not None :
179
+ current_config .update (compiler_config )
180
+ set_inductor_config (current_config , runtime_shape )
181
+
182
+ if isinstance (runtime_shape , int ):
183
+ dynamic_shapes = "from_example_inputs"
184
+ else :
185
+ dynamic_shapes = "from_tracing_context"
186
+
187
+ from torch ._inductor import standalone_compile
188
+ with pass_context (runtime_shape ):
189
+ compiled_graph = standalone_compile (
190
+ graph ,
191
+ example_inputs ,
192
+ dynamic_shapes = dynamic_shapes ,
193
+ options = {"config_patches" : current_config })
194
+
195
+ # Save the compiled artifact to disk in the specified path
196
+ assert key is not None
197
+ path = os .path .join (self .cache_dir , key )
198
+ compiled_graph .save (path = path , format = "unpacked" )
199
+ return compiled_graph , (key , path )
200
+
201
+ def load (self ,
202
+ handle : Any ,
203
+ graph : fx .GraphModule ,
204
+ example_inputs : List [Any ],
205
+ graph_index : int ,
206
+ runtime_shape : Optional [int ] = None ) -> Callable :
207
+ assert isinstance (handle , tuple )
208
+ assert isinstance (handle [0 ], str )
209
+ assert isinstance (handle [1 ], str )
210
+ path = handle [1 ]
211
+ inductor_compiled_graph = torch ._inductor .CompiledArtifact .load (
212
+ path = path , format = "unpacked" )
213
+ from torch ._inductor .compile_fx import graph_returns_tuple
214
+ returns_tuple = graph_returns_tuple (graph )
215
+
216
+ def compiled_graph_wrapper (* args ):
217
+ graph_output = inductor_compiled_graph (* args )
218
+ # unpack the tuple if needed
219
+ # TODO(rzou): the implication is that we're not
220
+ # reading the python bytecode correctly in vLLM?
221
+ if returns_tuple :
222
+ return graph_output
223
+ else :
224
+ return graph_output [0 ]
225
+
226
+ return compiled_graph_wrapper
227
+
228
+
130
229
class InductorAdaptor (CompilerInterface ):
131
230
"""
132
- The adaptor for the Inductor compiler, version 2.5 and 2.6.
231
+ The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7 .
133
232
"""
134
233
name = "inductor"
135
234
136
235
def compute_hash (self , vllm_config : VllmConfig ) -> str :
137
- factors : List [Any ] = []
138
- # summarize system state
139
- from torch ._inductor .codecache import CacheBase
140
- system_factors = CacheBase .get_system ()
141
- factors .append (system_factors )
142
-
143
- # summarize pytorch state
144
- from torch ._inductor .codecache import torch_key
145
- torch_factors = torch_key ()
146
- factors .append (torch_factors )
236
+ factors = get_inductor_factors ()
147
237
hash_str = hashlib .md5 (str (factors ).encode (),
148
238
usedforsecurity = False ).hexdigest ()[:10 ]
149
239
return hash_str
@@ -168,23 +258,19 @@ def compile(
168
258
graph : fx .GraphModule ,
169
259
example_inputs : List [Any ],
170
260
compiler_config : Dict [str , Any ],
171
- runtime_shape : Optional [int ] = None
261
+ runtime_shape : Optional [int ] = None ,
262
+ key : Optional [str ] = None ,
172
263
) -> Tuple [Optional [Callable ], Optional [Any ]]:
173
- current_config = {}
174
264
from torch ._inductor .compile_fx import compile_fx
265
+ current_config = {}
266
+ if compiler_config is not None :
267
+ current_config .update (compiler_config )
175
268
176
269
# disable remote cache
177
270
current_config ["fx_graph_cache" ] = True
178
271
current_config ["fx_graph_remote_cache" ] = False
179
272
180
- if compiler_config is not None :
181
- current_config .update (compiler_config )
182
-
183
- if isinstance (runtime_shape , int ):
184
- # for a specific batchsize, tuning triton kernel parameters
185
- # can be beneficial
186
- current_config ["max_autotune" ] = True
187
- current_config ["coordinate_descent_tuning" ] = True
273
+ set_inductor_config (current_config , runtime_shape )
188
274
189
275
# inductor can inplace modify the graph, so we need to copy it
190
276
# see https://github.com/pytorch/pytorch/issues/138980
@@ -422,6 +508,14 @@ def metrics_context(self) -> contextlib.AbstractContextManager:
422
508
return contextlib .nullcontext ()
423
509
424
510
511
+ def set_inductor_config (config , runtime_shape ):
512
+ if isinstance (runtime_shape , int ):
513
+ # for a specific batchsize, tuning triton kernel parameters
514
+ # can be beneficial
515
+ config ["max_autotune" ] = True
516
+ config ["coordinate_descent_tuning" ] = True
517
+
518
+
425
519
class EagerAdaptor (CompilerInterface ):
426
520
name = "eager"
427
521
@@ -430,7 +524,8 @@ def compile(
430
524
graph : fx .GraphModule ,
431
525
example_inputs : List [Any ],
432
526
compiler_config : Dict [str , Any ],
433
- runtime_shape : Optional [int ] = None
527
+ runtime_shape : Optional [int ] = None ,
528
+ key : Optional [str ] = None ,
434
529
) -> Tuple [Optional [Callable ], Optional [Any ]]:
435
530
# we don't need to compile the graph, just return the graph itself.
436
531
# It does not support caching, return None for the handle.
0 commit comments