From c4951cbc357a161086029ea44b3ef6d07a4f0aea Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 1 Aug 2024 20:32:28 +0300 Subject: [PATCH] Softcap: WIP Fuses scale + tanh + scale as used for softcaping in some models. Just CPU for now. ~1.4% for PP-512 on Gemma2-9b, no effect on TG. Somewhat surprisingly the improvement does not increase as I go to longer contexts. Gemma2 does softcap on K*Q, which grows quadratically with context length, so I would have thought the benefit from fusing scale, tanh, scale would increase. But no, no luck. --- ggml/include/ggml.h | 14 ++++ ggml/src/ggml.c | 189 +++++++++++++++++++++++++++++++++++++++++++- src/llama.cpp | 20 +++-- 3 files changed, 213 insertions(+), 10 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 026993db..17d3cb1a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -514,6 +514,7 @@ extern "C" { GGML_OP_TIMESTEP_EMBEDDING, GGML_OP_ARGSORT, GGML_OP_LEAKY_RELU, + GGML_OP_SOFTCAP, GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_ATTN_BACK, @@ -1223,6 +1224,19 @@ extern "C" { struct ggml_tensor * a, float s); + GGML_API struct ggml_tensor * ggml_softcap( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_softcap_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after); + // b -> view(a,offset,nb1,nb2,3), return modified a GGML_API struct ggml_tensor * ggml_set( struct ggml_context * ctx, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e7f1ae61..94c3eb3a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -2558,6 +2558,14 @@ inline static float32x4_t ggml_v_tanh(float32x4_t x) { return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } +inline static float32x4_t ggml_v_softcap(float32x4_t x, float s_before, float s_after) { + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f*s_before)); + const float32x4_t exp_two_x = ggml_v_expf(two_x); + const float32x4_t th = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); + return vmulq_f32(th, vdupq_n_f32(s_after)); +} + #elif defined(__AVX512F__) && defined(__AVX512DQ__) // adapted from arm limited optimized routine @@ -2607,6 +2615,13 @@ inline static __m512 ggml_v_tanh(__m512 x) { return _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); } +inline static __m512 ggml_v_softcap(__m512 x, __m512 s_before, __m512 s_after) { + const __m512 one = _mm512_set1_ps(1.0f); + const __m512 exp_two_x = ggml_v_expf(_mm512_mul_ps(x, s_before)); + const __m512 th = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); + return _mm512_mul_ps(th, s_after); +} + #elif defined(__AVX2__) && defined(__FMA__) // adapted from arm limited optimized routine @@ -2668,6 +2683,13 @@ inline static __m256 ggml_v_tanh(__m256 x) { return _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); } +inline static __m256 ggml_v_softcap(__m256 x, float s_before, float s_after) { + const __m256 one = _mm256_set1_ps(1.0f); + const __m256 exp_two_x = ggml_v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f*s_before))); + const __m256 th = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); + return _mm256_mul_ps(th, _mm256_set1_ps(s_after)); +} + #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON #if defined(__FMA__) @@ -2728,6 +2750,13 @@ inline static __m128 ggml_v_tanh(__m128 x) { return _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one)); } +inline static __m128 ggml_v_softcap(__m128 x, float s_before, float s_after) { + const __m128 one = _mm_set1_ps(1.0f); + const __m128 exp_two_x = ggml_v_expf(_mm_mul_ps(x, _mm_set1_ps(2.f*s_before))); + const __m128 th = _mm_div_ps(_mm_sub_ps(exp_two_x, one), _mm_add_ps(exp_two_x, one)); + return _mm_mul_ps(th, _mm_set1_ps(s_after)); +} + #endif // __ARM_NEON / __AVX2__ / __SSE2__ static void ggml_vec_silu_f32(const int n, float * y, const float * x) { @@ -2778,6 +2807,42 @@ static void ggml_vec_tanh_f32(const int n, float * y, const float * x) { } } +static void ggml_vec_softcap_f32(const int n, float * x, float s_before, float s_after) { + int i = 0; +#if defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 vs_before = _mm512_set1_ps(2.f*s_before); + __m512 vs_after = _mm512_set1_ps(s_after); + //for (; i + 63 < n; i += 64) { + // __m512 x1 = _mm512_loadu_ps(x + i); + // __m512 x2 = _mm512_loadu_ps(x + i + 16); + // __m512 x3 = _mm512_loadu_ps(x + i + 32); + // __m512 x4 = _mm512_loadu_ps(x + i + 48); + // _mm512_storeu_ps(x + i + 0, ggml_v_softcap(x1, vs_before, vs_after)); + // _mm512_storeu_ps(x + i + 16, ggml_v_softcap(x2, vs_before, vs_after)); + // _mm512_storeu_ps(x + i + 32, ggml_v_softcap(x3, vs_before, vs_after)); + // _mm512_storeu_ps(x + i + 48, ggml_v_softcap(x4, vs_before, vs_after)); + //} + for (; i + 15 < n; i += 16) { + _mm512_storeu_ps(x + i, ggml_v_softcap(_mm512_loadu_ps(x + i), vs_before, vs_after)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + _mm256_storeu_ps(x + i, ggml_v_softcap(_mm256_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + _mm_storeu_ps(x + i, ggml_v_softcap(_mm_loadu_ps(x + i), s_before, s_after)); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + vst1q_f32(x + i, ggml_v_softcap(vld1q_f32(x + i), s_before, s_after)); + } +#endif + for (; i < n; ++i) { + x[i] = s_after*tanhf(x[i]*s_before); + } +} + static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; @@ -2968,6 +3033,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TIMESTEP_EMBEDDING", "ARGSORT", "LEAKY_RELU", + "SOFTCAP", "FLASH_ATTN_EXT", "FLASH_ATTN_BACK", @@ -2995,7 +3061,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3056,6 +3122,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", "leaky_relu(x)", + "k2*tanh(k1*x)", "flash_attn_ext(x)", "flash_attn_back(x)", @@ -3083,7 +3150,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74"); +static_assert(GGML_OP_COUNT == 75, "GGML_OP_COUNT != 75"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5742,6 +5809,50 @@ struct ggml_tensor * ggml_scale_inplace( return ggml_scale_impl(ctx, a, s, true); } +// ggml_softcap + +static struct ggml_tensor * ggml_softcap_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after, + bool inplace) { + GGML_ASSERT(ggml_is_padded_1d(a)); + + bool is_node = false; + + if (a->grad) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + float params[2] = {s_before, s_after}; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_SOFTCAP; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +struct ggml_tensor * ggml_softcap( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after) { + return ggml_softcap_impl(ctx, a, s_before, s_after, false); +} + +struct ggml_tensor * ggml_softcap_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s_before, + float s_after) { + return ggml_softcap_impl(ctx, a, s_before, s_after, true); +} + // ggml_set static struct ggml_tensor * ggml_set_impl( @@ -13324,6 +13435,71 @@ static void ggml_compute_forward_scale( } } +// ggml_compute_forward_softcap + +static void ggml_compute_forward_softcap_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + // scale factor + float val[2]; + memcpy(val, dst->op_params, sizeof(val)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nc = src0->ne[0]; + const int nr = ggml_nrows(src0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + const size_t nb01 = src0->nb[1]; + + const size_t nb1 = dst->nb[1]; + + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + // TODO: better implementation + float * row = (float *) ((char *) dst->data + i1*nb1); + ggml_vec_softcap_f32(nc, row, val[0], val[1]); + //ggml_vec_scale_f32(nc, row, val[0]); + //ggml_vec_tanh_f32(nc, row, row); + //ggml_vec_scale_f32(nc, row, val[1]); + } +} + +static void ggml_compute_forward_softcap( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_softcap_f32(params, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_set static void ggml_compute_forward_set_f32( @@ -17175,6 +17351,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_scale(params, tensor); } break; + case GGML_OP_SOFTCAP: + { + ggml_compute_forward_softcap(params, tensor); + } break; case GGML_OP_SET: { ggml_compute_forward_set(params, tensor); @@ -17917,6 +18097,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_SOFTCAP: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_SET: { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -18928,6 +19112,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = 1; //TODO } break; case GGML_OP_SCALE: + case GGML_OP_SOFTCAP: case GGML_OP_SOFT_MAX: { n_tasks = MIN(n_threads, ggml_nrows(node->src[0])); diff --git a/src/llama.cpp b/src/llama.cpp index ba18a37c..9d989749 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8317,14 +8317,17 @@ static struct ggml_tensor * llm_build_kqv( //try from phi2 //ggml_mul_mat_set_prec(kq, GGML_PREC_F32); - kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); - kq = ggml_scale(ctx, kq, 30); + //kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f)); + //kq = ggml_scale(ctx, kq, 30); + + kq = ggml_softcap(ctx, kq, 0.08838834764831845f/30.0f, 30.f); } if (hparams.attn_soft_cap) { - kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); - kq = ggml_tanh(ctx, kq); - kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); + kq = ggml_softcap(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping, hparams.f_attn_logit_softcapping); + //kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); + //kq = ggml_tanh(ctx, kq); + //kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); } kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); @@ -11935,9 +11938,10 @@ struct llm_build_context { cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); // final logit soft-capping - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); - cur = ggml_tanh(ctx0, cur); - cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + cur = ggml_softcap(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping, hparams.f_final_logit_softcapping); + //cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + //cur = ggml_tanh(ctx0, cur); + //cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); cb(cur, "result_output", -1);