Skip to content

Commit c42f0ec

Browse files
mmngammngays
andauthored
examples : fix gpt-neox (#2943)
Co-authored-by: mmnga <mmnga1mmnga@gmail.com>
1 parent 2753415 commit c42f0ec

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

examples/gptneox-wip/gptneox-main.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -660,9 +660,10 @@ bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt2
660660
ggml_tensor * gpt_neox_ff(
661661
const gpt_neox_block &block,
662662
ggml_context * ctx0,
663-
ggml_tensor * inp) {
663+
ggml_tensor * inp,
664+
const gpt_neox_hparams &hparams) {
664665

665-
ggml_tensor * cur = ggml_norm(ctx0, inp);
666+
ggml_tensor * cur = ggml_norm(ctx0, inp, hparams.norm_eps);
666667

667668
cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, block.ln_2_g, cur), cur), ggml_repeat(ctx0, block.ln_2_b, cur));
668669
cur = ggml_mul_mat(ctx0, block.c_mlp_fc_w, cur);
@@ -753,7 +754,7 @@ bool gpt_neox_eval(
753754
// self-attention
754755
{
755756
{
756-
cur = ggml_norm(ctx0, inpL);
757+
cur = ggml_norm(ctx0, inpL, hparams.norm_eps);
757758

758759
cur = ggml_add(ctx0,
759760
ggml_mul(ctx0, ggml_repeat(ctx0, model.blocks[il].ln_1_g, cur), cur),
@@ -844,7 +845,7 @@ bool gpt_neox_eval(
844845
if (hparams.par_res == 0) {
845846
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL);
846847

847-
cur = gpt_neox_ff(model.blocks[il], ctx0, inpFF);
848+
cur = gpt_neox_ff(model.blocks[il], ctx0, inpFF, hparams);
848849

849850
// input for next layer
850851
inpL = ggml_add(ctx0, cur, inpFF);
@@ -853,7 +854,7 @@ bool gpt_neox_eval(
853854

854855
// this is independent of the self-attention result, so it could be done in parallel to the self-attention
855856
// note here we pass inpL instead of cur
856-
cur = gpt_neox_ff(model.blocks[il], ctx0, inpL);
857+
cur = gpt_neox_ff(model.blocks[il], ctx0, inpL, hparams);
857858

858859
// layer input + FF
859860
cur = ggml_add(ctx0, cur, inpFF);
@@ -867,7 +868,7 @@ bool gpt_neox_eval(
867868

868869
// norm
869870
{
870-
inpL = ggml_norm(ctx0, inpL);
871+
inpL = ggml_norm(ctx0, inpL, hparams.norm_eps);
871872

872873
// inpL = ln_f_g*inpL + ln_f_b
873874
inpL = ggml_add(ctx0,

llama.cpp

+44-2
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,44 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
325325
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
326326
},
327327
},
328+
{
329+
LLM_ARCH_GPT2,
330+
{
331+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
332+
},
333+
},
334+
{
335+
LLM_ARCH_GPTJ,
336+
{
337+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
338+
},
339+
},
340+
{
341+
LLM_ARCH_GPTNEOX,
342+
{
343+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
344+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
345+
{ LLM_TENSOR_OUTPUT, "output" },
346+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
347+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
348+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
349+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
350+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
351+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
352+
},
353+
},
354+
{
355+
LLM_ARCH_MPT,
356+
{
357+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
358+
},
359+
},
360+
{
361+
LLM_ARCH_UNKNOWN,
362+
{
363+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
364+
},
365+
},
328366
};
329367

330368
static llm_arch llm_arch_from_string(const std::string & name) {
@@ -1605,9 +1643,13 @@ static void llm_load_hparams(
16051643

16061644
GGUF_GET_KEY(ctx, hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
16071645

1608-
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
1609-
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
1646+
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
1647+
if (hparams.n_rot != hparams.n_embd / hparams.n_head) {
1648+
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head));
1649+
}
16101650
}
1651+
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
1652+
// gpt-j n_rot = rotary_dim
16111653
}
16121654

16131655
// arch-specific KVs

0 commit comments

Comments
 (0)