File tree 1 file changed +7
-0
lines changed
1 file changed +7
-0
lines changed Original file line number Diff line number Diff line change @@ -280,6 +280,10 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280
280
patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
281
281
_get_shape_env ))
282
282
283
+ stack .enter_context (
284
+ patch ("torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
285
+ _get_shape_env ))
286
+
283
287
# for forcing the graph to be cached
284
288
stack .enter_context (
285
289
patch (
@@ -330,6 +334,9 @@ def load(self,
330
334
exit_stack .enter_context (
331
335
patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
332
336
lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
337
+ exit_stack .enter_context (
338
+ patch ("torch._inductor.codecache.GuardedCache._get_shape_env" ,
339
+ lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
333
340
334
341
# Dynamo metrics context, see method for more details.
335
342
exit_stack .enter_context (self .metrics_context ())
You can’t perform that action at this time.
0 commit comments