diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5f39b587..47b5542a 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6971,6 +6971,7 @@ struct QFBase { static inline Data load4Floats(const Float * x) { return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); } + static inline Data add(Data x, Data y) { return _mm512_add_ps(x, y); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc); acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); @@ -6999,6 +7000,7 @@ struct QFBase { static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } + static inline Data add(Data x, Data y) { return _mm256_add_ps(x, y); } static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); @@ -7110,6 +7112,38 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } +template +void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + static_assert(Qy::nrc == 1); + static_assert(Qx::nrc == 1); + int nb = n/QFBase::k_step; + int nb4 = n/4; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Acc acc[4] = {}; + for (int i = 0; i < nb/4; ++i) { + for (int k = 0; k < 4; ++k) { + auto yv = y.load1(0, 4*i+k); + auto xv = x.load1(0, 4*i+k); + acc[k] = QFBase::acc(acc[k], yv, xv); + } + } + for (int i = 4*(nb/4); i < nb; ++i) { + auto yv = y.load1(0, i); + auto xv = x.load1(0, i); + acc[0] = QFBase::acc(acc[0], yv, xv); + } + for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) { + auto yv = y.load_tail(0, i); + auto xv = x.load_tail(0, i); + acc[0] = QFBase::acc(acc[0], yv, xv); + } + acc[0] = QFBase::add(acc[0], acc[1]); + acc[2] = QFBase::add(acc[2], acc[3]); + acc[0] = QFBase::add(acc[0], acc[2]); + info.store(ix0, 0, QFBase::hsum(acc[0])); +} + template inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { int nb = n/QFBase::k_step; @@ -7165,12 +7199,17 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. template void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const char * cx = (const char *)vx; + if constexpr (nrc_y == 1) { + for (int ix = 0; ix < nrc_x; ++ix) { + mul_mat_Qx_Qy_Mx1, QFT>(n, cx, bx, ix, info); + } + } else { #ifdef __AVX512F__ constexpr int k_nx = 5; #else constexpr int k_nx = 2; #endif - const char * cx = (const char *)vx; for (int ix = 0; ix < nrc_x/k_nx; ++ix) { mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, ix*k_nx, info); } @@ -7185,6 +7224,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in case 4: mul_mat_Qx_Qy_MxN, QFT>(n, cx, bx, last_x, info); break; #endif } + } } #ifdef __AVX512BF16__