@@ -195,7 +195,6 @@ def compile(
195
195
hash_str , file_path = None , None
196
196
from torch ._inductor .codecache import (FxGraphCache ,
197
197
compiled_fx_graph_hash )
198
-
199
198
if torch .__version__ .startswith ("2.5" ):
200
199
original_load = FxGraphCache .load
201
200
original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@@ -280,6 +279,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280
279
patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
281
280
_get_shape_env ))
282
281
282
+ from torch ._functorch ._aot_autograd .autograd_cache import (
283
+ AOTAutogradCache )
284
+
285
+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
286
+ if hasattr (AOTAutogradCache , "_get_shape_env" ):
287
+ stack .enter_context (
288
+ patch (
289
+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
290
+ _get_shape_env ))
291
+
283
292
# for forcing the graph to be cached
284
293
stack .enter_context (
285
294
patch (
@@ -325,11 +334,19 @@ def load(self,
325
334
assert isinstance (handle [1 ], str )
326
335
hash_str = handle [0 ]
327
336
337
+ from torch ._functorch ._aot_autograd .autograd_cache import (
338
+ AOTAutogradCache )
328
339
from torch ._inductor .codecache import FxGraphCache
329
340
with ExitStack () as exit_stack :
330
341
exit_stack .enter_context (
331
342
patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
332
343
lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
344
+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
345
+ if hasattr (AOTAutogradCache , "_get_shape_env" ):
346
+ exit_stack .enter_context (
347
+ patch (
348
+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
349
+ lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
333
350
334
351
# Dynamo metrics context, see method for more details.
335
352
exit_stack .enter_context (self .metrics_context ())
0 commit comments