From b127c6ccedce44aabd293c8f662d04ab78415c66 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Aug 2024 09:56:28 +0300 Subject: [PATCH] WIP: Fusing K*Q and softmax - not working yet --- ggml/src/ggml.c | 117 ++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 310 +++++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.h | 7 + 3 files changed, 430 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 60a89591..98fb4fa1 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12761,6 +12761,8 @@ static void ggml_compute_forward_mul_mat( #if GGML_USE_IQK_MULMAT if (dst->type == GGML_TYPE_F32 && (ne12*ne13)%nth == 0) { + //if (ith == 0) printf("%s(%-10s): %d multiplies with Nx = %d, Ny = %d ne00 = %d\n", __func__, dst->name, + // (int)(ne12*ne13), (int)ne01, (int)ne11, (int)ne00); int counter = 0; for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { @@ -12773,6 +12775,10 @@ static void ggml_compute_forward_mul_mat( } } } +#if IK_PRINT_TIMING + if (ith == 0) printf("%s(%s, 0): %g GFLOP, %g MiB\n", __func__, dst->name, 2e-9*ggml_nelements(dst)*ne00, + 4.*ggml_nelements(dst)/(1024*1024)); +#endif return; } if (dst->type == GGML_TYPE_F32) { @@ -12783,6 +12789,10 @@ static void ggml_compute_forward_mul_mat( src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), ith, nth)) goto IQK_MulMat_Not_Available1; +#if IK_PRINT_TIMING + if (ith == 0) printf("%s(%s, 1): %g GFLOP, %g MiB\n", __func__, dst->name, 2e-9*ggml_nelements(dst)*ne00, + 4.*ggml_nelements(dst)/(1024*1024)); +#endif return; } IQK_MulMat_Not_Available1:; @@ -14302,6 +14312,93 @@ static void ggml_compute_forward_diag_mask_zero( } } +static bool ggml_fused_mul_mat_softmax(const struct ggml_compute_params * params, + struct ggml_tensor * mul_mat, + struct ggml_tensor * soft_max) { + + if (!(mul_mat->src[0]->type == GGML_TYPE_F16 || mul_mat->src[0]->type == GGML_TYPE_F32) || + !(mul_mat->src[1]->type == GGML_TYPE_F16 || mul_mat->src[1]->type == GGML_TYPE_F32) || + !(soft_max->type == GGML_TYPE_F16 ||soft_max->type == GGML_TYPE_F32) || + !ggml_is_contiguous(soft_max) || !ggml_are_same_shape(mul_mat, soft_max)) { + return false; + } + + const int ith = params->ith; + const int nth = params->nth; + + //if (ith == 0) printf("%s: %s = softmax(%s = %s x %s)\n", __func__, soft_max->name, mul_mat->name, mul_mat->src[0]->name, mul_mat->src[1]->name); + + const struct ggml_tensor * dst = mul_mat; + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const enum ggml_type type = src0->type; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne1 == ne11); + GGML_ASSERT(ne2 == ne12); + GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + const int64_t r2 = ne12 / ne02; + const int64_t r3 = ne13 / ne03; + + float op_params[2]; + memcpy(op_params, soft_max->op_params, sizeof(op_params)); + + const uint32_t n_head = ne02; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + const float m0 = powf(2.0f, -(op_params[0] ) / n_head_log2); + const float m1 = powf(2.0f, -(op_params[0] / 2.0f) / n_head_log2); + + //if ((ne12*ne13)%nth == 0) { + // int counter = 0; + // for (int64_t i13 = 0; i13 < ne13; i13++) { + // for (int64_t i12 = 0; i12 < ne12; i12++) { + // if (counter++ % nth == ith) { + // const uint32_t h = i12; + // const float slope = (op_params[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + // if (!iqk_fused_mul_mat_softmax(ne01, ne11, ne00, + // src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), + // src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + // (float *)((char *)soft_max->data + i12*nb2 + i13*nb3), nb1/sizeof(float), + // params->wdata, params->wsize, + // soft_max->src[1]->data, op_params[0], slope, 0, 1)) return false; + // } + // } + // } + //} else { + for (int64_t i13 = 0; i13 < ne13; i13++) { + for (int64_t i12 = 0; i12 < ne12; i12++) { + const uint32_t h = i12; + const float slope = (op_params[1] > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + if (!iqk_fused_mul_mat_softmax(ne01, ne11, ne00, + src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), + src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), + (float *)((char *)soft_max->data + i12*nb2 + i13*nb3), nb1/sizeof(float), + params->wdata, params->wsize, + soft_max->src[1]->data, op_params[0], slope, ith, nth)) { + if (ith == 0) printf("iqk_fused_mul_mat_softmax returned false!\n"); + return false; + } + } + } + //} + + //if (ith == 0) printf(" success!\n"); + return true; +} + // ggml_compute_forward_soft_max static void ggml_compute_forward_soft_max_f32( @@ -14340,6 +14437,10 @@ static void ggml_compute_forward_soft_max_f32( const int nc = src0->ne[0]; const int nr = ggml_nrows(src0); +#if IK_PRINT_TIMING + if (ith == 0) printf("%s: %d x %d, %g MiB\n", __func__, nc, nr, 4.*nc*nr/(1024*1024)); +#endif + // rows per thread const int dr = (nr + nth - 1)/nth; @@ -17357,11 +17458,12 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); + bool result = false; if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return; + return result; } #if IK_PRINT_TIMING @@ -17459,7 +17561,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor); + if (next && next->op == GGML_OP_SOFT_MAX) { + result = ggml_fused_mul_mat_softmax(params, tensor, next); + } else { + ggml_compute_forward_mul_mat(params, tensor); + } } break; case GGML_OP_MUL_MAT_ID: { @@ -17701,6 +17807,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm int64_t t2 = ggml_time_us(); if (params->ith == 0) printf("%s(%s): %d us\n", ggml_op_name(tensor->op), tensor->name, (int)(t2 - t1)); #endif + return result; } //////////////////////////////////////////////////////////////////////////////// @@ -19526,7 +19633,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (ggml_is_noop(node)) continue; - ggml_compute_forward(¶ms, node); + bool skip_next = ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes - 1 ? cgraph->nodes[node_n + 1] : NULL); if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->shared->ec = GGML_STATUS_ABORTED; @@ -19537,6 +19644,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->shared->ec != GGML_STATUS_SUCCESS) { break; } + + if (skip_next) ++node_n; } return 0; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 32ddb3ff..51156a74 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -17,6 +17,7 @@ #include #include +#include // for hardware_destructive_interference_size #if defined IQK_IMPLEMENT @@ -5915,6 +5916,305 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { #endif // __aarch64__ +namespace { + +#if defined(__ARM_NEON) && defined(__aarch64__) +// copy-pasted from Justine Tunney's contribution to llama.cpp +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline float32x4_t v_expf(float32x4_t x) { + const float32x4_t r = vdupq_n_f32(0x1.8p23f); + const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + const float32x4_t n = vsubq_f32(z, r); + const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, + vdupq_n_f32(0x1.7f7d1cp-20f)); + const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + const float32x4_t u = vmulq_f32(b, b); + const float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) + return vfmaq_f32(k, j, k); + const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} +#endif + +#if defined(__AVX512F__) && defined(__AVX512DQ__) + +// copy-pasted from Justine Tunney's contribution to llama.cpp +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline __m512 v_expf(__m512 x) { + const __m512 r = _mm512_set1_ps(0x1.8p23f); + const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + const __m512 n = _mm512_sub_ps(z, r); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + const __mmask16 d = + _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} +#endif + +#if defined(__AVX2__) && defined(__FMA__) + +// adapted from arm limited optimized routine +// the maximum error is 1.45358 plus 0.5 ulps +// numbers above 88.38 will flush to infinity +// numbers beneath -103.97 will flush to zero +inline __m256 v_expf(__m256 x) { + const __m256 r = _mm256_set1_ps(0x1.8p23f); + const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + const __m256 n = _mm256_sub_ps(z, r); + const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), + _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + const __m256 k = _mm256_castsi256_ps( + _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + const __m256i c = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(126), _CMP_GT_OQ)); + const __m256 u = _mm256_mul_ps(b, b); + const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, + _mm256_set1_ps(0x1.573e2ep-5f)), u, + _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, + _mm256_set1_ps(0x1.fffdb6p-2f))), + u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) + return _mm256_fmadd_ps(j, k, k); + const __m256i g = _mm256_and_si256( + _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), + _mm256_set1_epi32(0x82000000u)); + const __m256 s1 = + _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + const __m256i d = _mm256_castps_si256( + _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), + _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); +} +#endif + +inline float prepare_softmax(int nc, float * sp, float scale, float slope, const char * mp, bool use_fp16) { + __m512 vscale = _mm512_set1_ps(scale); + __m512 vmax = _mm512_set1_ps(-INFINITY); + float scalar_max = -INFINITY; + if (mp) { + __m512 vslope = _mm512_set1_ps(slope); + if (use_fp16) { + const ggml_fp16_t * mp_f16 = (const ggml_fp16_t *)mp; + __m512 vmax1 = _mm512_set1_ps(-INFINITY); + for (int i = 0; i < nc/32; ++i) { + const __m512 m1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+0)); + const __m512 m2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)mp_f16 + 2*i+1)); + const __m512 x1 = _mm512_loadu_ps(sp + 32*i + 0); + const __m512 x2 = _mm512_loadu_ps(sp + 32*i + 16); + const __m512 y1 = _mm512_fmadd_ps(vslope, m1, _mm512_mul_ps(vscale, x1)); + const __m512 y2 = _mm512_fmadd_ps(vslope, m2, _mm512_mul_ps(vscale, x2)); + vmax = _mm512_max_ps(vmax , y1); + vmax1 = _mm512_max_ps(vmax1, y2); + _mm512_storeu_ps(sp + 32*i + 0, y1); + _mm512_storeu_ps(sp + 32*i + 16, y2); + } + vmax = _mm512_max_ps(vmax, vmax1); + for (int i = 32*(nc/32); i < nc; ++i) { + sp[i] = scale*sp[i] + slope*GGML_FP16_TO_FP32(mp_f16[i]); + scalar_max = std::max(scalar_max, sp[i]); + } + } else { + const float * mp_f32 = (const float *)mp; + for (int i = 0; i < nc/16; ++i) { + const __m512 m = _mm512_loadu_ps(mp_f32 + 16*i); + const __m512 x = _mm512_loadu_ps(sp + 16*i); + const __m512 y = _mm512_fmadd_ps(vslope, m, _mm512_mul_ps(vscale, x)); + vmax = _mm512_max_ps(vmax, y); + _mm512_storeu_ps(sp + 16*i, y); + } + for (int i = 16*(nc/16); i < nc; ++i) { + sp[i] = scale*sp[i] + slope*mp_f32[i]; + scalar_max = std::max(scalar_max, sp[i]); + } + } + } else { + for (int i = 0; i < nc/16; ++i) { + const __m512 x = _mm512_loadu_ps(sp + 16*i); + const __m512 y = _mm512_mul_ps(vscale, x); + vmax = _mm512_max_ps(vmax, y); + _mm512_storeu_ps(sp + 16*i, y); + } + for (int i = 16*(nc/16); i < nc; ++i) { + sp[i] = scale*sp[i]; + scalar_max = std::max(scalar_max, sp[i]); + } + } + float vector_max = _mm512_reduce_max_ps(vmax); + return std::max(scalar_max, vector_max); +} + +inline float do_soft_max(int nc, float * sp, float max) { + __m512 vmax = _mm512_set1_ps(-max); + __m512 vsum = _mm512_setzero_ps(); + for (int i = 0; i < nc/16; ++i) { + auto x = _mm512_loadu_ps(sp + 16*i); + auto y = v_expf(_mm512_add_ps(x, vmax)); + vsum = _mm512_add_ps(vsum, y); + _mm512_storeu_ps(sp + 16*i, y); + } + float sum = _mm512_reduce_add_ps(vsum); + for (int i = 16*(nc/16); i < nc; ++i) { + float y = expf(sp[i] - max); + sum += y; + sp[i] = y; + } + return sum; +} + +inline void do_scale(int nc, const float * sp, float * dp, float scale) { + auto vscale = _mm512_set1_ps(scale); + for (int i = 0; i < nc/16; ++i) { + auto x = _mm512_loadu_ps(sp + 16*i); + auto y = _mm512_mul_ps(x, vscale); + _mm512_storeu_ps(dp + 16*i, y); + } + for (int i = 16*(nc/16); i < nc; ++i) { + dp[i] = sp[i]*scale; + } +} + +void softmax_extended(int nc, float * sp, float * dp, float scale, float slope, const char * mp, bool mask_is_fp16) { + auto max = prepare_softmax(nc, sp, scale, slope, mp, mask_is_fp16); + auto sum = do_soft_max(nc, sp, max); + do_scale(nc, sp, dp, 1.f/sum); +} + +} + +bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00, + int int_typeA, const void * A, long strideA, + int int_typeB, const void * B, long strideB, + float * C, long stride_C, + char * work_buffer, long work_size, + const char * mask, float scale, float slope, + int ith, int nth) { + + constexpr int k_y_step = 5; // TODO: make this CPU dependent + +#if defined(__cpp_lib_hardware_interference_size) + constexpr int k_cache_line = std::hardware_destructive_interference_size; +#else + constexpr int k_cache_line = 64; +#endif + + auto typeA = ggml_type(int_typeA); + auto typeB = ggml_type(int_typeB); + + if (!(typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) || !(typeB == GGML_TYPE_F16 || typeB == GGML_TYPE_F32)) { + return false; + } + + int ny_per_thread = (Ny + nth - 1)/nth; + int first_y = ny_per_thread*ith; + if (first_y >= Ny) { + return true; + } + int n_cols_per_iter = std::min(ny_per_thread, k_y_step); + int needed_work_size_per_thread = k_cache_line*((n_cols_per_iter*Nx*sizeof(float) + k_cache_line - 1)/k_cache_line); + int needed_work_size = needed_work_size_per_thread*nth; + if (needed_work_size > work_size) { + return false; + } + + std::array funcs; + if (typeA == GGML_TYPE_F16) { + if (typeB == GGML_TYPE_F16) { + funcs[0] = mul_mat_fX_fY_T<1, ggml_half, ggml_half>; + funcs[1] = mul_mat_fX_fY_T<2, ggml_half, ggml_half>; + funcs[2] = mul_mat_fX_fY_T<3, ggml_half, ggml_half>; + funcs[3] = mul_mat_fX_fY_T<4, ggml_half, ggml_half>; + funcs[4] = mul_mat_fX_fY_T<5, ggml_half, ggml_half>; + } else { + funcs[0] = mul_mat_fX_fY_T<1, ggml_half, float>; + funcs[1] = mul_mat_fX_fY_T<2, ggml_half, float>; + funcs[2] = mul_mat_fX_fY_T<3, ggml_half, float>; + funcs[3] = mul_mat_fX_fY_T<4, ggml_half, float>; + funcs[4] = mul_mat_fX_fY_T<5, ggml_half, float>; + } + } else { + if (typeB == GGML_TYPE_F16) { + funcs[0] = mul_mat_fX_fY_T<1, float, ggml_half>; + funcs[1] = mul_mat_fX_fY_T<2, float, ggml_half>; + funcs[2] = mul_mat_fX_fY_T<3, float, ggml_half>; + funcs[3] = mul_mat_fX_fY_T<4, float, ggml_half>; + funcs[4] = mul_mat_fX_fY_T<5, float, ggml_half>; + } else { + funcs[0] = mul_mat_fX_fY_T<1, float, float>; + funcs[1] = mul_mat_fX_fY_T<2, float, float>; + funcs[2] = mul_mat_fX_fY_T<3, float, float>; + funcs[3] = mul_mat_fX_fY_T<4, float, float>; + funcs[4] = mul_mat_fX_fY_T<5, float, float>; + } + } + + auto row_size_qx = strideA*ggml_type_size(typeA); + auto row_size_qy = strideB*ggml_type_size(typeB); + + DataInfo info{(float *)(work_buffer + needed_work_size_per_thread*ith), (const char *)B + first_y*row_size_qy, + Nx*sizeof(float), row_size_qy, 0, 1, nullptr, 0}; + + C += first_y*stride_C; + + const char * mp = mask ? mask + first_y*Nx*sizeof(ggml_half) : nullptr; + + int n_step = (ny_per_thread + k_y_step - 1)/k_y_step; + for (int i_step = 1; i_step <= n_step; ++i_step) { + int this_ny = i_step*k_y_step <= ny_per_thread ? k_y_step : ny_per_thread - (i_step - 1)*k_y_step; + funcs[this_ny-1](ne00, A, row_size_qx, info, Nx); + // Now we need to compute the softmax and store the result in C + for (int iy = 0; iy < this_ny; ++iy) { + softmax_extended(Nx, info.s + iy*Nx, C, scale, slope, mp, true); + C += stride_C; + if (mp) mp += Nx*sizeof(ggml_half); + } + info.cy += row_size_qy*this_ny; + } + + return true; +} + #else // IQK_IMPLEMENT bool iqk_mul_mat(int, long, long, long, int, const void *, long, int, const void *, long, float *, long, int, int) { @@ -5926,4 +6226,14 @@ bool iqk_mul_mat_moe(long, long, long, int, int, const void *, long, int, const return false; } +bool iqk_fused_mul_mat_softmax(long, long, long, + int, const void *, long, + int, const void *, long, + float *, long, + char *, long, + const char *, float, float, uint32_t, + int, int) { + return false; +} + #endif diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 6bed5f5a..956ee8f7 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -21,6 +21,13 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeB, const void * B, long strideB, float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth); +bool iqk_fused_mul_mat_softmax(long Nx, long Ny, long ne00, + int typeA, const void * A, long strideA, + int typeB, const void * B, long strideB, + float * C, long stride_C, + char * work_buffer, long work_size, + const char * mask, float scale, float slope, + int ith, int nth); #ifdef __cplusplus }