mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
q8_KV: use it in FA on NEON
This commit is contained in:
@@ -13740,6 +13740,48 @@ static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInf
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
int32x4_t acc[2*nrc_y] = {};
|
||||
float dy[nrc_y];
|
||||
const int8_t * q8y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dptr = (const float *)info.src1_row(iy);
|
||||
dy[iy] = dptr[0];
|
||||
q8y[iy] = (const int8_t *)(dptr + 2);
|
||||
}
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto q8x = (const int8_t *)(dptr + 8);
|
||||
for (int ib = 0; ib < n/16; ++ib) {
|
||||
auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
|
||||
auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8y[iy]+16*ib);
|
||||
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
|
||||
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
|
||||
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
|
||||
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
|
||||
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
|
||||
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
|
||||
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
|
||||
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
|
||||
}
|
||||
}
|
||||
auto scale1_x = vld1q_f32(dptr+0);
|
||||
auto scale2_x = vld1q_f32(dptr+4);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto scale_y = vdupq_n_f32(dy[iy]);
|
||||
auto scale1 = vmulq_f32(scale1_x, scale_y);
|
||||
auto scale2 = vmulq_f32(scale2_x, scale_y);
|
||||
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
|
||||
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
|
||||
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<1, block_q8_0_x4> q8(info);
|
||||
@@ -15827,7 +15869,9 @@ struct FlashQKfp32 {
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ8KV<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
|
||||
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
|
||||
@@ -15844,14 +15888,10 @@ struct FlashQKfp32 {
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ8KVR8<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS(mul_mat_qX_0_q8_0<DequantizerQ80, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_r8_q8_KV<16>, 16);
|
||||
#endif
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
|
||||
Reference in New Issue
Block a user