WIP: absorb adding input into std_attn and std_ffn

This commit is contained in:
Iwan Kawrakow
2025-12-21 06:47:23 +00:00
parent 5562605076
commit e2f325fad3
2 changed files with 69 additions and 28 deletions

View File

@@ -637,7 +637,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
ggml_tensor * act_scales,
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
const llm_build_cb & cb, int il, ggml_cgraph * graph) {
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) {
if (!up_b && !up_s && !gate_b && !gate_s && !down_b && !down_s &&
up->extra && gate->extra && down->extra && type_gate == LLM_FFN_PAR &&
@@ -661,11 +661,11 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (ffn_norm && ffn_norm->extra) {
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
GGML_ASSERT(norm->splits[id]);
cur = llm_build_norm(ctx, input, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
cb(cur, "ffn_inp_normed", il_cb);
}
else if (input->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, input, GGML_TYPE_F32);
else if (cur->type != GGML_TYPE_F32) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F32);
}
cur = ggml_fused_up_gate(ctx, split_u, split_g, cur, unary_op);
cb(cur, "ffn_up_gate", il_cb);
@@ -683,6 +683,10 @@ ggml_tensor * llm_build_context::llm_build_ffn(
}
ffn.push_back(cur);
}
if (add_input) {
ffn.back() = ggml_add(ctx, ffn.back(), input);
cb(ffn.back(), "ffn_with_inp", il);
}
if (ffn.size() == 1) return ffn.front();
auto cur = ggml_add(ctx, ffn[0], ffn[1]);
cb(cur, "combine_ffn", il);
@@ -849,6 +853,11 @@ ggml_tensor * llm_build_context::llm_build_ffn(
cb(cur, "ffn_down_s", il);
}
if (add_input) {
cur = ggml_add(ctx, cur, input);
cb(cur, "ffn_out_with_inp", il);
}
return cur;
}
@@ -868,7 +877,9 @@ ggml_tensor * llm_build_context::llm_build_moe_ffn(
bool scale_w,
float w_scale,
llm_expert_gating_func_type gating_op,
const llm_build_cb & cb, int il, ggml_cgraph * graph) {
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) {
auto input = cur;
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
@@ -1040,20 +1051,30 @@ llm_expert_gating_func_type gating_op,
if (lctx.cparams.fused_mmad) {
experts = ggml_mul_multi_add(ctx, experts, weights);
cb(experts, "ffn_moe_weighted", il);
if (add_input) {
experts = ggml_add(ctx, experts, input);
cb(experts, "ffn_out_with_inp", il);
}
return experts;
}
experts = ggml_mul(ctx, experts, weights);
cb(experts, "ffn_moe_weighted", il);
}
ggml_tensor * result;
if (n_expert_used == 1) {
return ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
result = ggml_cont(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0));
}
if (n_expert_used == 2) {
return ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
result = ggml_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0),
ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], experts->nb[1]));
}
return ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
result = ggml_multi_add(ctx, ggml_view_2d(ctx, experts, n_embd, n_tokens, experts->nb[2], 0), n_expert_used);
if (add_input) {
cb(result, "ffn_out", il);
result = ggml_add(ctx, result, input);
}
return result;
}
@@ -1076,7 +1097,7 @@ ggml_tensor * llm_build_context::llm_build_std_moe_ffn(ggml_context * ctx, llama
float w_scale,
llm_expert_gating_func_type gating_op,
llm_ffn_op_type type_op_shexp,
const llm_build_cb & cb, int il, ggml_cgraph * graph) {
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input) {
auto split_up_exps = (ggml_split_tensor_t *)up_exps->extra;
auto split_gate_exps = (ggml_split_tensor_t *)gate_exps->extra;
@@ -1110,7 +1131,7 @@ llm_expert_gating_func_type gating_op,
the_exp_probs_b,
n_expert, n_expert_used,
type_op, norm_w, scale_w, w_scale,
gating_op, cb, il, graph);
gating_op, cb, il, graph, add_input);
cb(routed_out, "routed_out", il);
ggml_build_forward_expand(graph, routed_out);
@@ -1206,7 +1227,7 @@ llm_expert_gating_func_type gating_op,
split_exp_probs_b ? split_exp_probs_b->splits[id] : nullptr,
n_expert, n_expert_used,
type_op, norm_w, scale_w, w_scale,
gating_op, cb, il, graph);
gating_op, cb, il, graph, add_input);
cb(routed_out, "routed_out", il_cb);
if (split_up_shexp) {
@@ -1754,7 +1775,7 @@ ggml_cgraph * llm_build_context::build_llama() {
// self-attention
if (use_rope) {
cur = build_std_attention(gf, model.layers[il].attn_norm, inpL, inp_pos, nullptr,
this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il);
this_KQ_mask, nullptr, nullptr, kq_scale, hparams.f_attention_scale, this_n_swa, il, true, false, true);
}
else {
@@ -1807,9 +1828,11 @@ ggml_cgraph * llm_build_context::build_llama() {
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
n_tokens = n_outputs;
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
cb(cur, "last_attn", il);
cb(inpSA, "last_ffn_inp", il);
if (use_rope) {
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
cb(inpSA, "last_ffn_inp", il);
}
}
// For Granite architecture
@@ -1818,8 +1841,13 @@ ggml_cgraph * llm_build_context::build_llama() {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
ggml_tensor * ffn_inp;
if (use_rope) {
ffn_inp = cur;
} else {
ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
}
// feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) {
@@ -1829,7 +1857,7 @@ ggml_cgraph * llm_build_context::build_llama() {
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf);
LLM_FFN_SILU, LLM_FFN_PAR, cb, il, gf, true);
cb(cur, "ffn_out", il);
} else if (model.arch == LLM_ARCH_LLAMA4) {
// llama4 MoE
@@ -1846,7 +1874,7 @@ ggml_cgraph * llm_build_context::build_llama() {
LLM_FFN_SILU, false,
false, 0.0,
LLM_EXPERT_GATING_FUNC_SIGMOID,
cb, il, gf);
cb, il, gf, true);
// Shared experts
ggml_tensor * shexp_out = llm_build_ffn(ctx0, lctx, nullptr, ffn_inp_normed,
@@ -1875,7 +1903,7 @@ ggml_cgraph * llm_build_context::build_llama() {
LLM_FFN_SILU, true,
false, 0.0,
LLM_EXPERT_GATING_FUNC_SOFTMAX,
cb, il, gf);
cb, il, gf, true);
cb(cur, "ffn_moe_out", il);
}
@@ -1885,8 +1913,8 @@ ggml_cgraph * llm_build_context::build_llama() {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
//cur = ggml_add(ctx0, cur, ffn_inp);
//cb(cur, "ffn_out", il);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);
@@ -9312,7 +9340,7 @@ ggml_cgraph * llm_build_context::llama_build_graph(
ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tensor * the_attn_norm,
ggml_tensor * input, ggml_tensor * inp_pos, ggml_tensor * rope_factors_in,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale,
int n_swa, int il, bool do_rope, bool add_graph_split) {
int n_swa, int il, bool do_rope, bool add_graph_split, bool add_input) {
if (!model.layers[il].wqkv && !model.layers[il].wqk && cparams.flash_attn &&
model.layers[il].wq->extra && model.layers[il].wk->extra && model.layers[il].wv->extra && model.layers[il].wo->extra) {
if (kv_self.k_l[il]->extra && kv_self.v_l[il]->extra) {
@@ -9489,6 +9517,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
attn.push_back(cur);
}
GGML_ASSERT(!attn.empty());
if (add_input) {
attn.back() = ggml_add(ctx0, attn.back(), input);
cb(attn.back(), "attn_out_with_input", il);
}
if (attn.size() == 1) return attn.front();
//if (attn.size() > 2 && attn.size()%2 == 0) {
// for (int id = 0; id < int(attn.size()/2); ++id) {
@@ -9515,6 +9547,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
if (attn.size() > 2) {
cur->op_params[0] = 0xff;
}
//if (add_input) {
// cur = ggml_add(ctx0, cur, input);
// cb(cur, "combine_attn_inp", il);
//}
return cur;
}
}
@@ -9549,5 +9585,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, KQ_scale, cb, il, sinks, n_swa);
if (add_input) {
cb(cur, "attn_out", il);
cur = ggml_add(ctx0, cur, input);
}
return cur;
}

View File

@@ -335,7 +335,7 @@ struct llm_build_context {
ggml_tensor * act_scales,
llm_ffn_op_type type_op,
llm_ffn_gate_type type_gate,
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr);
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false);
static ggml_tensor * llm_build_moe_ffn(ggml_context * ctx, llama_context & lctx,
ggml_tensor * cur,
@@ -351,7 +351,7 @@ struct llm_build_context {
bool scale_w,
float w_scale,
llm_expert_gating_func_type gating_op,
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr);
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false);
static ggml_tensor * llm_build_moe_ffn(ggml_context * ctx, llama_context & lctx,
ggml_tensor * cur,
@@ -367,7 +367,7 @@ llm_expert_gating_func_type gating_op,
bool scale_w,
float w_scale,
llm_expert_gating_func_type gating_op,
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr) {
const llm_build_cb & cb, int il, ggml_cgraph * graph = nullptr, bool add_input = false) {
return llm_build_moe_ffn(ctx, lctx, cur,
gate_inp, nullptr,
up_exps, nullptr,
@@ -376,7 +376,7 @@ llm_expert_gating_func_type gating_op,
exp_probs_b,
n_expert, n_expert_used,
type_op, norm_w, scale_w, w_scale,
gating_op, cb, il, graph);
gating_op, cb, il, graph, add_input);
}
static ggml_tensor * llm_build_std_moe_ffn(ggml_context * ctx, llama_context & lctx,
@@ -398,7 +398,7 @@ llm_expert_gating_func_type gating_op,
float w_scale,
llm_expert_gating_func_type gating_op,
llm_ffn_op_type type_op_shexp,
const llm_build_cb & cb, int il, ggml_cgraph * graph);
const llm_build_cb & cb, int il, ggml_cgraph * graph, bool add_input = false);
static ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids);
@@ -410,6 +410,6 @@ llm_expert_gating_func_type gating_op,
ggml_tensor * build_std_attention(ggml_cgraph * gf, ggml_tensor * attn_norm, ggml_tensor * cur, ggml_tensor * inp_pos, ggml_tensor * rope_factors,
ggml_tensor * KQ_mask, ggml_tensor * sinks, ggml_tensor * inp_attn_scale, float KQ_scale, float f_attn_scale,
int n_swa, int il, bool do_rope = true, bool add_graph_split = false);
int n_swa, int il, bool do_rope = true, bool add_graph_split = false, bool add_input = false);
};