FA: dedicated mat mul for D = 128 also for ARM_NEON

This commit is contained in:
Iwan Kawrakow
2025-01-19 16:29:03 +01:00
parent e9951656f8
commit 4ecfaaea48

View File

@@ -12063,6 +12063,46 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(n == 128);
int8x16x4_t qx[8];
float32x4_t scales[4];
float32x4_t scales_y[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
for (int k = 0; k < 4; ++k) {
scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d));
qx[2*k+0] = vld1q_s8_x4(iq8[k].qs);
qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto by = (const block_q8_0_x4 *)info.src1_row(iy);
auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d));
scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0);
scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1);
scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2);
scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3);
auto sumf = vdupq_n_f32(0.f);
for (int k = 0; k < 4; ++k) {
auto y = vld1q_s8_x2(by->qs+32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3);
sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi));
}
info.store(ix, iy, sumf);
}
}
}
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
m.funcs[0] = func<Dequantizer, 1>;\
m.funcs[1] = func<Dequantizer, 2>;\
@@ -13645,7 +13685,23 @@ struct FlashQKfp32 {
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
#ifdef __aarch64__
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
if constexpr (D == 128) {
if (q_step >= 64 && nq >= 64) {
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64);
}
else if (q_step >= 32 && nq >= 32) {
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32);
}
else if (q_step >= 16 && nq >= 16) {
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16);
}
else {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq);
}
} else {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
}
//MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
#else
#ifdef HAVE_FANCY_SIMD
if constexpr (D == 128) {