From d2225010b903bf16ea12e594279f7854eb94d9b1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 8 Sep 2024 07:06:38 +0200 Subject: [PATCH] Fused rms_norm WIP --- ggml/src/ggml-metal.m | 1 + src/llama.cpp | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index ea4f6e19..b3f6e60c 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2616,6 +2616,7 @@ static enum ggml_status ggml_metal_graph_compute( GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ggml_is_contiguous_1(src0)); GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(ggml_nrows(src1) == 1); float eps; diff --git a/src/llama.cpp b/src/llama.cpp index 6eee71e1..768aafa7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7988,9 +7988,12 @@ static struct ggml_tensor * llm_build_norm( const llm_build_cb & cb, int il, float scale_eps = 1) { - if (type == LLM_NORM_RMS && !mb) { + if (type == LLM_NORM_RMS && mw) { cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps); - cb(cur, "fused_norm", il); + if (mb) { + cb(cur, "fused_norm", il); + cur = ggml_add(ctx, cur, mb); + } return cur; }