@@ -45,6 +45,7 @@ def __init__(self, use_inductor: bool):
45
45
self .cache : Dict [Tuple [Optional [int ], int , str ], Any ] = dict ()
46
46
cls = InductorAdaptor if use_inductor else EagerAdaptor
47
47
self .compiler = cls ()
48
+ self .is_cache_updated = False
48
49
49
50
def compute_hash (self , vllm_config : VllmConfig ) -> str :
50
51
return self .compiler .compute_hash (vllm_config )
@@ -66,11 +67,11 @@ def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
66
67
disable_cache = disable_cache )
67
68
68
69
def save_to_file (self ):
69
- if self .disable_cache :
70
+ if self .disable_cache or not self . is_cache_updated :
70
71
return
72
+ printer = pprint .PrettyPrinter (indent = 4 )
73
+ data = printer .pformat (self .cache )
71
74
with open (self .cache_file_path , "w" ) as f :
72
- printer = pprint .PrettyPrinter (indent = 4 )
73
- data = printer .pformat (self .cache )
74
75
f .write (data )
75
76
76
77
def load (self ,
@@ -131,6 +132,7 @@ def compile(self,
131
132
if handle is not None :
132
133
self .cache [(runtime_shape , graph_index ,
133
134
self .compiler .name )] = handle
135
+ self .is_cache_updated = True
134
136
if graph_index == 0 :
135
137
# adds some info logging for the first graph
136
138
logger .info ("Cache the graph of shape %s for later use" ,
0 commit comments