q8_KV: use it in FA on NEON

This commit is contained in:
Iwan Kawrakow
2025-02-18 13:53:01 +02:00
parent 58c13d0574
commit c82f4194c3

View File

@@ -13740,6 +13740,48 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
}
}
template <int nrc_y>
void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
int32x4_t acc[2*nrc_y] = {};
float dy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ix += 8) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < n/16; ++ib) {
auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8y[iy]+16*ib);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
}
}
auto scale1_x = vld1q_f32(dptr+0);
auto scale2_x = vld1q_f32(dptr+4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale_y = vdupq_n_f32(dy[iy]);
auto scale1 = vmulq_f32(scale1_x, scale_y);
auto scale2 = vmulq_f32(scale2_x, scale_y);
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@@ -15827,7 +15869,9 @@ struct FlashQKfp32 {
}
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#else
#ifdef HAVE_FANCY_SIMD
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
@@ -15844,14 +15888,10 @@ struct FlashQKfp32 {
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
#else
#ifdef HAVE_FANCY_SIMD
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_r8_q8_KV<16>, 16);
#endif
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
#ifdef __aarch64__