mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Refactor iqk: FA refactored (NEON)
This commit is contained in:
@@ -1224,120 +1224,43 @@ struct FlashQKfp32 {
|
||||
static_assert(k_step%F16::block_size == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
#ifdef __AVX2__
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
#ifdef __AVX2__
|
||||
constexpr int nrc_k = 8;
|
||||
#else
|
||||
constexpr int nrc_k = 8;
|
||||
#endif
|
||||
static_assert(k_step%nrc_k == 0);
|
||||
#endif
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
iqk_gemm_default_floats(D, q_step, kh.block, kh.stride, info, k_step);
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 6;
|
||||
constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if constexpr (qrem > 0) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<qrem, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<qrem, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
}
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX2__
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
#else
|
||||
float32x4_t vk[k_step/4];
|
||||
#endif
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
#ifdef __AVX2__
|
||||
constexpr int nrc_k = 8;
|
||||
static_assert(k_step%nrc_k == 0);
|
||||
#endif
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
iqk_gemm_default_floats(D, nq, kh.block, kh.stride, info, k_step);
|
||||
#ifdef __AVX2__
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
#else
|
||||
float32x4_t vk[k_step/4];
|
||||
#endif
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 6;
|
||||
constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
const int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
switch (qrem) {
|
||||
case 0: break;
|
||||
case 1: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
} break;
|
||||
case 3: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
} break;
|
||||
}
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __aarch64__
|
||||
static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
|
||||
|
||||
@@ -1008,6 +1008,41 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t,
|
||||
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <int nrc_q>
|
||||
inline void mm_helper(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) {
|
||||
constexpr int nrc_k = 6;
|
||||
int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, cx, bx, ik*nrc_k, info);
|
||||
}
|
||||
if (krem > 0) {
|
||||
switch (krem) {
|
||||
case 1: mul_mat_f16_f16_NxN<nrc_q, 1, true>(D, cx, bx, k_step - krem, info); break;
|
||||
case 2: mul_mat_f16_f16_NxN<nrc_q, 2, true>(D, cx, bx, k_step - krem, info); break;
|
||||
case 3: mul_mat_f16_f16_NxN<nrc_q, 3, true>(D, cx, bx, k_step - krem, info); break;
|
||||
case 4: mul_mat_f16_f16_NxN<nrc_q, 4, true>(D, cx, bx, k_step - krem, info); break;
|
||||
default: mul_mat_f16_f16_NxN<nrc_q, 5, true>(D, cx, bx, k_step - krem, info); break;
|
||||
}
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_gemm_default_floats(int D, int nq, const char * cx, size_t bx, DataInfo& info, int k_step) {
|
||||
constexpr int nrc_q = 4;
|
||||
mm_helper<nrc_q>(D, nq, cx, bx, info, k_step);
|
||||
if (int qrem = nq - nrc_q*(nq/nrc_q); qrem > 0) {
|
||||
switch (qrem) {
|
||||
case 1: mm_helper<1>(D, nq, cx, bx, info, k_step);
|
||||
case 2: mm_helper<2>(D, nq, cx, bx, info, k_step);
|
||||
default: mm_helper<3>(D, nq, cx, bx, info, k_step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user