diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 19152174..8776986a 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -1121,10 +1121,10 @@ llm_expert_gating_func_type gating_op, if (ffn_norm) { auto the_ffn_norm = ffn_norm->extra ? ((ggml_split_tensor_t *)ffn_norm->extra)->splits[lctx.model.main_gpu] : ffn_norm; GGML_ASSERT(the_ffn_norm); - cur = llm_build_norm(ctx, input, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il); + cur = llm_build_norm(ctx, cur, lctx.model.hparams, the_ffn_norm, nullptr, LLM_NORM_RMS, cb, il); cb(cur, "ffn_inp_normed", il); } - else if (cur->type != GGML_TYPE_F32) { + if (cur->type != GGML_TYPE_F32) { cur = ggml_cast(ctx, cur, GGML_TYPE_F32); } auto the_gate_inp = gate_inp->extra ? ((ggml_split_tensor_t *)gate_inp->extra)->splits[lctx.model.main_gpu] : gate_inp; @@ -1139,8 +1139,12 @@ llm_expert_gating_func_type gating_op, the_exp_probs_b, n_expert, n_expert_used, type_op, norm_w, scale_w, w_scale, - gating_op, cb, il, graph, add_input); + gating_op, cb, il, graph, false); cb(routed_out, "routed_out", il); + if (add_input) { + routed_out = ggml_add(ctx, routed_out, input); + cb(routed_out, "routed_out_with_inp", il); + } ggml_build_forward_expand(graph, routed_out); if (up_shexp && gate_shexp && down_shexp) {