@@ -1862,9 +1862,6 @@ static bool llama_kv_cache_init(
1862
1862
if (model.arch == LLM_ARCH_MAMBA) {
1863
1863
// only one slot is needed for Mamba
1864
1864
n_ctx = 1;
1865
- // it's probably best to keep as much precision as possible for the states
1866
- ktype = GGML_TYPE_F32;
1867
- vtype = GGML_TYPE_F32;
1868
1865
}
1869
1866
1870
1867
cache.has_shift = false;
@@ -4179,7 +4176,7 @@ static bool llm_load_tensors(
4179
4176
} break;
4180
4177
case LLM_ARCH_MAMBA:
4181
4178
{
4182
- const int64_t d_conv = hparams.n_embd_head_k;
4179
+ const int64_t d_conv = hparams.n_embd_head_k + 1 ;
4183
4180
const int64_t d_state = hparams.n_embd_head_v;
4184
4181
const int64_t d_inner = hparams.n_head;
4185
4182
// FIXME: ceiling instead of floor
@@ -6917,28 +6914,27 @@ struct llm_build_context {
6917
6914
struct ggml_cgraph * build_mamba() {
6918
6915
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
6919
6916
6920
- const bool use_conv = batch.n_tokens > 1;
6921
- GGML_ASSERT(use_conv == false); // TODO: implement
6917
+ const int32_t n_tok = batch.n_tokens;
6922
6918
6923
6919
// hopefully the compiler does constant folding
6924
6920
const int64_t d_model = n_embd;
6925
6921
const int64_t d_inner = n_head;
6926
6922
GGML_ASSERT(2 * d_model == d_inner);
6927
- const int64_t d_conv = n_embd_head_k;
6923
+ const int64_t d_conv = n_embd_head_k + 1 ;
6928
6924
const int64_t d_state = n_embd_head_v;
6929
6925
const int64_t dt_rank = d_model / 16;
6930
6926
6931
6927
struct ggml_tensor * cur;
6932
6928
struct ggml_tensor * inpL;
6933
6929
6934
- // NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
6935
- // {n_embd, batch}
6930
+ // {n_embd, n_tok}
6936
6931
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
6937
6932
cb(inpL, "inp_embd", -1);
6938
6933
6939
6934
for (int il = 0; il < n_layer; ++il) {
6940
6935
// (ab)using the kv cache to store the state
6941
- ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner);
6936
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
6937
+ ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv - 1, d_inner);
6942
6938
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
6943
6939
6944
6940
// norm
@@ -6947,33 +6943,43 @@ struct llm_build_context {
6947
6943
LLM_NORM_RMS, cb, il);
6948
6944
cb(cur, "attn_norm", il);
6949
6945
6950
- // {n_embd, 2*d_inner} * {n_embd, batch } = {2*d_inner, batch }
6946
+ // {n_embd, 2*d_inner} * {n_embd, n_tok } => {2*d_inner, n_tok }
6951
6947
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
6952
6948
// split the above in two
6953
- // assuming it's contiguous
6954
- // {d_inner, batch}
6949
+ // => {d_inner, n_tok}
6955
6950
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
6956
6951
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
6957
6952
6958
- cur = x;
6959
-
6960
6953
// conv
6961
6954
{
6962
- // shift conv state left
6963
- conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0);
6964
-
6965
- // update last column
6966
- // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967
- conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1));
6968
-
6969
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
6970
-
6971
- // rearrange and sum
6972
- // no need to rearrange the conv_state, since it's already in the right shape
6973
- // => {1, d_inner}
6974
- x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
6975
- // => {d_inner, 1}
6976
- x = ggml_transpose(ctx0, x);
6955
+ // concat last (d_conv - 1) columns of conv_state, and x
6956
+
6957
+ // The following tensor is too big in order to avoid an assertion error when making an overlapping view.
6958
+ // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
6959
+ // This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}
6960
+ // which is around (d_conv-1) times as small as its current size.
6961
+ struct ggml_tensor * conv_x = ggml_new_tensor_1d(ctx0, conv_state->type, d_conv*d_inner*n_tok);
6962
+ const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size(conv_x);
6963
+
6964
+ conv_x = ggml_set_2d(ctx0, conv_x, conv_state, conv_x_nb1, 0);
6965
+ // unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float)
6966
+ conv_x = ggml_set_2d(ctx0, conv_x, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_x_nb1, (d_conv - 1)*ggml_element_size(conv_x));
6967
+
6968
+ // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
6969
+ ggml_build_forward_expand(gf,
6970
+ ggml_cpy(ctx0,
6971
+ ggml_view_2d(ctx0, conv_x, d_conv - 1, d_inner, conv_x_nb1, n_tok*ggml_element_size(conv_x)),
6972
+ ggml_view_tensor(ctx0, kv_self.k_l[il])));
6973
+
6974
+ // prepare convolution for all tokens in the batch with a self-overlapping view
6975
+ // {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
6976
+ conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1)*d_inner*ggml_element_size(conv_x), 0);
6977
+
6978
+ // perform convolution
6979
+ // => {1, d_inner, n_tok}
6980
+ x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
6981
+ // => {d_inner, n_tok, 1}
6982
+ x = ggml_permute(ctx0, x, 2, 0, 1, 3);
6977
6983
6978
6984
// bias
6979
6985
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
@@ -6983,23 +6989,24 @@ struct llm_build_context {
6983
6989
6984
6990
// ssm
6985
6991
{
6986
- // {2*n_embd, batch} * {2*n_embd , dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
6987
- struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x);
6988
- // FIXME: handle batches of more than 1 token
6989
- struct ggml_tensor * dt = ggml_view_1d (ctx0, x_db, dt_rank, 0);
6990
- struct ggml_tensor * B = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank);
6991
- struct ggml_tensor * C = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state));
6992
-
6993
- // {dt_rank} * {dt_rank, d_inner } = {1, d_inner }
6994
- dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt);
6995
- dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b) );
6992
+ // {d_inner , dt_rank + 2*d_state} * {d_inner, n_tok} => { dt_rank + 2*d_state, n_tok }
6993
+ struct ggml_tensor * x_db = ggml_mul_mat(ctx0, model.layers[il].ssm_x, x );
6994
+ // split
6995
+ struct ggml_tensor * dt = ggml_view_2d (ctx0, x_db, dt_rank, x_db->ne[1], x_db->nb[1] , 0);
6996
+ struct ggml_tensor * B = ggml_view_2d (ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1] , ggml_element_size(x_db)*dt_rank);
6997
+ struct ggml_tensor * C = ggml_view_2d (ctx0, x_db, d_state, x_db->ne[1], x_db->nb[1] , ggml_element_size(x_db)*(dt_rank+d_state));
6998
+
6999
+ // {dt_rank, d_inner } * {dt_rank, n_tok } => {d_inner, n_tok }
7000
+ dt = ggml_mul_mat(ctx0, model.layers[il].ssm_dt, dt );
7001
+ dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
6996
7002
dt = ggml_soft_plus(ctx0, dt);
6997
7003
7004
+ // FIXME: support batches with more than 1 token
6998
7005
// => {d_state, d_inner}
6999
- struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt ));
7006
+ struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, ggml_transpose(ctx0, dt) ));
7000
7007
7001
7008
// => {d_state, d_inner}
7002
- struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt) );
7009
+ struct ggml_tensor * dB = ggml_out_prod(ctx0, B, dt );
7003
7010
7004
7011
// => {d_state, d_inner}
7005
7012
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
@@ -7014,7 +7021,7 @@ struct llm_build_context {
7014
7021
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
7015
7022
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
7016
7023
7017
- // {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
7024
+ // {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
7018
7025
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
7019
7026
}
7020
7027
@@ -10722,8 +10729,15 @@ struct llama_context * llama_new_context_with_model(
10722
10729
ctx->rng = std::mt19937(params.seed);
10723
10730
ctx->logits_all = params.logits_all;
10724
10731
10725
- const ggml_type type_k = params.type_k;
10726
- const ggml_type type_v = params.type_v;
10732
+ ggml_type type_k = params.type_k;
10733
+ ggml_type type_v = params.type_v;
10734
+
10735
+ // Mamba (mis)uses the KV cache to store its states
10736
+ if (model->arch == LLM_ARCH_MAMBA) {
10737
+ // it's probably best to keep as much precision as possible for the states
10738
+ type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
10739
+ type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
10740
+ }
10727
10741
10728
10742
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
10729
10743
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
0 commit comments