Skip to content

Commit 26bc4bb

Browse files
authored
Avoid overwriting vllm_compile_cache.py (#17418)
Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
1 parent 3c3d767 commit 26bc4bb

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

vllm/compilation/backends.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, use_inductor: bool):
4545
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
4646
cls = InductorAdaptor if use_inductor else EagerAdaptor
4747
self.compiler = cls()
48+
self.is_cache_updated = False
4849

4950
def compute_hash(self, vllm_config: VllmConfig) -> str:
5051
return self.compiler.compute_hash(vllm_config)
@@ -66,11 +67,11 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
6667
disable_cache=disable_cache)
6768

6869
def save_to_file(self):
69-
if self.disable_cache:
70+
if self.disable_cache or not self.is_cache_updated:
7071
return
72+
printer = pprint.PrettyPrinter(indent=4)
73+
data = printer.pformat(self.cache)
7174
with open(self.cache_file_path, "w") as f:
72-
printer = pprint.PrettyPrinter(indent=4)
73-
data = printer.pformat(self.cache)
7475
f.write(data)
7576

7677
def load(self,
@@ -131,6 +132,7 @@ def compile(self,
131132
if handle is not None:
132133
self.cache[(runtime_shape, graph_index,
133134
self.compiler.name)] = handle
135+
self.is_cache_updated = True
134136
if graph_index == 0:
135137
# adds some info logging for the first graph
136138
logger.info("Cache the graph of shape %s for later use",

0 commit comments

Comments
 (0)