From 8d47523e7e3c2e4c2cdd5212d0a05d7424bf4eb1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 4 Sep 2024 19:37:05 +0300 Subject: [PATCH] Improve TG speed (when not memory bound) --- ggml/src/iqk/iqk_mul_mat.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index bf6fd6ce..32455e09 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3336,7 +3336,7 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, } template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - constexpr int k_nx = 5; + constexpr int k_nx = nrc_y <= 2 ? 8 : 5; const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { mul_mat_Qx_Qy_MxN(n, cx, bx, ix*k_nx, info); @@ -3344,6 +3344,14 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in int last_x = k_nx*(nrc_x/k_nx); if (last_x == nrc_x) return; int nx = nrc_x - last_x; + if constexpr (nrc_y <= 2) { + if (nx >= 4) { + mul_mat_Qx_Qy_MxN(n, cx, bx, last_x, info); + last_x += 4; + if (last_x == nrc_x) return; + nx = nrc_x - last_x; + } + } switch (nx) { case 1: mul_mat_Qx_Qy_MxN(n, cx, bx, last_x, info); break; case 2: mul_mat_Qx_Qy_MxN(n, cx, bx, last_x, info); break;