mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
q8_k_r4: NEON
We get PP-512(LLaMA-3.1-8B) = 159.2 t/s. Compare this to the 128 t/s we have fr Q8_0_R4.
This commit is contained in:
@@ -9228,6 +9228,55 @@ void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
int nbl = n / QK_K;
|
||||
float32x4_t acc[2*nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
|
||||
auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
|
||||
int32x4_t isum[2*nrc_y] = {};
|
||||
for (int ib = 0; ib < QK_K/16; ++ib) {
|
||||
auto q1 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 0);
|
||||
auto q2 = vld1q_u8_x4(iq8[ibl].qs + 128*ib + 64);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
q1.val[k] = veorq_u8(q1.val[k], vdupq_n_u8(0x80));
|
||||
q2.val[k] = veorq_u8(q2.val[k], vdupq_n_u8(0x80));
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
|
||||
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
|
||||
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
|
||||
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
|
||||
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
|
||||
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
|
||||
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
|
||||
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
|
||||
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
|
||||
const float * bsum = (const float *)q8.y[iy][ibl].bsums;
|
||||
auto m8 = vdupq_n_f32(-128.f*bsum[0]);
|
||||
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
|
||||
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
|
||||
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
|
||||
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix+0, iy, acc[2*iy+0]);
|
||||
info.store(ix+4, iy, acc[2*iy+1]);
|
||||
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<1, block_q8_0_x4> q8(info);
|
||||
@@ -9645,6 +9694,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q6_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_Q8_K_R8:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_k_r8_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_KR8;
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
|
||||
Reference in New Issue
Block a user