q8_k_r4: NEON

We get PP-512(LLaMA-3.1-8B) = 159.2 t/s.
Compare this to the 128 t/s we have fr Q8_0_R4.
This commit is contained in:
Iwan Kawrakow
2024-12-13 18:57:07 +01:00
parent 6d6d12fc86
commit 89510ccce7

View File

@@ -9228,6 +9228,55 @@ void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
template <int nrc_y>
void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
int32x4_t isum[2*nrc_y] = {};
for (int ib = 0; ib < QK_K/16; ++ib) {
auto q1 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 0);
auto q2 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 64);
for (int k = 0; k < 4; ++k) {
q1.val[k] = veorq_u8(q1.val[k], vdupq_n_u8(0x80));
q2.val[k] = veorq_u8(q2.val[k], vdupq_n_u8(0x80));
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
const float * bsum = (const float *)q8.y[iy][ibl].bsums;
auto m8 = vdupq_n_f32(-128.f*bsum[0]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
@@ -9645,6 +9694,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q8_K_R8:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
expected_Btype = GGML_TYPE_Q8_KR8;
break;
case GGML_TYPE_IQ4_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;