8-rows interleaved q8_0 (NEON)

PP-512 is slightly better (138 t/s vs 132.5 t/s), TG-128 is about the
same.
This commit is contained in:
Iwan Kawrakow
2025-01-26 09:43:22 +02:00
parent 45075579ef
commit 4de6088eef

View File

@@ -12176,41 +12176,54 @@ struct Q6_0_R4_Dequantizer {
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
float32x4_t acc[nrc_y] = {};
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[4*ib4+k].d));
auto qx1 = vld1q_s8_x4(iq8[4*ib4+k].qs);
auto qx2 = vld1q_s8_x4(iq8[4*ib4+k].qs+64);
auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx1.val[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx1.val[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx1.val[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx1.val[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx2.val[0], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx2.val[1], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx2.val[2], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx2.val[3], y.val[1], 3);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
auto sumi1 = vdupq_n_s32(0);
auto sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}