Skip to content

Commit 8b185b7

Browse files
committed
llama : fix expert weighting in the FFN
1 parent 7ea3695 commit 8b185b7

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

llama.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4250,11 +4250,13 @@ struct llm_build_context {
42504250
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
42514251

42524252
// select experts
4253-
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4254-
//ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
4255-
ggml_tensor * weights = ggml_get_rows(ctx0,
4256-
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
4257-
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
4253+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4254+
ggml_tensor * weights =
4255+
ggml_reshape_2d(ctx0,
4256+
ggml_get_rows(ctx0,
4257+
ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts),
4258+
n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
4259+
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok]
42584260

42594261
// compute expert outputs
42604262
ggml_tensor * moe_out;

0 commit comments

Comments
 (0)