diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9827cc33..ab6d172d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1164,14 +1164,12 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * c, float eps); GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * c, float eps); // group normalize along ne0*ne1*n_groups diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index b9a969f6..7e670912 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -197,7 +197,6 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } - //fused_rms_norm_f32_cuda(src0_d, src1_d, src2_d, dst_d, ne00, nrows, eps, ne0, ne1, ne2, nb1, nb2, stream); static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); @@ -271,20 +270,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - if (!dst->src[1] && !dst->src[2]) { + if (!dst->src[1]) { ggml_cuda_op_rms_norm(ctx, dst); return; } const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); const int64_t ne00 = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); @@ -292,7 +295,5 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const float * src1_d = (const float *)dst->src[1]->data; - fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 98d1a116..d562002e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5743,18 +5743,18 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * c, float eps, bool inplace) { - if (!b && !c) { + if (!b) { return ggml_rms_norm_impl(ctx, a, eps, inplace); } - //printf("%s: %zd x %zd x %zd %zd", __func__, a->ne[0], a->ne[1], a->ne[2], a->ne[3]); - //if (b) printf(", b = %zd x %zd x %zd %zd, ", b->ne[0], b->ne[1], b->ne[2], b->ne[3]); - //if (c) printf(", c = %zd x %zd x %zd %zd, ", c->ne[0], c->ne[1], c->ne[2], c->ne[3]); - //printf("\n"); + if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) { + struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace); + result = ggml_mul_impl(ctx, result, b, inplace); + return result; + } bool is_node = false; @@ -5770,7 +5770,6 @@ static struct ggml_tensor * ggml_fused_rms_norm_impl( result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -5779,18 +5778,16 @@ struct ggml_tensor * ggml_fused_rms_norm( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * c, float eps) { - return ggml_fused_rms_norm_impl(ctx, a, b, c, eps, false); + return ggml_fused_rms_norm_impl(ctx, a, b, eps, false); } struct ggml_tensor * ggml_fused_rms_norm_inplace( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, - struct ggml_tensor * c, float eps) { - return ggml_fused_rms_norm_impl(ctx, a, b, c, eps, true); + return ggml_fused_rms_norm_impl(ctx, a, b, eps, true); } // ggml_rms_norm_back @@ -12517,9 +12514,8 @@ static void ggml_compute_forward_fused_rms_norm_f32( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - const struct ggml_tensor * src2 = dst->src[2]; - if (!src1 && !src2) { + if (!src1) { ggml_compute_forward_rms_norm_f32(params, dst); return; } @@ -12527,8 +12523,9 @@ static void ggml_compute_forward_fused_rms_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT(!src1 || src1->nb[0] == sizeof(float)); - GGML_ASSERT(!src2 || src2->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); const int ith = params->ith; const int nth = params->nth; @@ -12555,37 +12552,11 @@ static void ggml_compute_forward_fused_rms_norm_f32( float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - memcpy(y, x, ne00 * sizeof(float)); - const float scale = 1.0f/sqrtf(mean + eps); + ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data); ggml_vec_scale_f32(ne00, y, scale); - if (src1) { - const int64_t i13 = i03 % src1->ne[3]; - const int64_t i12 = i02 % src1->ne[2]; - const int64_t i11 = i01 % src1->ne[1]; - const int64_t nr0 = ne00 / src1->ne[0]; - - float * src1_ptr = (float *) ((char *) src1->data + i13*src1->nb[3] + i12*src1->nb[2] + i11*src1->nb[1]); - - for (int64_t r = 0 ; r < nr0; ++r) { - ggml_vec_mul_f32(src1->ne[0], y + r*src1->ne[0], y + r*src1->ne[0], src1_ptr); - } - } - - if (src2) { - const int64_t i23 = i03 % src2->ne[3]; - const int64_t i22 = i02 % src2->ne[2]; - const int64_t i21 = i01 % src2->ne[1]; - const int64_t nr0 = ne00 / src2->ne[0]; - - float * src2_ptr = (float *) ((char *) src2->data + i23*src2->nb[3] + i22*src2->nb[2] + i21*src2->nb[1]); - - for (int64_t r = 0 ; r < nr0; ++r) { - ggml_vec_mul_f32(src1->ne[0], y + r*src1->ne[0], y + r*src1->ne[0], src2_ptr); - } - } } } } diff --git a/src/llama.cpp b/src/llama.cpp index b2510585..6eee71e1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7988,8 +7988,8 @@ static struct ggml_tensor * llm_build_norm( const llm_build_cb & cb, int il, float scale_eps = 1) { - if (type == LLM_NORM_RMS) { - cur = ggml_fused_rms_norm(ctx, cur, mw, mb, scale_eps * hparams.f_norm_rms_eps); + if (type == LLM_NORM_RMS && !mb) { + cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps); cb(cur, "fused_norm", il); return cur; }