mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
q3_k_r4: NEON
We get PP-512(LLaMA-3.1-8B) = 106.9 t/s. This is 1.93X faster than q3_K_S!
This commit is contained in:
@@ -8364,6 +8364,73 @@ IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x
|
||||
qx[7] = vshrq_n_u8(bits.val[3], 4); // 28..31
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto mf = vdupq_n_u8(0x0f);
|
||||
auto m30 = vdupq_n_u8(0x30);
|
||||
auto m32 = vdupq_n_s8(-32);
|
||||
auto m03 = vdupq_n_u8(0x03);
|
||||
auto m04 = vdupq_n_u8(0x04);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
|
||||
auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
|
||||
auto sh = vld1q_u8(iq3[ibl].scales_h);
|
||||
i8scales.val[0] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30)));
|
||||
i8scales.val[1] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(vshlq_n_u8(sh, 2), m30)));
|
||||
i8scales.val[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m30)));
|
||||
i8scales.val[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30)));
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
|
||||
auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
|
||||
hbits = veorq_u8(hbits, vdupq_n_u8(0xff));
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[0], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 2))));
|
||||
qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 1))));
|
||||
qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, hbits)));
|
||||
qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 1))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[1], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 2))));
|
||||
qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 3))));
|
||||
qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 4))));
|
||||
qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 5))));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
@@ -8965,6 +9032,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_Q3_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q3_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_Q4_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q4_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K32;
|
||||
|
||||
Reference in New Issue
Block a user