Skip to content

Commit a6e72e1

Browse files
authored
[Bugfix] [pytorch] Patch AOTAutogradCache._get_shape_env (#17142)
Signed-off-by: James Wu <jjwu@meta.com>
1 parent 5e83a72 commit a6e72e1

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

vllm/compilation/compiler_interface.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def compile(
195195
hash_str, file_path = None, None
196196
from torch._inductor.codecache import (FxGraphCache,
197197
compiled_fx_graph_hash)
198-
199198
if torch.__version__.startswith("2.5"):
200199
original_load = FxGraphCache.load
201200
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@@ -280,6 +279,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280279
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
281280
_get_shape_env))
282281

282+
from torch._functorch._aot_autograd.autograd_cache import (
283+
AOTAutogradCache)
284+
285+
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
286+
if hasattr(AOTAutogradCache, "_get_shape_env"):
287+
stack.enter_context(
288+
patch(
289+
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
290+
_get_shape_env))
291+
283292
# for forcing the graph to be cached
284293
stack.enter_context(
285294
patch(
@@ -325,11 +334,19 @@ def load(self,
325334
assert isinstance(handle[1], str)
326335
hash_str = handle[0]
327336

337+
from torch._functorch._aot_autograd.autograd_cache import (
338+
AOTAutogradCache)
328339
from torch._inductor.codecache import FxGraphCache
329340
with ExitStack() as exit_stack:
330341
exit_stack.enter_context(
331342
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
332343
lambda *args, **kwargs: AlwaysHitShapeEnv()))
344+
# torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
345+
if hasattr(AOTAutogradCache, "_get_shape_env"):
346+
exit_stack.enter_context(
347+
patch(
348+
"torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
349+
lambda *args, **kwargs: AlwaysHitShapeEnv()))
333350

334351
# Dynamo metrics context, see method for more details.
335352
exit_stack.enter_context(self.metrics_context())

0 commit comments

Comments
 (0)