Skip to content

Commit 7016fe5

Browse files
committed
mamba : simplify the conv step with a self-overlapping view
Turns out the conv_state can be made smaller by one column. Note that this breaks existing GGUFs of Mamba, because the key_value_length field is tied to the conv_state size. Convolution with a self-overlapping view is cool! And it's much simpler than what I initially thought would be necessary to make the convolution step work with more than 1 token at a time. Next step is to make the SSM step work on batches of tokens too, and thus I need to figure out a way to make a parallel selective scan which will keep the ssm_state small and won't make it bigger by a factor of (n_layer * batch_size). * llama : fix Mamba KV self size wrongly displaying as f16 instead of f32 Relatedly, I also tried to see if other types than f32 worked for the states, but they don't, because of the operators used. It's probably better anyway to keep lots of precision there, since the states are small anyway.
1 parent ba94c9d commit 7016fe5

File tree

2 files changed

+63
-47
lines changed

2 files changed

+63
-47
lines changed

convert-hf-to-gguf.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1510,10 +1510,12 @@ def set_gguf_parameters(self):
15101510
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
15111511
self.gguf_writer.add_embedding_length(d_model)
15121512
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
1513-
self.gguf_writer.add_head_count(d_inner)
1513+
self.gguf_writer.add_head_count(d_inner) # the number of rows in conv_state and ssm_state
15141514
self.gguf_writer.add_block_count(self.hparams["n_layer"])
15151515
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
1516-
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4))
1516+
# NOTE: (ab)using the KV cache metadata to store dimensions for conv_state and ssm_state
1517+
# Since the first column of the conv_state is shifted out each time, it's not actually needed
1518+
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4) - 1)
15171519
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
15181520
self.gguf_writer.add_file_type(self.ftype)
15191521

llama.cpp

+59-45
Original file line numberDiff line numberDiff line change
@@ -1862,9 +1862,6 @@ static bool llama_kv_cache_init(
18621862
if (model.arch == LLM_ARCH_MAMBA) {
18631863
// only one slot is needed for Mamba
18641864
n_ctx = 1;
1865-
// it's probably best to keep as much precision as possible for the states
1866-
ktype = GGML_TYPE_F32;
1867-
vtype = GGML_TYPE_F32;
18681865
}
18691866

18701867
cache.has_shift = false;
@@ -4179,7 +4176,7 @@ static bool llm_load_tensors(
41794176
} break;
41804177
case LLM_ARCH_MAMBA:
41814178
{
4182-
const int64_t d_conv = hparams.n_embd_head_k;
4179+
const int64_t d_conv = hparams.n_embd_head_k + 1;
41834180
const int64_t d_state = hparams.n_embd_head_v;
41844181
const int64_t d_inner = hparams.n_head;
41854182
// FIXME: ceiling instead of floor
@@ -6917,28 +6914,27 @@ struct llm_build_context {
69176914
struct ggml_cgraph * build_mamba() {
69186915
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
69196916

6920-
const bool use_conv = batch.n_tokens > 1;
6921-
GGML_ASSERT(use_conv == false); // TODO: implement
6917+
const int32_t n_tok = batch.n_tokens;
69226918

69236919
// hopefully the compiler does constant folding
69246920
const int64_t d_model = n_embd;
69256921
const int64_t d_inner = n_head;
69266922
GGML_ASSERT(2 * d_model == d_inner);
6927-
const int64_t d_conv = n_embd_head_k;
6923+
const int64_t d_conv = n_embd_head_k + 1;
69286924
const int64_t d_state = n_embd_head_v;
69296925
const int64_t dt_rank = d_model / 16;
69306926

69316927
struct ggml_tensor * cur;
69326928
struct ggml_tensor * inpL;
69336929

6934-
// NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
6935-
// {n_embd, batch}
6930+
// {n_embd, n_tok}
69366931
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
69376932
cb(inpL, "inp_embd", -1);
69386933

69396934
for (int il = 0; il < n_layer; ++il) {
69406935
// (ab)using the kv cache to store the state
6941-
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner);
6936+
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
6937+
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
69426938
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
69436939

69446940
// norm
@@ -6947,33 +6943,43 @@ struct llm_build_context {
69476943
LLM_NORM_RMS, cb, il);
69486944
cb(cur, "attn_norm", il);
69496945

6950-
// {n_embd, 2*d_inner} * {n_embd, batch} = {2*d_inner, batch}
6946+
// {n_embd, 2*d_inner} * {n_embd, n_tok} => {2*d_inner, n_tok}
69516947
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
69526948
// split the above in two
6953-
// assuming it's contiguous
6954-
// {d_inner, batch}
6949+
// => {d_inner, n_tok}
69556950
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
69566951
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
69576952

6958-
cur = x;
6959-
69606953
// conv
69616954
{
6962-
// shift conv state left
6963-
conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0);
6964-
6965-
// update last column
6966-
// x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967-
conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1));
6968-
6969-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
6970-
6971-
// rearrange and sum
6972-
// no need to rearrange the conv_state, since it's already in the right shape
6973-
// => {1, d_inner}
6974-
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
6975-
// => {d_inner, 1}
6976-
x = ggml_transpose(ctx0, x);
6955+
// concat last (d_conv - 1) columns of conv_state, and x
6956+
6957+
// The following tensor is too big in order to avoid an assertion error when making an overlapping view.
6958+
// TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
6959+
// This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}
6960+
// which is around (d_conv-1) times as small as its current size.
6961+
struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok);
6962+
const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);
6963+
6964+
conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
6965+
// unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float)
6966+
conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x));
6967+
6968+
// store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
6969+
ggml_build_forward_expand(gf,
6970+
ggml_cpy(ctx0,
6971+
ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
6972+
ggml_view_tensor(ctx0, kv_self.k_l[il])));
6973+
6974+
// prepare convolution for all tokens in the batch with a self-overlapping view
6975+
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
6976+
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1)*d_inner*ggml_element_size(conv_x), 0);
6977+
6978+
// perform convolution
6979+
// => {1, d_inner, n_tok}
6980+
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
6981+
// => {d_inner, n_tok, 1}
6982+
x = ggml_permute(ctx0, x, 2, 0, 1, 3);
69776983

69786984
// bias
69796985
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
@@ -6983,23 +6989,24 @@ struct llm_build_context {
69836989

69846990
// ssm
69856991
{
6986-
// {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
6987-
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x);
6988-
// FIXME: handle batches of more than 1 token
6989-
struct ggml_tensor * dt = ggml_view_1d(ctx0, x_db, dt_rank, 0);
6990-
struct ggml_tensor * B = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank);
6991-
struct ggml_tensor * C = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state));
6992-
6993-
// {dt_rank} * {dt_rank, d_inner} = {1, d_inner}
6994-
dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt);
6995-
dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b));
6992+
// {d_inner, dt_rank + 2*d_state} * {d_inner, n_tok} => {dt_rank + 2*d_state, n_tok}
6993+
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x);
6994+
// split
6995+
struct ggml_tensor * dt = ggml_view_2d(ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1], 0);
6996+
struct ggml_tensor * B = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*dt_rank);
6997+
struct ggml_tensor * C = ggml_view_2d(ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1], ggml_element_size(x_db)*(dt_rank+d_state));
6998+
6999+
// {dt_rank, d_inner} * {dt_rank, n_tok} => {d_inner, n_tok}
7000+
dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt);
7001+
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
69967002
dt = ggml_soft_plus(ctx0, dt);
69977003

7004+
// FIXME: support batches with more than 1 token
69987005
// => {d_state, d_inner}
6999-
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt));
7006+
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt)));
70007007

70017008
// => {d_state, d_inner}
7002-
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt));
7009+
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt);
70037010

70047011
// => {d_state, d_inner}
70057012
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
@@ -7014,7 +7021,7 @@ struct llm_build_context {
70147021
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
70157022
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
70167023

7017-
// {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
7024+
// {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
70187025
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
70197026
}
70207027

@@ -10722,8 +10729,15 @@ struct llama_context * llama_new_context_with_model(
1072210729
ctx->rng = std::mt19937(params.seed);
1072310730
ctx->logits_all = params.logits_all;
1072410731

10725-
const ggml_type type_k = params.type_k;
10726-
const ggml_type type_v = params.type_v;
10732+
ggml_type type_k = params.type_k;
10733+
ggml_type type_v = params.type_v;
10734+
10735+
// Mamba (mis)uses the KV cache to store its states
10736+
if (model->arch == LLM_ARCH_MAMBA) {
10737+
// it's probably best to keep as much precision as possible for the states
10738+
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
10739+
type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
10740+
}
1072710741

1072810742
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
1072910743
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);

0 commit comments

Comments
 (0)