Skip to content

Commit 709ea7d

Browse files
committed
mamba : stop abusing attention metadata
This breaks existing converted-to-GGUF Mamba models, but will allow supporting mixed architectures like MambaFormer without needing to break Mamba models. This will also allow changing the size of Mamba's states without having to reconvert models in the future. (e.g. using something else than d_conv - 1 columns for the conv_states will not require breaking existing converted Mamba models again) * gguf-py : add new KV metadata key-value pairs for Mamba * llama : add new metadata key-value pairs for Mamba * llama : guard against divisions by zero when n_head is 0 * mamba : rename "unlimited" KV cache property to "recurrent"
1 parent 919d79f commit 709ea7d

File tree

4 files changed

+128
-49
lines changed

4 files changed

+128
-49
lines changed

convert-hf-to-gguf.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -1857,21 +1857,28 @@ def set_vocab(self):
18571857

18581858
def set_gguf_parameters(self):
18591859
d_model = self.hparams["d_model"]
1860+
d_conv = self.hparams.get("d_conv", 4)
18601861
d_inner = self.hparams.get("d_inner", 2 * d_model)
1862+
d_state = self.hparams.get("d_state", 16)
1863+
# ceiling division
1864+
# ref: https://stackoverflow.com/a/17511341/22827863
1865+
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
1866+
dt_rank = self.hparams.get("dt_rank", -(d_model // -16))
1867+
18611868
# Fail early for models which don't have a block expansion factor of 2
18621869
assert d_inner == 2 * d_model
18631870

18641871
self.gguf_writer.add_name(self.dir_model.name)
18651872
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
18661873
self.gguf_writer.add_embedding_length(d_model)
18671874
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
1868-
self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state
1875+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
18691876
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1877+
self.gguf_writer.add_ssm_conv_kernel_size(d_conv)
1878+
self.gguf_writer.add_ssm_inner_length(d_inner)
1879+
self.gguf_writer.add_ssm_state_length(d_state)
1880+
self.gguf_writer.add_ssm_dt_rank(dt_rank)
18701881
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
1871-
# NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
1872-
# Since the first column of the conv_state is shifted out each time, it's not actually needed
1873-
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
1874-
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
18751882
self.gguf_writer.add_file_type(self.ftype)
18761883

18771884
def write_tensors(self):

gguf-py/gguf/constants.py

+12
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class Rope:
6161
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
6262
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
6363

64+
class SSM:
65+
CONV_KERNEL_SIZE = "{arch}.ssm.d_conv"
66+
INNER_LENGTH = "{arch}.ssm.d_inner"
67+
STATE_LENGTH = "{arch}.ssm.d_state"
68+
DT_RANK = "{arch}.ssm.dt_rank"
69+
6470
class Tokenizer:
6571
MODEL = "tokenizer.ggml.model"
6672
LIST = "tokenizer.ggml.tokens"
@@ -726,6 +732,12 @@ def get_type(val: Any) -> GGUFValueType:
726732
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
727733
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
728734

735+
# SSM
736+
KEY_SSM_CONV_KERNEL_SIZE = Keys.SSM.CONV_KERNEL_SIZE
737+
KEY_SSM_INNER_LENGTH = Keys.SSM.INNER_LENGTH
738+
KEY_SSM_STATE_LENGTH = Keys.SSM.STATE_LENGTH
739+
KEY_SSM_DT_RANK = Keys.SSM.DT_RANK
740+
729741
# tokenization
730742
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
731743
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST

gguf-py/gguf/gguf_writer.py

+12
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,18 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
382382
def add_rope_scaling_finetuned(self, value: bool) -> None:
383383
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
384384

385+
def add_ssm_conv_kernel_size(self, value: int) -> None:
386+
self.add_uint32(Keys.SSM.CONV_KERNEL_SIZE.format(arch=self.arch), value)
387+
388+
def add_ssm_inner_length(self, value: int) -> None:
389+
self.add_uint32(Keys.SSM.INNER_LENGTH.format(arch=self.arch), value)
390+
391+
def add_ssm_state_length(self, value: int) -> None:
392+
self.add_uint32(Keys.SSM.STATE_LENGTH.format(arch=self.arch), value)
393+
394+
def add_ssm_dt_rank(self, value: int) -> None:
395+
self.add_uint32(Keys.SSM.DT_RANK.format(arch=self.arch), value)
396+
385397
def add_tokenizer_model(self, model: str) -> None:
386398
self.add_string(Keys.Tokenizer.MODEL, model)
387399

0 commit comments

Comments
 (0)