Skip to content

Commit 32d11db

Browse files
committed
gguf-py : always defer GGUFWrite output file opening
Changing what happens when the output file is opened will be easier, since this reduces the cases to consider. * gguf-py : prevent GGUFWriter from writing all tensors multiple times It was already checked with an assertion before, but using WriterState should make the error message slightly less cryptic.
1 parent fe59f20 commit 32d11db

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

gguf-py/gguf/gguf_writer.py

+32-17
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ class WriterState(Enum):
5151
HEADER = auto()
5252
KV_DATA = auto()
5353
TI_DATA = auto()
54+
WEIGHTS = auto()
5455

5556

5657
class GGUFWriter:
5758
fout: BufferedWriter | None
59+
path: os.PathLike[str] | str | None
5860
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
5961
tensors: dict[str, TensorInfo]
6062
kv_data: dict[str, GGUFValue]
@@ -77,7 +79,8 @@ def __init__(
7779
self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False,
7880
endianess: GGUFEndian = GGUFEndian.LITTLE,
7981
):
80-
self.fout = open(path, "wb") if path is not None else None
82+
self.fout = None
83+
self.path = path
8184
self.arch = arch
8285
self.endianess = endianess
8386
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
@@ -88,19 +91,29 @@ def __init__(
8891
logger.info("gguf: This GGUF file is for {0} Endian only".format(
8992
"Big" if self.endianess == GGUFEndian.BIG else "Little",
9093
))
91-
self.state = WriterState.NO_FILE if self.fout is None else WriterState.EMPTY
94+
self.state = WriterState.NO_FILE
9295

9396
self.add_architecture()
9497

95-
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
96-
# NOTE: not checking for WriterState.NO_FILE,
97-
# because writing can technically be started over from any state,
98-
# as long as a new path is provided
98+
def open_output_file(self, path: os.PathLike[str] | str | None = None) -> None:
99+
if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path):
100+
# allow calling this multiple times as long as the path is the same
101+
return
102+
if self.state is not WriterState.NO_FILE:
103+
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
104+
99105
if path is not None:
106+
self.path = path
107+
108+
if self.path is not None:
100109
if self.fout is not None:
101110
self.fout.close()
102-
self.fout = open(path, "wb")
111+
self.fout = open(self.path, "wb")
103112
self.state = WriterState.EMPTY
113+
114+
def write_header_to_file(self, path: os.PathLike[str] | str | None = None) -> None:
115+
self.open_output_file(path)
116+
104117
if self.state is not WriterState.EMPTY:
105118
raise ValueError(f'Expected output file to be empty, got {self.state}')
106119

@@ -206,8 +219,8 @@ def add_tensor_info(
206219
self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype,
207220
tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None,
208221
) -> None:
209-
if self.state is not WriterState.EMPTY and self.state is not WriterState.NO_FILE:
210-
raise ValueError(f'Expected output file to be empty or absent, got {self.state}')
222+
if self.state is not WriterState.NO_FILE:
223+
raise ValueError(f'Expected output file to be not yet opened, got {self.state}')
211224

212225
if name in self.tensors:
213226
raise ValueError(f'Duplicated tensor name {name!r}')
@@ -263,8 +276,8 @@ def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None
263276
fp.write(bytes([0] * pad))
264277

265278
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
266-
if self.state is not WriterState.TI_DATA:
267-
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
279+
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
280+
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
268281
assert self.fout is not None
269282

270283
if self.endianess == GGUFEndian.BIG:
@@ -273,6 +286,8 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
273286
tensor.tofile(self.fout)
274287
self.write_padding(self.fout, tensor.nbytes)
275288

289+
self.state = WriterState.WEIGHTS
290+
276291
def write_tensors_to_file(self, *, progress: bool = False) -> None:
277292
self.write_ti_data_to_file()
278293

@@ -299,14 +314,14 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
299314
bar.update(ti.nbytes)
300315
self.write_padding(self.fout, ti.nbytes)
301316
ti.tensor = None
317+
else:
318+
self.temp_file.seek(0)
302319

303-
return
304-
305-
self.temp_file.seek(0)
320+
shutil.copyfileobj(self.temp_file, self.fout)
321+
self.flush()
322+
self.temp_file.close()
306323

307-
shutil.copyfileobj(self.temp_file, self.fout)
308-
self.flush()
309-
self.temp_file.close()
324+
self.state = WriterState.WEIGHTS
310325

311326
def flush(self) -> None:
312327
assert self.fout is not None

0 commit comments

Comments
 (0)