Skip to content

Commit 3393d47

Browse files
committed
Patch AOTAutogradCache._get_shape_env
1 parent 6d0df0e commit 3393d47

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

vllm/compilation/compiler_interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280280
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
281281
_get_shape_env))
282282

283+
stack.enter_context(
284+
patch("torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env",
285+
_get_shape_env))
286+
283287
# for forcing the graph to be cached
284288
stack.enter_context(
285289
patch(
@@ -330,6 +334,9 @@ def load(self,
330334
exit_stack.enter_context(
331335
patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
332336
lambda *args, **kwargs: AlwaysHitShapeEnv()))
337+
exit_stack.enter_context(
338+
patch("torch._inductor.codecache.GuardedCache._get_shape_env",
339+
lambda *args, **kwargs: AlwaysHitShapeEnv()))
333340

334341
# Dynamo metrics context, see method for more details.
335342
exit_stack.enter_context(self.metrics_context())

0 commit comments

Comments
 (0)