@@ -51,10 +51,12 @@ class WriterState(Enum):
51
51
HEADER = auto ()
52
52
KV_DATA = auto ()
53
53
TI_DATA = auto ()
54
+ WEIGHTS = auto ()
54
55
55
56
56
57
class GGUFWriter :
57
58
fout : BufferedWriter | None
59
+ path : os .PathLike [str ] | str | None
58
60
temp_file : tempfile .SpooledTemporaryFile [bytes ] | None
59
61
tensors : dict [str , TensorInfo ]
60
62
kv_data : dict [str , GGUFValue ]
@@ -77,7 +79,8 @@ def __init__(
77
79
self , path : os .PathLike [str ] | str | None , arch : str , use_temp_file : bool = False ,
78
80
endianess : GGUFEndian = GGUFEndian .LITTLE ,
79
81
):
80
- self .fout = open (path , "wb" ) if path is not None else None
82
+ self .fout = None
83
+ self .path = path
81
84
self .arch = arch
82
85
self .endianess = endianess
83
86
self .data_alignment = GGUF_DEFAULT_ALIGNMENT
@@ -88,19 +91,29 @@ def __init__(
88
91
logger .info ("gguf: This GGUF file is for {0} Endian only" .format (
89
92
"Big" if self .endianess == GGUFEndian .BIG else "Little" ,
90
93
))
91
- self .state = WriterState .NO_FILE if self . fout is None else WriterState . EMPTY
94
+ self .state = WriterState .NO_FILE
92
95
93
96
self .add_architecture ()
94
97
95
- def write_header_to_file (self , path : os .PathLike [str ] | str | None = None ) -> None :
96
- # NOTE: not checking for WriterState.NO_FILE,
97
- # because writing can technically be started over from any state,
98
- # as long as a new path is provided
98
+ def open_output_file (self , path : os .PathLike [str ] | str | None = None ) -> None :
99
+ if self .state is WriterState .EMPTY and self .fout is not None and (path is None or path == self .path ):
100
+ # allow calling this multiple times as long as the path is the same
101
+ return
102
+ if self .state is not WriterState .NO_FILE :
103
+ raise ValueError (f'Expected output file to be not yet opened, got { self .state } ' )
104
+
99
105
if path is not None :
106
+ self .path = path
107
+
108
+ if self .path is not None :
100
109
if self .fout is not None :
101
110
self .fout .close ()
102
- self .fout = open (path , "wb" )
111
+ self .fout = open (self . path , "wb" )
103
112
self .state = WriterState .EMPTY
113
+
114
+ def write_header_to_file (self , path : os .PathLike [str ] | str | None = None ) -> None :
115
+ self .open_output_file (path )
116
+
104
117
if self .state is not WriterState .EMPTY :
105
118
raise ValueError (f'Expected output file to be empty, got { self .state } ' )
106
119
@@ -206,8 +219,8 @@ def add_tensor_info(
206
219
self , name : str , tensor_shape : Sequence [int ], tensor_dtype : np .dtype ,
207
220
tensor_nbytes : int , raw_dtype : GGMLQuantizationType | None = None ,
208
221
) -> None :
209
- if self .state is not WriterState .EMPTY and self . state is not WriterState . NO_FILE :
210
- raise ValueError (f'Expected output file to be empty or absent , got { self .state } ' )
222
+ if self .state is not WriterState .NO_FILE :
223
+ raise ValueError (f'Expected output file to be not yet opened , got { self .state } ' )
211
224
212
225
if name in self .tensors :
213
226
raise ValueError (f'Duplicated tensor name { name !r} ' )
@@ -263,8 +276,8 @@ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None
263
276
fp .write (bytes ([0 ] * pad ))
264
277
265
278
def write_tensor_data (self , tensor : np .ndarray [Any , Any ]) -> None :
266
- if self .state is not WriterState .TI_DATA :
267
- raise ValueError (f'Expected output file to contain tensor info, got { self .state } ' )
279
+ if self .state is not WriterState .TI_DATA and self . state is not WriterState . WEIGHTS :
280
+ raise ValueError (f'Expected output file to contain tensor info or weights , got { self .state } ' )
268
281
assert self .fout is not None
269
282
270
283
if self .endianess == GGUFEndian .BIG :
@@ -273,6 +286,8 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
273
286
tensor .tofile (self .fout )
274
287
self .write_padding (self .fout , tensor .nbytes )
275
288
289
+ self .state = WriterState .WEIGHTS
290
+
276
291
def write_tensors_to_file (self , * , progress : bool = False ) -> None :
277
292
self .write_ti_data_to_file ()
278
293
@@ -299,14 +314,14 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
299
314
bar .update (ti .nbytes )
300
315
self .write_padding (self .fout , ti .nbytes )
301
316
ti .tensor = None
317
+ else :
318
+ self .temp_file .seek (0 )
302
319
303
- return
304
-
305
- self .temp_file .seek ( 0 )
320
+ shutil . copyfileobj ( self . temp_file , self . fout )
321
+ self . flush ()
322
+ self .temp_file .close ( )
306
323
307
- shutil .copyfileobj (self .temp_file , self .fout )
308
- self .flush ()
309
- self .temp_file .close ()
324
+ self .state = WriterState .WEIGHTS
310
325
311
326
def flush (self ) -> None :
312
327
assert self .fout is not None
0 commit comments