Skip to content

Commit d354d03

Browse files
MrkvakYellowRoseCx
authored andcommitted
convert-hf : support for mixtral-instruct (ggml-org#4428)
* convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct * convert : use sentencepiece tokenizer for Mixtral-instruct * convert : make flake8 happy
1 parent b1cf642 commit d354d03

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

convert-hf-to-gguf.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,18 @@ def set_gguf_parameters(self):
7777
self.gguf_writer.add_embedding_length(n_embd)
7878
if (n_ff := self.hparams.get("intermediate_size")) is not None:
7979
self.gguf_writer.add_feed_forward_length(n_ff)
80-
if (n_head := self.hparams.get("num_attention_head")) is not None:
80+
if (n_head := self.hparams.get("num_attention_heads")) is not None:
8181
self.gguf_writer.add_head_count(n_head)
82+
if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
83+
self.gguf_writer.add_head_count_kv(n_head_kv)
84+
85+
if (n_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
86+
self.gguf_writer.add_layer_norm_rms_eps(n_rms_eps)
87+
if (n_experts := self.hparams.get("num_local_experts")) is not None:
88+
self.gguf_writer.add_expert_count(n_experts)
89+
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
90+
self.gguf_writer.add_expert_used_count(n_experts_used)
91+
8292
self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
8393

8494
def write_tensors(self):
@@ -170,6 +180,8 @@ def from_model_architecture(model_architecture):
170180
return StableLMModel
171181
if model_architecture == "QWenLMHeadModel":
172182
return QwenModel
183+
if model_architecture == "MixtralForCausalLM":
184+
return MixtralModel
173185
return Model
174186

175187
def _is_model_safetensors(self) -> bool:
@@ -207,6 +219,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
207219
return gguf.MODEL_ARCH.STABLELM
208220
if arch == "QWenLMHeadModel":
209221
return gguf.MODEL_ARCH.QWEN
222+
if arch == "MixtralForCausalLM":
223+
return gguf.MODEL_ARCH.LLAMA
210224

211225
raise NotImplementedError(f'Architecture "{arch}" not supported!')
212226

@@ -837,6 +851,11 @@ def set_gguf_parameters(self):
837851
self.gguf_writer.add_layer_norm_eps(1e-5)
838852

839853

854+
class MixtralModel(Model):
855+
def set_vocab(self):
856+
self._set_vocab_sentencepiece()
857+
858+
840859
class QwenModel(Model):
841860
@staticmethod
842861
def token_bytes_to_string(b):

0 commit comments

Comments
 (0)