diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fbdd4789..070b79d3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6895,6 +6895,25 @@ struct QFBase { static inline Data load4Floats(const Float * x) { return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0); } + static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { + acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc); + acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline Acc acc_r4_first(const Data * xv, const Data& yv) { + auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00)); + acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline __m128 hsum_r4(Acc acc) { + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3)); + return _mm_add_ps(sum1, sum2); + } #else constexpr static int k_step = 8; using Data = __m256; @@ -6904,12 +6923,29 @@ struct QFBase { static inline Acc acc(Acc prev, const Data& y, const Data& x) { return _mm256_fmadd_ps(y, x, prev); } + static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) { + acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc); + acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } + static inline Acc acc_r4_first(const Data * xv, const Data& yv) { + auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00)); + acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc); + acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc); + acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc); + return acc; + } static inline Acc acc_first(const Data& y, const Data& x) { return _mm256_mul_ps(y, x); } static inline float hsum(Acc acc) { return hsum_float_8(acc); } + static inline __m128 hsum_r4(Acc acc) { + return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); + } template static inline Data load4Floats(const Float * x) { return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0); @@ -6928,6 +6964,31 @@ template struct QFT final : public QFBase { } IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); } IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); } + IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const { + xv[0] = load1(ix+0, i); + xv[1] = load1(ix+1, i); + xv[2] = load1(ix+2, i); + xv[3] = load1(ix+3, i); +#ifdef HAVE_FANCY_SIMD + auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]); + auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]); + auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]); + auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]); + xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); + xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1))); + xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); + xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3))); +#else + auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]); + auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]); + auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]); + auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]); + xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); + xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1))); + xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); + xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3))); +#endif + } const Float * y[nrc]; }; @@ -6973,6 +7034,56 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); } +template +inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) { + int nb = n/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc]; + auto yv = y.load1(0, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, 0); + acc[ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, 0); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]); + } + for (int i = 1; i < nb; ++i) { + yv = y.load1(0, i); + for (int ix = 0; ix < Qx::nrc; ++ix) { + xv[ix] = x.load1(ix, i); + acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]); + } + for (int iy = 1; iy < Qy::nrc; ++iy) { + yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix])); +} + +template +inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) { + static_assert(Qx::nrc%4 == 0); + int nb = D/QFBase::k_step; + Qy y(info); + Qx x(cx + ix0*bx, bx); + QFBase::Data xv[Qx::nrc]; + QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {}; + for (int i = 0; i < nb; ++i) { + for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix); + for (int iy = 0; iy < Qy::nrc; ++iy) { + auto yv = y.load1(iy, i); + for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv); + } + } + for (int iy = 0; iy < Qy::nrc; ++iy) { + for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy])); + } +} + // This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done // in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in // f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now. @@ -13066,147 +13177,194 @@ struct FlashQKfp32 { static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); -#ifdef __aarch64__ - constexpr static bool is_small_head = false; -#else - constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size; -#endif - - template , typename q_float> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, - F16::Data * qv, F16::Data * vk, FlashMS& fms) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY; - if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) { - return; - } - auto qr = q + m1*stride_q; - for (int i = 0; i < D/F16::block_size; ++i) qv[i] = F16::load(qr + F16::block_size*i); - if (mp[l1+0] != fms.h_inf) { - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i], qv[i]); - fms.cache[k_step*m1 + l1 + 0] = F16::reduce_add(vsum); - } - if (mp[l1+1] != fms.h_inf) { - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i+D/16], qv[i]); - fms.cache[k_step*m1 + l1 + 1] = F16::reduce_add(vsum); - } - } - - template , typename q_float> - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, - F16::Data * vk, FlashMS& fms) { - // q index is q_step*i1 + m1 - // k index is k_step*k1 + l1 - const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1); - if (mp[l1] == fms.h_inf) { - fms.cache[k_step*m1 + l1] = -INFINITY; - return; - } - auto qr = q + m1*stride_q; - auto vsum = F16::zero(); - for (int i = 0; i < D/F16::block_size; ++i) { - vsum = F16::fmadd(vsum, vk[i], F16::load(qr + F16::block_size*i)); - } - fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum); - } - - template , typename q_float> - static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS& fms) { - F16::Data qv[D/F16::block_size]; - F16::Data vk[D/(F16::block_size/2)]; - for (int l1 = 0; l1 < k_step; l1 += 2) { - kh.load_2(l1, vk); - for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); - } - } - } - - template , typename q_float> - static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, - const q_float * q, const char * mask, FlashMS& fms) { - F16::Data vk[D/F16::block_size]; - for (int l1 = 0; l1 < k_step; ++l1) { - kh.load(l1, vk); - for (int m1 = 0; m1 < q_step; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); - } - } - } - - template , typename q_float> - static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS& fms) { - F16::Data qv[D/F16::block_size]; - F16::Data vk[D/(F16::block_size/2)]; - for (int l1 = 0; l1 < k_step; l1 += 2) { - kh.load_2(l1, vk); - for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms); - } - } - } - - template , typename q_float> - static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, - const q_float * q, const char * mask, FlashMS& fms) { - F16::Data vk[D/F16::block_size]; - for (int l1 = 0; l1 < k_step; ++l1) { - kh.load(l1, vk); - for (int m1 = 0; m1 < nq; ++m1) { - mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms); - } - } - } - +#ifdef __AVX2__ template static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { - if constexpr (is_small_head) { - mult_mask_kq(kh, stride_q, stride_m, q, mask, fms); - } - else { - mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms); - } -#ifdef __aarch64__ - float32x4_t vk[k_step/4]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); - } +#ifdef HAVE_FANCY_SIMD + constexpr int nrc_q = 8; + constexpr int nrc_k = 8; #else + // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 + constexpr int nrc_q = 4; + constexpr int nrc_k = 8; +#endif + constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; + } + if constexpr (qrem > 0) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_Qx_Qy_MxN_fa, QFT>(D, kh.block, kh.stride, k_step - krem, info); + } + } F16::Data vk[k_step/F16::block_size]; for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk); + fms.update_M_S(j, vk, mask + stride_m*j); } -#endif } +#else + template + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, + FlashMS& fms) { + constexpr int nrc_q = 4; + constexpr int nrc_k = 6; + constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < q_step/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; + } + if constexpr (qrem > 0) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); + } + } + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } + } +#endif +#ifdef __AVX2__ template static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { - if constexpr (is_small_head) { - mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms); - } - else { - mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms); - } -#ifdef __aarch64__ - float32x4_t vk[k_step/4]; - for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); - } +#ifdef HAVE_FANCY_SIMD + constexpr int nrc_q = 8; + constexpr int nrc_k = 8; #else - F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < nq; ++j) { - fms.update_M_S(j, vk); - } + // somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4 + constexpr int nrc_q = 4; + constexpr int nrc_k = 8; #endif + static_assert(k_step%nrc_k == 0); + int qrem = q_step - nrc_q*(q_step/nrc_q); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + info.cur_y += nrc_q; + } + if (qrem > 0) { + switch (qrem) { + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; +#ifdef HAVE_FANCY_SIMD + case 4: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 5: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 6: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; + case 7: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_Qx_Qy_MxN_fa4, QFT>(D, kh.block, kh.stride, ik*nrc_k, info); + } + } break; +#endif + } + } + F16::Data vk[k_step/F16::block_size]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } } +#else + template + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, + FlashMS& fms) { + constexpr int nrc_q = 4; + constexpr int nrc_k = 6; + constexpr int krem = k_step - nrc_k*(k_step/nrc_k); + const int qrem = q_step - nrc_q*(q_step/nrc_q); + DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); + } + info.cur_y += nrc_q; + } + switch (qrem) { + case 0: break; + case 1: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + case 2: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + case 3: { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); + } + if constexpr (krem > 0) { + mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info); + } + } break; + } + float32x4_t vk[k_step/4]; + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } + } +#endif #ifdef __aarch64__ static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) { @@ -13827,7 +13985,7 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str const float * q, const char * mask, float scale, float softcap, float * qkv) { #if defined __AVX2__ - constexpr bool kUseLargeStepsQ = !std::is_same_v>; + constexpr bool kUseLargeStepsQ = true; //!std::is_same_v>; #else constexpr bool kUseLargeStepsQ = true; #endif