diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 551c7cd4..2711dd04 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -621,6 +621,7 @@ ggml_tensor * llm_build_context::llm_build_norm( ggml_tensor * llm_build_context::llm_build_ffn( ggml_context * ctx, llama_context & lctx, + ggml_tensor * ffn_norm, ggml_tensor * input, ggml_tensor * up, ggml_tensor * up_b, @@ -654,7 +655,12 @@ ggml_tensor * llm_build_context::llm_build_ffn( auto split_d = d->splits[id]; GGML_ASSERT((!split_u && !split_g && split_d) || (split_u && split_g && split_d)); if (!split_u) continue; - auto cur = ggml_fused_up_gate(ctx, split_u, split_g, input, unary_op); + auto cur = input; + if (ffn_norm && ffn_norm->extra) { + auto norm = (ggml_split_tensor_t *)ffn_norm->extra; + cur = llm_build_norm(ctx, input, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il); + } + cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op); cb(cur, "ffn_up_gate", il_cb); cur = llm_build_lora_mm(lctx, ctx, split_d, cur); cb(cur, "ffn_down", il_cb); @@ -677,6 +683,11 @@ ggml_tensor * llm_build_context::llm_build_ffn( return cur; } + if (ffn_norm) { + input = llm_build_norm(ctx, input, lctx.model.hparams, ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(input, "ffn_norm", il); + } + if (lctx.cparams.fused_up_gate && up && gate && !up_b && !up_s && !gate_b && !gate_s && type_gate == LLM_FFN_PAR && (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { @@ -1524,10 +1535,7 @@ ggml_cgraph * llm_build_context::build_llama() { // feed-forward network if (model.layers[il].ffn_gate_inp == nullptr) { // non-MoE - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -1552,7 +1560,7 @@ ggml_cgraph * llm_build_context::build_llama() { cb, il, gf); // Shared experts - ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, ffn_inp_normed, + ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp_normed, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -1729,10 +1737,8 @@ ggml_cgraph * llm_build_context::build_deci() { // feed-forward network if (model.layers[il].ffn_gate_inp == nullptr) { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -1843,10 +1849,7 @@ ggml_cgraph * llm_build_context::build_baichuan() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -1938,10 +1941,7 @@ ggml_cgraph * llm_build_context::build_xverse() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -2051,7 +2051,7 @@ ggml_cgraph * llm_build_context::build_falcon() { // feed forward { - cur = llm_build_ffn(ctx0, lctx, attn_norm, // !! use the attn norm, not the result + cur = llm_build_ffn(ctx0, lctx, nullptr, attn_norm, // !! use the attn norm, not the result model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -2170,7 +2170,7 @@ ggml_cgraph * llm_build_context::build_grok() { cb(moe_out, "ffn_moe_out", il); if (model.layers[il].ffn_up) { - ggml_tensor* ffn_out = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor* ffn_out = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -2409,10 +2409,7 @@ ggml_cgraph * llm_build_context::build_starcoder() { // FF { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -2489,10 +2486,7 @@ ggml_cgraph * llm_build_context::build_refact() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -2669,21 +2663,21 @@ ggml_cgraph * llm_build_context::build_bert() { // feed-forward network if (model.arch == LLM_ARCH_BERT) { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, LLM_FFN_GELU, LLM_FFN_PAR, cb, il); } else { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -2769,10 +2763,7 @@ ggml_cgraph * llm_build_context::build_bloom() { // FF { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -2893,9 +2884,7 @@ ggml_cgraph * llm_build_context::build_mpt() { // feed forward { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -3011,7 +3000,7 @@ ggml_cgraph * llm_build_context::build_stablelm() { // parallel residual cur = inpSA; } - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -3114,10 +3103,7 @@ ggml_cgraph * llm_build_context::build_qwen() { // feed-forward forward { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -3209,10 +3195,7 @@ ggml_cgraph * llm_build_context::build_qwen2() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -3312,10 +3295,7 @@ ggml_cgraph * llm_build_context::build_qwen2vl() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -3437,7 +3417,7 @@ ggml_cgraph * llm_build_context::build_qwen2moe() { ggml_tensor * cur_gate = ggml_div(ctx0, ggml_silu(ctx0, cur_gate_inp), cur_gate_inp); cb(cur_gate, "ffn_shexp_gate", il); - ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor * cur_ffn = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -3543,10 +3523,7 @@ ggml_cgraph * llm_build_context::build_qwen3() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -3745,12 +3722,7 @@ ggml_cgraph * llm_build_context::build_qwen3vl() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4019,7 +3991,7 @@ ggml_cgraph * llm_build_context::build_phi2() { // FF { - ffn_output = llm_build_ffn(ctx0, lctx, attn_norm_output, + ffn_output = llm_build_ffn(ctx0, lctx, nullptr, attn_norm_output, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -4132,14 +4104,11 @@ ggml_cgraph * llm_build_context::build_phi3() { cur = ggml_add(ctx0, cur, residual); residual = cur; - cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // FF // special-case: the up and gate tensors are merged into a single tensor // TOOD: support into llm_build_ffn { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, cur, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4228,7 +4197,7 @@ ggml_cgraph * llm_build_context::build_plamo() { // feed-forward network { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4325,10 +4294,7 @@ ggml_cgraph * llm_build_context::build_gpt2() { // FF { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -4427,10 +4393,7 @@ ggml_cgraph * llm_build_context::build_codeshell() { // FF { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -4518,10 +4481,7 @@ ggml_cgraph * llm_build_context::build_orion() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4611,10 +4571,7 @@ ggml_cgraph * llm_build_context::build_internlm2() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4724,10 +4681,7 @@ ggml_cgraph * llm_build_context::build_minicpm() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4826,12 +4780,9 @@ ggml_cgraph * llm_build_context::build_gemma() { struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); cb(sa_out, "sa_out", il); - cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // feed-forward network { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, sa_out, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -4936,12 +4887,9 @@ ggml_cgraph * llm_build_context::build_gemma2() { struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); cb(sa_out, "sa_out", il); - cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // feed-forward network { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, sa_out, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -5067,11 +5015,8 @@ ggml_cgraph * llm_build_context::build_gemma3() { struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); cb(sa_out, "sa_out", il); - cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // feed-forward network - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, sa_out, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -5165,11 +5110,7 @@ ggml_cgraph * llm_build_context::build_starcoder2() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -5425,7 +5366,7 @@ ggml_cgraph * llm_build_context::build_command_r() { // feed-forward network { - cur = llm_build_ffn(ctx0, lctx, ffn_inp, + cur = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -5556,7 +5497,7 @@ ggml_cgraph * llm_build_context::build_olmo() { cur = llm_build_norm(ctx0, ffn_inp, hparams, NULL, NULL, LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -5670,10 +5611,7 @@ ggml_cgraph * llm_build_context::build_openelm() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -5777,7 +5715,7 @@ ggml_cgraph * llm_build_context::build_gptneox() { cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -5805,7 +5743,7 @@ ggml_cgraph * llm_build_context::build_gptneox() { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); cb(cur, "ffn_norm", il); - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -5898,10 +5836,7 @@ ggml_cgraph * llm_build_context::build_arctic() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -6431,7 +6366,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { cb(cur, "ffn_norm", il); if ((uint32_t) il < hparams.n_layer_dense_lead) { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -6456,7 +6391,7 @@ ggml_cgraph * llm_build_context::build_deepseek2() { // FFN shared expert { - ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -6578,7 +6513,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { if ((uint32_t) il < hparams.n_layer_dense_lead) { // dense FFN - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -6601,7 +6536,7 @@ ggml_cgraph * llm_build_context::build_glm4_moe() { cb(routed_out, "routed_out", il); { - struct ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx, cur, + struct ggml_tensor * shared_out = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -6861,10 +6796,7 @@ ggml_cgraph * llm_build_context::build_bitnet_158() { struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, NULL, NULL, NULL, @@ -6978,7 +6910,7 @@ ggml_cgraph * llm_build_context::build_cohere2() { // feed-forward network { - cur = llm_build_ffn(ctx0, lctx, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, + cur = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); @@ -7096,11 +7028,8 @@ ggml_cgraph * llm_build_context::build_t5_encoder() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm_enc, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm_enc, ffn_inp, model.layers[il].ffn_up_enc, NULL, NULL, model.layers[il].ffn_gate_enc, NULL, NULL, model.layers[il].ffn_down_enc, NULL, NULL, @@ -7284,11 +7213,8 @@ ggml_cgraph * llm_build_context::build_t5_decoder() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // T5 uses relu, flan-T5 uses gelu-gated - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -7382,10 +7308,7 @@ ggml_cgraph * llm_build_context::build_jais() { // FF { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, @@ -7487,10 +7410,7 @@ ggml_cgraph * llm_build_context::build_chatglm() { // FF { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -7613,12 +7533,8 @@ ggml_cgraph * llm_build_context::build_glm4() { // FF { - // Pre-MLP norm - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - // MLP - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -7731,7 +7647,7 @@ ggml_cgraph * llm_build_context::build_dots1() { cb(cur, "ffn_norm", il); if ((uint32_t) il < hparams.n_layer_dense_lead) { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -7754,7 +7670,7 @@ ggml_cgraph * llm_build_context::build_dots1() { cb(moe_out, "ffn_moe_out", il); { - ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -7875,10 +7791,7 @@ ggml_cgraph * llm_build_context::build_ernie4_5() { // feed-forward network { - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -8000,10 +7913,7 @@ ggml_cgraph * llm_build_context::build_ernie4_5_moe() { bool is_moe_layer = static_cast(il) >= hparams.n_layer_dense_lead && (il + 1) % hparams.n_moe_layer_step == 0; if (!is_moe_layer) { - cur = llm_build_norm(ctx0, ffn_inp,hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -8031,7 +7941,7 @@ ggml_cgraph * llm_build_context::build_ernie4_5_moe() { // Shared expert (if present) if (hparams.n_ff_shexp > 0) { - ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -8148,7 +8058,7 @@ ggml_cgraph * llm_build_context::build_hunyuan_moe() { cb(cur, "ffn_norm", il); // feed-forward network (non-MoE) - ggml_tensor * cur_mlp = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor * cur_mlp = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -8396,7 +8306,7 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() { cb(cur, "ffn_norm", il); if (static_cast(il) < hparams.n_layer_dense_lead) { - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, @@ -8419,7 +8329,7 @@ ggml_cgraph * llm_build_context::build_bailingmoe2() { cb, il, gf); cb(moe_out, "ffn_moe_out", il); - ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, nullptr, cur, model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, @@ -8639,10 +8549,7 @@ ggml_cgraph* llm_build_context::build_smollm3() { cb(ffn_inp, "ffn_inp", il); // feed-forward network - cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, + cur = llm_build_ffn(ctx0, lctx, model.layers[il].ffn_norm, ffn_inp, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, diff --git a/src/llama-build-context.h b/src/llama-build-context.h index 543ce9ca..3277c27b 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -317,7 +317,7 @@ struct llm_build_context { float kq_scale, const llm_build_cb & cb, int il, ggml_tensor * sinks = nullptr, int n_swa = 0); - static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx, + static ggml_tensor * llm_build_ffn(ggml_context * ctx, llama_context & lctx, ggml_tensor * ffn_norm, ggml_tensor * cur, ggml_tensor * up, ggml_tensor * up_b, diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index 8cefef46..25f2a0ae 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -406,7 +406,7 @@ bool create_tensors_helper::create_llama_tensors(const LLM_TN & tn) { // optional bias tensors layer.bo = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.rope_freqs = create_tensor(ctx_split, tn(LLM_TENSOR_ROPE_FREQS, "weight"), {n_embd/n_head/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); @@ -2944,6 +2944,10 @@ bool create_tensors_helper::create_tensors() { } if (layer.ffn_down && layer.ffn_up && layer.ffn_gate) { + if (layer.ffn_norm) { + auto split = create_split(ggml_nrows(layer.ffn_norm), -1, model.splits); + prepare_split_tensors(-1, ctx_split, layer.ffn_norm, layer.split_ffn_norm, split); + } int ffn_granularity = 16; if (ggml_is_quantized(layer.ffn_down->type)) { auto tt = ggml_internal_get_type_traits(layer.ffn_down->type); diff --git a/src/llama-model.h b/src/llama-model.h index b7ab2cfc..73fa6e83 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -224,6 +224,7 @@ struct llama_layer { llama_split_tensor split_ffn_up; llama_split_tensor split_ffn_gate; llama_split_tensor split_ffn_down; + llama_split_tensor split_ffn_norm; // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr;