diff --git a/ggml/src/iqk/fa/iqk_fa_templates.h b/ggml/src/iqk/fa/iqk_fa_templates.h index 3cbdddd6..6de2acea 100644 --- a/ggml/src/iqk/fa/iqk_fa_templates.h +++ b/ggml/src/iqk/fa/iqk_fa_templates.h @@ -1224,120 +1224,43 @@ struct FlashQKfp32 { static_assert(k_step%F16::block_size == 0); static_assert(q_step <= 4 || q_step%4 == 0); -#ifdef __AVX2__ template static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { -#ifdef HAVE_FANCY_SIMD +#ifdef __AVX2__ constexpr int nrc_k = 8; -#else - constexpr int nrc_k = 8; -#endif static_assert(k_step%nrc_k == 0); +#endif DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; iqk_gemm_default_floats(D, q_step, kh.block, kh.stride, info, k_step); - F16::Data vk[k_step/F16::block_size]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk, mask + stride_m*j); - } - } -#else - template - static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS& fms) { - constexpr int nrc_q = 4; - constexpr int nrc_k = 6; - constexpr int qrem = q_step - nrc_q*(q_step/nrc_q); - constexpr int krem = k_step - nrc_k*(k_step/nrc_k); - DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; - for (int iq = 0; iq < q_step/nrc_q; ++iq) { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); - } - if constexpr (krem > 0) { - mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); - } - info.cur_y += nrc_q; - } - if constexpr (qrem > 0) { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); - } - if constexpr (krem > 0) { - mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); - } - } - float32x4_t vk[k_step/4]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk, mask + stride_m*j); - } - } -#endif - #ifdef __AVX2__ + F16::Data vk[k_step/F16::block_size]; +#else + float32x4_t vk[k_step/4]; +#endif + for (int j = 0; j < q_step; ++j) { + fms.update_M_S(j, vk, mask + stride_m*j); + } + } + template static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { +#ifdef __AVX2__ constexpr int nrc_k = 8; static_assert(k_step%nrc_k == 0); +#endif DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; iqk_gemm_default_floats(D, nq, kh.block, kh.stride, info, k_step); +#ifdef __AVX2__ F16::Data vk[k_step/F16::block_size]; +#else + float32x4_t vk[k_step/4]; +#endif for (int j = 0; j < nq; ++j) { fms.update_M_S(j, vk, mask + stride_m*j); } } -#else - template - static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, - FlashMS& fms) { - constexpr int nrc_q = 4; - constexpr int nrc_k = 6; - constexpr int krem = k_step - nrc_k*(k_step/nrc_k); - const int qrem = q_step - nrc_q*(q_step/nrc_q); - DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr}; - for (int iq = 0; iq < nq/nrc_q; ++iq) { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_f16_f16_NxN(D, kh.block, kh.stride, ik*nrc_k, info); - } - if constexpr (krem > 0) { - mul_mat_f16_f16_NxN(D, kh.block, kh.stride, k_step - krem, info); - } - info.cur_y += nrc_q; - } - switch (qrem) { - case 0: break; - case 1: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); - } - if constexpr (krem > 0) { - mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info); - } - } break; - case 2: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); - } - if constexpr (krem > 0) { - mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info); - } - } break; - case 3: { - for (int ik = 0; ik < k_step/nrc_k; ++ik) { - mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info); - } - if constexpr (krem > 0) { - mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info); - } - } break; - } - float32x4_t vk[k_step/4]; - for (int j = 0; j < q_step; ++j) { - fms.update_M_S(j, vk, mask + stride_m*j); - } - } -#endif #ifdef __aarch64__ static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) { diff --git a/ggml/src/iqk/iqk_gemm_floats.cpp b/ggml/src/iqk/iqk_gemm_floats.cpp index 664c734b..5165eb98 100644 --- a/ggml/src/iqk/iqk_gemm_floats.cpp +++ b/ggml/src/iqk/iqk_gemm_floats.cpp @@ -1008,6 +1008,41 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array +inline void mm_helper(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) { + constexpr int nrc_k = 6; + int krem = k_step - nrc_k*(k_step/nrc_k); + for (int iq = 0; iq < nq/nrc_q; ++iq) { + for (int ik = 0; ik < k_step/nrc_k; ++ik) { + mul_mat_f16_f16_NxN(D, cx, bx, ik*nrc_k, info); + } + if (krem > 0) { + switch (krem) { + case 1: mul_mat_f16_f16_NxN(D, cx, bx, k_step - krem, info); break; + case 2: mul_mat_f16_f16_NxN(D, cx, bx, k_step - krem, info); break; + case 3: mul_mat_f16_f16_NxN(D, cx, bx, k_step - krem, info); break; + case 4: mul_mat_f16_f16_NxN(D, cx, bx, k_step - krem, info); break; + default: mul_mat_f16_f16_NxN(D, cx, bx, k_step - krem, info); break; + } + } + info.cur_y += nrc_q; + } +} +} + +void iqk_gemm_default_floats(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) { + constexpr int nrc_q = 4; + mm_helper(D, nq, cx, bx, info, k_step); + if (int qrem = nq - nrc_q*(nq/nrc_q); qrem > 0) { + switch (qrem) { + case 1: mm_helper<1>(D, nq, cx, bx, info, k_step); + case 2: mm_helper<2>(D, nq, cx, bx, info, k_step); + default: mm_helper<3>(D, nq, cx, bx, info, k_step); + } + } +} + #endif #endif