Fused norm (#1086)

* Adding fused_norm - same idea as fused_rms_norm

* Avoid computing the attention reduce op for cohere2

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-12-24 15:22:43 +01:00
committed by GitHub
parent 1ace5b7526
commit fbb67fa2bd
7 changed files with 273 additions and 29 deletions

View File

@@ -678,9 +678,10 @@ ggml_tensor * llm_build_context::llm_build_ffn(
auto norm = (ggml_split_tensor_t *)ffn_norm->extra;
GGML_ASSERT(norm->splits[id]);
if (is_norm) {
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il);
GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
//cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il);
//GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
//cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
cur = ggml_fused_norm(ctx, cur, norm->splits[id], lctx.model.hparams.f_norm_eps);
} else {
cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
}
@@ -700,6 +701,13 @@ ggml_tensor * llm_build_context::llm_build_ffn(
if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) {
cur = ggml_cast(ctx, cur, GGML_TYPE_F16);
}
if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) {
// When the reduce op is turned off via op_params[3] == 1, we need to add each src
// rtaher than add the reduced add_extra result to the ffn reduced ffn result.
GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different
cur = ggml_add(ctx, cur, add_extra->src[id]);
cb(cur, "ffn_with_extra", il_cb);
}
if (graph) {
ggml_build_forward_expand(graph, cur);
}
@@ -711,7 +719,7 @@ ggml_tensor * llm_build_context::llm_build_ffn(
ffn[id_last] = ggml_add(ctx, ffn[id_last], input);
cb(ffn[id_last], "ffn_with_inp", il);
}
if (add_extra) {
if (add_extra && !(add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1)) {
ffn[id_last] = ggml_add(ctx, ffn[id_last], add_extra);
cb(ffn[id_last], "ffn_with_inp", il);
}
@@ -7287,6 +7295,8 @@ ggml_cgraph * llm_build_context::build_cohere2() {
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
}
attn_out->op_params[3] = 1; // i.e., turn off the reduce operation as it is not required
// feed-forward network
cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate,
NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR,
@@ -9379,9 +9389,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
auto cur = get_input_tensor_sm_graph(input, id);
if (attn_norm) {
if (is_norm) {
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il);
GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
//cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il);
//GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM);
//cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff;
cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps);
} else {
cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il);
}