@@ -4250,11 +4250,13 @@ struct llm_build_context {
4250
4250
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
4251
4251
4252
4252
// select experts
4253
- ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4254
- //ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
4255
- ggml_tensor * weights = ggml_get_rows(ctx0,
4256
- ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts);
4257
- weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
4253
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
4254
+ ggml_tensor * weights =
4255
+ ggml_reshape_2d(ctx0,
4256
+ ggml_get_rows(ctx0,
4257
+ ggml_reshape_3d(ctx0, probs, 1, n_experts, n_tokens), selected_experts),
4258
+ n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
4259
+ weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok]
4258
4260
4259
4261
// compute expert outputs
4260
4262
ggml_tensor * moe_out;
0 commit comments