From 82c4f273323d93d567e4cda9817df673b76b4d14 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Sat, 7 Feb 2026 07:56:58 +0200 Subject: [PATCH] Fuse the attention gate in Step-3.5-Flash (#1244) * WIP * This works but is slow * Turn off the up / gate clamps for now * OK we need the clamping * Fuse the clamp (CUDA) * Fuse the clamp (CPU) * WIP * Be able to use merged q, k, v * Be able to use merged up/gate experts * Fuse the clamp (CUDA mmvq) * WIP: graph parallel for Step-3.5 * WIP * This should be it * Cleanup * Fix merge * Not working attempt to extend fused_mul_unary to the Step-3.5 case * It works now, but performance gain is very minor --- ggml/src/ggml-cuda/unary.cu | 62 ++++++++++++++++++++++++++++++++++++- ggml/src/ggml.c | 47 +++++++++++++++++++++++----- src/llama-build-context.cpp | 20 ++++-------- 3 files changed, 106 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 75362778..b73a46db 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -72,6 +72,38 @@ static __global__ void fused_mul_silu_f32(const float * x, const float * y, floa dst[i] = g * max(-limit, min(limit, y[i])); } +static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + dst[i] = x[row] * y[i] / (1.0f + expf(-x[row])); +} + +static __global__ void fused_mul_sigmoid_f32(int ne0, const float * x, const float * y, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + dst[i] = y[i] / (1.0f + expf(-x[row])); +} + +static __global__ void fused_mul_silu_f32(int ne0, const float * x, const float * y, float * dst, const int k, float limit) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + int row = i / ne0; + float g = x[row] / (1.0f + expf(-x[row])); + g = min(g, limit); + dst[i] = g * max(-limit, min(limit, y[i])); +} + static __global__ void fused_mul_silu_f32(const float * x, float * dst, const int k, const int ne0) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -257,6 +289,20 @@ static void fused_mul_silu_f32_cuda(const float * x, const float * y, float * ds } } +static void fused_mul_silu_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, float limit, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + if (limit < 1e-6f) { + fused_mul_silu_f32<<>>(ne0, x, y, dst, k); + } else { + fused_mul_silu_f32<<>>(ne0, x, y, dst, k, limit); + } +} + +static void fused_mul_sigmoid_f32_cuda(int ne0, const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + fused_mul_sigmoid_f32<<>>(ne0, x, y, dst, k); +} + static void fused_mul_relu_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; fused_mul_relu_f32<<>>(x, y, dst, k); @@ -410,7 +456,21 @@ void ggml_cuda_op_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_tensor * GGML_ASSERT(ggml_is_contiguous(src0)); if (src1) { - GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src1, dst)); + if (!ggml_are_same_shape(src0, src1)) { + GGML_ASSERT(src0->ne[0] == 1 && src0->ne[1] == src1->ne[1] && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3]); + if (op == GGML_UNARY_OP_SILU) { + fused_mul_silu_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data, + ggml_nelements(dst), limit, ctx.stream()); + } + else if (op == GGML_UNARY_OP_SIGMOID) { + fused_mul_sigmoid_f32_cuda(src1->ne[0], (const float *)src0->data, (const float *)src1->data, (float *)dst->data, + ggml_nelements(dst), ctx.stream()); + } else { + GGML_ABORT("Fatal error"); + } + return; + } GGML_ASSERT(ggml_are_same_shape(src0, src1)); ggml_fused_mul_unary(ctx, op, ggml_nelements(dst), (const float *)src0->data, (const float *)src1->data, (float *)dst->data, limit); } else { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191bdcd1..b899678c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6472,9 +6472,20 @@ static struct ggml_tensor * ggml_fused_mul_unary_impl( struct ggml_tensor * b, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(b, a)); + GGML_ASSERT(ggml_is_contiguous(a)); + if (!ggml_are_same_shape(b, a)) { + GGML_ASSERT(a->ne[0] == 1 && a->ne[1] == b->ne[1] && a->ne[2] == b->ne[2] && a->ne[3] == b->ne[3]); + GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID); + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, b) : ggml_dup_tensor(ctx, b); + ggml_set_op_params_i32(result, 0, (int32_t) op); + result->op = GGML_OP_FUSED_MUL_UNARY; + result->src[0] = a; + result->src[1] = b; + return result; + } GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + //GGML_ASSERT(ggml_are_same_shape(b, a)); bool is_node = false; @@ -15158,17 +15169,11 @@ static void ggml_compute_forward_fused_mul_unary_f32( enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; const float limit = *(const float *)(dst->op_params + 1); - GGML_ASSERT(ggml_is_contiguous_1(src0)); - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); - const int ith = params->ith; const int nth = params->nth; const int nc = dst->ne[0]; - const int nr = ggml_nrows(src0); - + const int nr = ggml_nrows(dst); // rows per thread const int dr = (nr + nth - 1)/nth; @@ -15177,6 +15182,32 @@ static void ggml_compute_forward_fused_mul_unary_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + if (!ggml_are_same_shape(src0, src1)) { + GGML_ASSERT(src0->ne[0] == 1 && ggml_nrows(src0) == nr); + GGML_ASSERT(op == GGML_UNARY_OP_SILU || op == GGML_UNARY_OP_SIGMOID); + for (int i1 = ir0; i1 < ir1; i1++) { + float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); + const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); + const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1])); + float gate = op == GGML_UNARY_OP_SILU ? ggml_silu_f32(x[0]) : 1.0f/(1.0f + expf(-x[0])); + if (limit < 1e-6f) { + for (int i = 0; i < nc; ++i) z[i] = gate * y[i]; + } else { + gate = MIN(gate, limit); + for (int i = 0; i < nc; ++i) { + float up = MAX(-limit, MIN(limit, y[i])); + z[i] = up * gate; + } + } + } + return; + } + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + for (int i1 = ir0; i1 < ir1; i1++) { float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index b4143827..d2c64475 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -9683,13 +9683,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens GGML_ASSERT(wqkv_gate && wqkv_gate->splits[id]); auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate->splits[id], input_normed); cb(gate, "attn_gate", il_cb); - gate = ggml_sigmoid(ctx0, gate); - cb(gate, "attn_gate_sigmoid", il_cb); int nh = split_wo->ne[0]/n_embd_head_v; auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, nh, n_tokens); auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, nh, n_tokens); - gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d); - cur = ggml_mul(ctx0, attn_3d, gate_3d); + cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID); cb(attn_3d, "attn_gated_3d", il_cb); } @@ -9777,17 +9774,12 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens cb(cur, "wqkv", il); auto gate = llm_build_lora_mm(lctx, ctx0, wqkv_gate, input_normed); // [n_head_l, n_tokens] cb(gate, "attn_gate", il); - gate = ggml_sigmoid(ctx0, gate); - cb(gate, "attn_gate_sigmoid", il); - // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens] int n_head_l = hparams.n_head(il); - ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens); - ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); - gate_3d = ggml_repeat(ctx0, gate_3d, attn_3d); - cb(gate_3d, "attn_gate_bcast", il); - attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); - cb(attn_3d, "attn_gated_3d", il); - cur = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + auto attn_3d = ggml_reshape_3d(ctx0, cur, n_embd_head_v, n_head_l, n_tokens); + auto gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cur = ggml_fused_mul_unary(ctx0, gate_3d, attn_3d, GGML_UNARY_OP_SIGMOID); + cb(cur, "attn_gated_3d", il); + cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v * n_head_l, n_tokens); cb(cur, "attn_gated", il); cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); if (model.layers[il].bo) {