diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 9f312863..127a8ac6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -691,6 +691,7 @@ extern "C" { GGML_OP_REDUCE, GGML_OP_FAKE_CPY, + GGML_OP_FUSED_NORM, GGML_OP_COUNT, }; @@ -1487,6 +1488,18 @@ extern "C" { struct ggml_tensor * b, float eps); + GGML_API struct ggml_tensor * ggml_fused_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + + GGML_API struct ggml_tensor * ggml_fused_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + // group normalize along ne0*ne1*n_groups // used in stable-diffusion GGML_API struct ggml_tensor * ggml_group_norm( diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index c2e4e688..828ea4c2 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3208,6 +3208,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_fused_rms_norm(ctx, dst); } break; + case GGML_OP_FUSED_NORM: + ggml_cuda_op_fused_rms_norm(ctx, dst, true); + break; case GGML_OP_MUL_MAT: if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) { GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]); @@ -4166,6 +4169,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ROPE_FAST: case GGML_OP_ROPE_CACHE: return true; + case GGML_OP_FUSED_NORM: + return ggml_is_contiguous(op->src[0]); //case GGML_OP_ROPE: // return ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 26a21088..99c69503 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -36,6 +36,42 @@ static __global__ void norm_f32(const T * x, float * dst, const int ncols, const } } +template +static __global__ void fused_norm_f32(const T * x, const float * c, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float2 mean_var = make_float2(0.f, 0.f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = (float)x[row*ncols + col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + mean_var = warp_reduce_sum(mean_var); + if (block_size > WARP_SIZE) { + __shared__ float2 s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = mean_var; + } + __syncthreads(); + mean_var = s_sum[lane_id]; + mean_var = warp_reduce_sum(mean_var); + } + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = (T)(((float)x[row*ncols + col] - mean) * inv_std * c[col]); + } +} + template static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { // blockIdx.x: num_groups idx @@ -310,26 +346,47 @@ static void rms_norm_f32_nc_cuda( template static void fused_rms_norm_f32_cuda(const src_t * x, const float * y, float * dst, - const int ncols, const int nrows, const float eps, cudaStream_t stream) { + const int ncols, const int nrows, const float eps, bool is_norm, cudaStream_t stream) { constexpr int kBlockSize = 256; GGML_ASSERT(ncols % WARP_SIZE == 0); - if (ncols < kBlockSize) { - switch (ncols) { - case 32: fused_rms_norm_f32< 32><<>>(x, y, dst, ncols, eps); break; - case 64: fused_rms_norm_f32< 64><<>>(x, y, dst, ncols, eps); break; - case 96: fused_rms_norm_f32< 96><<>>(x, y, dst, ncols, eps); break; - case 128: fused_rms_norm_f32<128><<>>(x, y, dst, ncols, eps); break; - case 160: fused_rms_norm_f32<160><<>>(x, y, dst, ncols, eps); break; - case 192: fused_rms_norm_f32<192><<>>(x, y, dst, ncols, eps); break; - default : fused_rms_norm_f32<224><<>>(x, y, dst, ncols, eps); break; + if (is_norm) { + if (ncols < kBlockSize) { + switch (ncols) { + case 32: fused_norm_f32< 32><<>>(x, y, dst, ncols, eps); break; + case 64: fused_norm_f32< 64><<>>(x, y, dst, ncols, eps); break; + case 96: fused_norm_f32< 96><<>>(x, y, dst, ncols, eps); break; + case 128: fused_norm_f32<128><<>>(x, y, dst, ncols, eps); break; + case 160: fused_norm_f32<160><<>>(x, y, dst, ncols, eps); break; + case 192: fused_norm_f32<192><<>>(x, y, dst, ncols, eps); break; + default : fused_norm_f32<224><<>>(x, y, dst, ncols, eps); break; + } + } + else if (ncols < 1024) { + const dim3 block_dims(kBlockSize, 1, 1); + fused_norm_f32<<>>(x, y, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_norm_f32<1024><<>>(x, y, dst, ncols, eps); } - } - else if (ncols < 1024) { - const dim3 block_dims(kBlockSize, 1, 1); - fused_rms_norm_f32<<>>(x, y, dst, ncols, eps); } else { - const dim3 block_dims(1024, 1, 1); - fused_rms_norm_f32<1024><<>>(x, y, dst, ncols, eps); + if (ncols < kBlockSize) { + switch (ncols) { + case 32: fused_rms_norm_f32< 32><<>>(x, y, dst, ncols, eps); break; + case 64: fused_rms_norm_f32< 64><<>>(x, y, dst, ncols, eps); break; + case 96: fused_rms_norm_f32< 96><<>>(x, y, dst, ncols, eps); break; + case 128: fused_rms_norm_f32<128><<>>(x, y, dst, ncols, eps); break; + case 160: fused_rms_norm_f32<160><<>>(x, y, dst, ncols, eps); break; + case 192: fused_rms_norm_f32<192><<>>(x, y, dst, ncols, eps); break; + default : fused_rms_norm_f32<224><<>>(x, y, dst, ncols, eps); break; + } + } + else if (ncols < 1024) { + const dim3 block_dims(kBlockSize, 1, 1); + fused_rms_norm_f32<<>>(x, y, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_rms_norm_f32<1024><<>>(x, y, dst, ncols, eps); + } } } @@ -427,7 +484,7 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } -void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm) { if (!dst->src[1]) { ggml_cuda_op_rms_norm(ctx, dst); return; @@ -453,11 +510,14 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * if (ggml_is_contiguous(src0)) { const int64_t nrows = ggml_nrows(src0); if (src0->type == GGML_TYPE_F32) { - fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream); } else { - fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + fused_rms_norm_f32_cuda((const half *)src0_d, src1_d, dst_d, ne00, nrows, eps, is_norm, stream); } } else { + if (is_norm) { + GGML_ABORT("Non-contiguous norm is not implemented"); + } auto ts0 = ggml_type_size(src0->type); GGML_ASSERT(src0->nb[0] == ts0); auto s01 = src0->nb[1] / ts0; diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index cadf7248..8550633b 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -6,7 +6,7 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); -void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst, bool is_norm = false); void ggml_cuda_op_fused_add_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * add, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/reduce.cu b/ggml/src/ggml-cuda/reduce.cu index 3ad4fb1f..dda70252 100644 --- a/ggml/src/ggml-cuda/reduce.cu +++ b/ggml/src/ggml-cuda/reduce.cu @@ -54,6 +54,10 @@ void ggml_cuda_op_reduce([[maybe_unused]] ggml_backend_cuda_context & ctx, ggml_ GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(dst)); GGML_ASSERT(nhave >=2 && nhave <= nreduce); + if (dst->op_params[3] == 1) { + // The dst tensor is just a container for the sources and the reduce op is turned off + return; + } auto & info = ggml_cuda_info(); #ifdef GGML_USE_NCCL diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f6a0bdee..3d944bf6 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4294,9 +4294,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "REDUCE", "FAKE_CPY", + "FUSED_NORM", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4405,9 +4406,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "reduce(x1,x2,...)", "fake_cpy(x,y)", + "norm(x,y)", }; -static_assert(GGML_OP_COUNT == 94, "GGML_OP_COUNT != 94"); +static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -7406,6 +7408,67 @@ struct ggml_tensor * ggml_fused_rms_norm_inplace( return ggml_fused_rms_norm_impl(ctx, a, b, eps, true); } +static struct ggml_tensor * ggml_fused_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps, + bool inplace) { + + if (!b) { + return ggml_norm_impl(ctx, a, eps, inplace); + } + + if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) { + struct ggml_tensor * result = ggml_norm_impl(ctx, a, eps, inplace); + result = ggml_mul_impl(ctx, result, b, inplace); + return result; + } + + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result; + if (inplace) { + GGML_ASSERT(a->type == GGML_TYPE_F32); + result = ggml_view_tensor(ctx, a); + } else { + if (a->type == GGML_TYPE_F32) { + result = ggml_dup_tensor(ctx, a); + } else { + result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], a->ne[1], a->ne[2], a->ne[3]); + } + } + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_FUSED_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + return ggml_fused_norm_impl(ctx, a, b, eps, false); +} + +struct ggml_tensor * ggml_fused_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + return ggml_fused_norm_impl(ctx, a, b, eps, true); +} + // ggml_rms_norm_back struct ggml_tensor * ggml_rms_norm_back( @@ -15404,6 +15467,88 @@ static void ggml_compute_forward_norm( } } +static void ggml_compute_forward_fused_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (!src1) { + ggml_compute_forward_norm_f32(params, dst); + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps > 0.0f); + + const float * c = (const float *)src1->data; + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + ggml_float xi = (ggml_float)x[i00]; + sum += xi; + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + ggml_float sum2 = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + float v = x[i00] - mean; + y[i00] = v * c[i00]; + sum2 += (ggml_float)(v*v); + } + + float variance = sum2/ne00; + const float scale = 1.0f/sqrtf(variance + eps); + + ggml_vec_scale_f32(ne00, y, scale); + + } + } + } +} + +static void ggml_compute_forward_fused_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_group_rms_norm static void ggml_compute_forward_rms_norm_f32( @@ -22853,6 +22998,10 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_fused_rms_norm(params, tensor); } break; + case GGML_OP_FUSED_NORM: + { + ggml_compute_forward_fused_norm(params, tensor); + } break; case GGML_OP_RMS_NORM_BACK: { ggml_compute_forward_rms_norm_back(params, tensor); @@ -23657,6 +23806,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor } } break; case GGML_OP_FUSED_RMS_NORM: + case GGML_OP_FUSED_NORM: { GGML_ABORT("fatal error"); // TODO: not implemented } @@ -24817,6 +24967,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_FUSED_RMS_NORM: + case GGML_OP_FUSED_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_GROUP_NORM: case GGML_OP_CONCAT: diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 83e65bf7..cb7edac3 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -678,9 +678,10 @@ ggml_tensor * llm_build_context::llm_build_ffn( auto norm = (ggml_split_tensor_t *)ffn_norm->extra; GGML_ASSERT(norm->splits[id]); if (is_norm) { - cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il); - GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM); - cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + //cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM, cb, il); + //GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM); + //cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + cur = ggml_fused_norm(ctx, cur, norm->splits[id], lctx.model.hparams.f_norm_eps); } else { cur = llm_build_norm(ctx, cur, lctx.model.hparams, norm->splits[id], NULL, LLM_NORM_RMS, cb, il); } @@ -700,6 +701,13 @@ ggml_tensor * llm_build_context::llm_build_ffn( if (cur->ne[1] > 32 && lctx.cparams.split_mode_f16) { cur = ggml_cast(ctx, cur, GGML_TYPE_F16); } + if (add_extra && add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1) { + // When the reduce op is turned off via op_params[3] == 1, we need to add each src + // rtaher than add the reduced add_extra result to the ffn reduced ffn result. + GGML_ASSERT(add_extra->src[id]); // TODO: fix this! It can be null if the splits of the attention and ffn tensors are different + cur = ggml_add(ctx, cur, add_extra->src[id]); + cb(cur, "ffn_with_extra", il_cb); + } if (graph) { ggml_build_forward_expand(graph, cur); } @@ -711,7 +719,7 @@ ggml_tensor * llm_build_context::llm_build_ffn( ffn[id_last] = ggml_add(ctx, ffn[id_last], input); cb(ffn[id_last], "ffn_with_inp", il); } - if (add_extra) { + if (add_extra && !(add_extra->op == GGML_OP_REDUCE && add_extra->op_params[3] == 1)) { ffn[id_last] = ggml_add(ctx, ffn[id_last], add_extra); cb(ffn[id_last], "ffn_with_inp", il); } @@ -7287,6 +7295,8 @@ ggml_cgraph * llm_build_context::build_cohere2() { inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } + attn_out->op_params[3] = 1; // i.e., turn off the reduce operation as it is not required + // feed-forward network cur = llm_build_ffn(ctx0, lctx, model.layers[il].attn_norm, inpL, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, @@ -9379,9 +9389,10 @@ ggml_tensor * llm_build_context::build_std_attention(ggml_cgraph * gf, ggml_tens auto cur = get_input_tensor_sm_graph(input, id); if (attn_norm) { if (is_norm) { - cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il); - GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM); - cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + //cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM, cb, il); + //GGML_ASSERT(cur->src[0]->op == GGML_OP_NORM); + //cur->src[0]->op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t) - 1] = 0xff; + cur = ggml_fused_norm(ctx0, cur, attn_norm->splits[id], lctx.model.hparams.f_norm_eps); } else { cur = llm_build_norm(ctx0, cur, lctx.model.hparams, attn_norm->splits[id], NULL, LLM_NORM_RMS, cb, il); }