Fuse the attention gate in Step-3.5-Flash (#1244)

* WIP

* This works but is slow

* Turn off the up / gate clamps for now

* OK we need the clamping

* Fuse the clamp (CUDA)

* Fuse the clamp (CPU)

* WIP

* Be able to use merged q, k, v

* Be able to use merged up/gate experts

* Fuse the clamp (CUDA mmvq)

* WIP: graph parallel for Step-3.5

* WIP

* This should be it

* Cleanup

* Fix merge

* Not working attempt to extend fused_mul_unary to the Step-3.5 case

* It works now, but performance gain is very minor
This commit is contained in:
Kawrakow
2026-02-07 07:56:58 +02:00
committed by GitHub
parent 90d7499c2c
commit 82c4f27332
3 changed files with 106 additions and 23 deletions

View File

@@ -9683,13 +9683,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
GGML_ASSERT(wqkv_gate && wqkv_gate->splits[id]);
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate->splits[id], input_normed);
cb(gate, "attn_gate", il_cb);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il_cb);
int nh = split_wo->ne[0]/n_embd_head_v;
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, nh, n_tokens);
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, nh, n_tokens);
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
cur = ggml_mul(ctx0, attn_3d, gate_3d);
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
cb(attn_3d, "attn_gated_3d", il_cb);
}
@@ -9777,17 +9774,12 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens
cb(cur, "wqkv", il);
auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate, input_normed); // [n_head_l, n_tokens]
cb(gate, "attn_gate", il);
gate = ggml_sigmoid(ctx0, gate);
cb(gate, "attn_gate_sigmoid", il);
// reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
int n_head_l = hparams.n_head(il);
ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d);
cb(gate_3d, "attn_gate_bcast", il);
attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
cb(attn_3d, "attn_gated_3d", il);
cur = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens);
auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens);
cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID);
cb(cur, "attn_gated_3d", il);
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v * n_head_l, n_tokens);
cb(cur, "attn_gated", il);
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
if (model.layers[il].bo) {