q2_k_r4: NEON

We get PP-512(LLaMA-3.1-8B) = 106.2 t/s.
TG-128 is 36.02 t/s, which is ~10% higher than q2_K_S.
This commit is contained in:
Iwan Kawrakow
2024-12-11 16:52:33 +01:00
parent 2b07aa3f2e
commit ea2de6ee34

View File

@@ -8526,6 +8526,88 @@ IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x
qx[7] = vshrq_n_u8(bits.val[3], 4); // 28..31
}
template <int nrc_y>
void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = vdupq_n_u8(0x0f);
auto m03 = vdupq_n_u8(0x03);
int nbl = n / QK_K;
int8x16_t qx[4];
float32x4_t acc[nrc_y] = {};
int16x8x4_t i16scales;
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
int32x4_t isum[nrc_y] = {};
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto m4 = vmulq_f32(vdupq_n_f32(-1.f), vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d+4)));
for (int is = 0; is < 2; ++is) {
auto sl = vld1q_u8_x2(iq2[ibl].scales + 32*is);
auto m = vshrq_n_u8(sl.val[0], 4);
i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
i16scales.val[1] = vmovl_u8(vget_high_u8(m));
m = vshrq_n_u8(sl.val[1], 4);
i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
i16scales.val[3] = vmovl_u8(vget_high_u8(m));
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vdupq_n_s32(0);
auto bsums = vld1q_s16(q8.y[iy][ibl].bsums + 8*is);
auto b8 = vget_low_s16(bsums);
//auto bsums = q8.load_bsums(iy, ibl);
//auto b8 = vget_low_s16(bsums.val[0]);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[0]), b8, 0);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[0]), b8, 1);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[1]), b8, 2);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[1]), b8, 3);
b8 = vget_high_s16(bsums);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[2]), b8, 0);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[2]), b8, 1);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[3]), b8, 2);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[3]), b8, 3);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(m4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
}
m = vandq_u8(sl.val[0], mf);
i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
i16scales.val[1] = vmovl_u8(vget_high_u8(m));
m = vandq_u8(sl.val[1], mf);
i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
i16scales.val[3] = vmovl_u8(vget_high_u8(m));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[0], m03));
qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 2), m03));
qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 4), m03));
qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 6), m03));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[1], m03));
qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 2), m03));
qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 4), m03));
qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 6), m03));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -9191,6 +9273,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K32;
break;
case GGML_TYPE_Q2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q3_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q3_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;