Make q8_0_r4 work with tensor row sizes that are not a multiple of 128

.. on NEON
This commit is contained in:
Iwan Kawrakow
2025-01-28 19:26:59 +02:00
parent 3c974f5076
commit d3545680b9

View File

@@ -12439,12 +12439,32 @@ struct Q6_0_R4_Dequantizer {
const int8x16_t m32 = vdupq_n_s8(-32);
};
inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
auto y = vld1q_s8_x2(qy);
sumi1 = sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
}
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
GGML_ASSERT(nb%4 == 0);
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
@@ -12459,32 +12479,29 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi1 = vdupq_n_s32(0);
auto sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);