Skip to content

Commit 11bff29

Browse files
authored
MPT : support GQA for replit-code-v1.5 (#3627)
1 parent 11dc109 commit 11bff29

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

convert-mpt-hf-to-gguf.py

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def parse_args() -> argparse.Namespace:
9898
gguf_writer.add_block_count(block_count)
9999
gguf_writer.add_feed_forward_length(4 * hparams["d_model"])
100100
gguf_writer.add_head_count(hparams["n_heads"])
101+
if kv_n_heads := hparams["attn_config"].get("kv_n_heads"):
102+
gguf_writer.add_head_count_kv(kv_n_heads)
101103
gguf_writer.add_layer_norm_eps(1e-05)
102104
if hparams["attn_config"]["clip_qkv"] is not None:
103105
gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"])

llama.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -2839,8 +2839,8 @@ static void llm_load_tensors(
28392839
auto & layer = model.layers[i];
28402840

28412841
layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
2842-
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split);
2843-
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
2842+
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
2843+
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
28442844

28452845
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
28462846

@@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
53685368
const int64_t n_layer = hparams.n_layer;
53695369
const int64_t n_ctx = cparams.n_ctx;
53705370
const int64_t n_head = hparams.n_head;
5371-
const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA
5371+
const int64_t n_head_kv = hparams.n_head_kv;
53725372
const int64_t n_embd_head = hparams.n_embd_head();
53735373
const int64_t n_embd_gqa = hparams.n_embd_gqa();
53745374

0 commit comments

Comments
 (0)