From 6ef4f28aae8da5f8c63cbb4ee269e968fa8a8138 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 30 Sep 2024 08:29:34 +0300 Subject: [PATCH] Adding fused y*unary(x) op --- ggml/include/ggml.h | 13 +++++ ggml/src/ggml.c | 125 +++++++++++++++++++++++++++++++++++++++++++- src/llama.cpp | 8 +++ 3 files changed, 144 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 08fe6a3e..b1aebd21 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -487,6 +487,7 @@ extern "C" { GGML_OP_RMS_NORM_BACK, GGML_OP_GROUP_NORM, GGML_OP_FUSED_RMS_NORM, + GGML_OP_FUSED_MUL_UNARY, GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID, @@ -963,6 +964,18 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + GGML_API struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op); + GGML_API struct ggml_tensor * ggml_div( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d31713df..16ef1e8c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3258,6 +3258,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM_BACK", "GROUP_NORM", "FUSED_RMS_NORM", + "FUSED_MUL_UNARY", "MUL_MAT", "MUL_MAT_ID", @@ -3321,7 +3322,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3349,6 +3350,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm_back(x)", "group_norm(x)", "fused_rms_norm(x)", + "fused_mul_unary(x)", "X*Y", "X[i]*Y", @@ -3412,7 +3414,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); +static_assert(GGML_OP_COUNT == 78, "GGML_OP_COUNT != 78"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5246,6 +5248,55 @@ struct ggml_tensor * ggml_mul_inplace( struct ggml_tensor * b) { return ggml_mul_impl(ctx, a, b, true); } +// ggml_mul + +static struct ggml_tensor * ggml_fused_mul_unary_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op, + bool inplace) { + GGML_ASSERT(ggml_are_same_shape(b, a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + bool is_node = false; + + if (!inplace && (a->grad || b->grad)) { + is_node = true; + } + + if (inplace) { + GGML_ASSERT(!is_node); + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) op); + + result->op = GGML_OP_FUSED_MUL_UNARY; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_mul_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, false); +} + +struct ggml_tensor * ggml_fused_mul_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_unary_op op) { + return ggml_fused_mul_unary_impl(ctx, a, b, op, true); +} // ggml_div @@ -12374,6 +12425,67 @@ static void ggml_compute_forward_swiglu( } } +// ggml_compute_forward_fused_mul_unary + +static void ggml_compute_forward_fused_mul_unary_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + enum ggml_unary_op op = (enum ggml_unary_op)dst->op_params[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(op == GGML_UNARY_OP_GELU || op == GGML_UNARY_OP_RELU || op == GGML_UNARY_OP_SILU); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = dst->ne[0]; + const int nr = ggml_nrows(src0); + + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int i1 = ir0; i1 < ir1; i1++) { + float * z = (float *) ((char *) dst->data + i1*( dst->nb[1])); + const float * x = (const float *) ((char *) src0->data + i1*(src0->nb[1])); + const float * y = (const float *) ((char *) src1->data + i1*(src1->nb[1])); + switch (op) { + case GGML_UNARY_OP_GELU: ggml_vec_gelu_f32(nc, z, x); break; + case GGML_UNARY_OP_RELU: ggml_vec_relu_f32(nc, z, x); break; + case GGML_UNARY_OP_SILU: ggml_vec_silu_f32(nc, z, x); break; + default: GGML_ABORT("fatal error"); + } + ggml_vec_mul_f32(nc, z, z, y); + } +} + +static void ggml_compute_forward_fused_mul_unary( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_mul_unary_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( @@ -17990,6 +18102,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul(params, tensor); } break; + case GGML_OP_FUSED_MUL_UNARY: + { + ggml_compute_forward_fused_mul_unary(params, tensor); + } break; case GGML_OP_DIV: { ggml_compute_forward_div(params, tensor); @@ -18715,6 +18831,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_FUSED_MUL_UNARY: + { + GGML_ABORT("fatal error"); // TODO: implement + } case GGML_OP_CONCAT: { GGML_ABORT("fatal error"); // TODO: implement @@ -19813,6 +19933,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: + case GGML_OP_FUSED_MUL_UNARY: case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: diff --git a/src/llama.cpp b/src/llama.cpp index eb982125..9ed109c6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8083,6 +8083,13 @@ static struct ggml_tensor * llm_build_ffn( cur = tmp; } + if (type_gate == LLM_FFN_PAR && + (type_op == LLM_FFN_SILU || type_op == LLM_FFN_RELU || (type_op == LLM_FFN_GELU && !act_scales))) { + cur = ggml_fused_mul_unary(ctx, cur, tmp, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : + type_op == LLM_FFN_RELU ? GGML_UNARY_OP_RELU : GGML_UNARY_OP_GELU); + } + else { + switch (type_op) { case LLM_FFN_SILU: { @@ -8122,6 +8129,7 @@ static struct ggml_tensor * llm_build_ffn( cur = ggml_mul(ctx, cur, tmp); cb(cur, "ffn_gate_par", il); } + } if (down) { cur = llm_build_lora_mm(lctx, ctx, down, cur);