From d5cfbca86bbddc7c48203677c5827796cdd1c787 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 21 Oct 2025 16:17:59 +0300 Subject: [PATCH] Fuse add+add+fused_rms --- ggml/src/ggml-cuda.cu | 19 ++++++++- ggml/src/ggml-cuda/norm.cu | 81 +++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/norm.cuh | 2 + src/llama-build-context.cpp | 1 + 4 files changed, 101 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 563433ec..fd627953 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3129,11 +3129,26 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_ADD: - if (i + 1 < cgraph->n_nodes && + if (i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_ADD && + cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && + ggml_is_contiguous(dst->src[0]) && + ggml_is_contiguous(dst->src[1]) && + ggml_are_same_shape(dst->src[0], dst->src[1]) && + dst == cgraph->nodes[i+1]->src[0] && + ggml_is_contiguous(cgraph->nodes[i+1]->src[1]) && + ggml_are_same_shape(dst, cgraph->nodes[i+1]->src[1]) && + cgraph->nodes[i+1] == cgraph->nodes[i+2]->src[0]) { + //printf("Fusing add->add->fused_rms of %s, %s, %s\n", dst->name, cgraph->nodes[i+1]->name, cgraph->nodes[i+2]->name); + ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + } + else if (i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && - ggml_are_same_shape(dst->src[0], dst->src[1])) { + ggml_are_same_shape(dst->src[0], dst->src[1]) && + dst == cgraph->nodes[i+1]->src[0]) { ggml_cuda_op_fused_add_rms_norm(ctx, dst, cgraph->nodes[i+1]); ++i; } else { diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 5a49132a..f296b79f 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -492,6 +492,41 @@ static __global__ void fused_add_rms_norm_f32(const float * a, const float * b, } } +template +static __global__ void fused_add_add_rms_norm_f32(const float * a1, const float * a2, const float * b, const float * c, + float * dst_add, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = a1[row*ncols + col] + a2[row*ncols + col] + b[row*ncols + col]; + tmp += xi * xi; + dst_add[row*ncols + col] = xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = lane_id < block_size/WARP_SIZE ? s_sum[lane_id] : 0.0f; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * c[col] * dst_add[row*ncols + col]; + } +} static void fused_add_rms_norm_f32_cuda(const float * a, const float * b, const float * c, float * dst_add, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { @@ -538,3 +573,49 @@ void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tenso src1_d, (float *)add->data, dst_d, ne00, nrows, eps, stream); } +static void fused_add_add_rms_norm_f32_cuda(const float * a1, const float * a2, const float * b, const float * c, float * dst_add, float * dst, + const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + fused_add_add_rms_norm_f32<256><<>>(a1, a2, b, c, dst_add, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_add_add_rms_norm_f32<1024><<>>(a1, a2, b, c, dst_add, dst, ncols, eps); + } +} + +void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, + ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + //const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(add1->data == add2->src[0]->data); + GGML_ASSERT(add2->data == src0->data); + GGML_ASSERT(ggml_is_contiguous(src0)); + //GGML_ASSERT(ggml_is_contiguous(add->src[0])); + //GGML_ASSERT(ggml_is_contiguous(add->src[1])); + //GGML_ASSERT(ggml_are_same_shape(add->src[0], add->src[1])); + //GGML_ASSERT(ggml_are_same_shape(add->src[0], src0)); + //GGML_ASSERT(add->src[0]->type == GGML_TYPE_F32); + //GGML_ASSERT(add->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int64_t ne00 = src0->ne[0]; + + const int64_t nrows = ggml_nrows(src0); + fused_add_add_rms_norm_f32_cuda((const float *)add1->src[0]->data, (const float *)add1->src[1]->data, (const float *)add2->src[1]->data, + src1_d, (float *)add2->data, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 29d67d2e..40f758de 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -9,3 +9,5 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst); + +void ggml_cuda_op_fused_add_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add1, ggml_tensor * add2, ggml_tensor * dst); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 0ccdcc6e..3e55cc59 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -7793,6 +7793,7 @@ ggml_cgraph * llm_build_context::build_openai_moe() { cur = ffn_inp; cur = llm_build_norm(ctx0, cur, hparams, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, cb, il); + ggml_build_forward_expand(gf, cur); cb(cur, "attn_post_norm", il); bool use_dup_bias = cur->ne[1] < 32 && model.layers[il].ffn_up_exps_b_dup &&