From 52ad57b04275d652610be5578a00dd3e9e2cbfca Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 23 Jun 2025 18:38:31 +0200 Subject: [PATCH] q3_K 45.7 t/s -> 170.8 t/s. q3_k_r4 is at 110.3 t/s. --- ggml/src/iqk/iqk_gemm_kquants.cpp | 108 +++++++++++++++++++++++++++++- ggml/src/iqk/iqk_mul_mat.cpp | 1 + 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 57686613..2caa2604 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -3797,12 +3797,118 @@ void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int } } +void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q3_K * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + uint32_t block[8]; + int8x16x2_t xv[8]; + uint32_t aux32[4]; + + auto ml = vdupq_n_s8(0x03); + auto mh = vdupq_n_s8(0x04); + + union { int8x16_t vec; int8_t val[16]; } helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto sc16 = (const uint16_t *)x8[k][i].scales; + uint32_t aux0 = sc16[0] | (sc16[1] << 16); + uint32_t aux1 = sc16[2] | (sc16[3] << 16); + uint32_t aux2 = sc16[4] | (sc16[5] << 16); + aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030); + aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030); + aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030); + aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030); + helper.vec = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32)); + auto hbits = vld1q_u8_x2(x8[k][i].hmask); + auto max_i16 = vdupq_n_u16(0); + for (int i128 = 0; i128 < 2; ++i128) { + auto q2bits = vld1q_u8_x2(x8[k][i].qs + 32*i128); + xv[4*i128+0].val[0] = vsubq_s8(vorrq_s8(vandq_s8(q2bits.val[0], ml), vandq_s8(vshlq_n_u8(hbits.val[0], 2), mh)), mh); + xv[4*i128+0].val[1] = vsubq_s8(vorrq_s8(vandq_s8(q2bits.val[1], ml), vandq_s8(vshlq_n_u8(hbits.val[1], 2), mh)), mh); + xv[4*i128+1].val[0] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[0], 2), ml), vandq_s8(vshlq_n_u8(hbits.val[0], 1), mh)), mh); + xv[4*i128+1].val[1] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[1], 2), ml), vandq_s8(vshlq_n_u8(hbits.val[1], 1), mh)), mh); + xv[4*i128+2].val[0] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[0], 4), ml), vandq_s8(hbits.val[0], mh)), mh); + xv[4*i128+2].val[1] = vsubq_s8(vorrq_s8(vandq_s8(vshrq_n_u8(q2bits.val[1], 4), ml), vandq_s8(hbits.val[1], mh)), mh); + xv[4*i128+3].val[0] = vsubq_s8(vorrq_s8(vshrq_n_u8(q2bits.val[0], 6), vandq_s8(vshrq_n_u8(hbits.val[0], 1), mh)), mh); + xv[4*i128+3].val[1] = vsubq_s8(vorrq_s8(vshrq_n_u8(q2bits.val[1], 6), vandq_s8(vshrq_n_u8(hbits.val[1], 1), mh)), mh); + hbits.val[0] = vshrq_n_u8(hbits.val[0], 4); + hbits.val[1] = vshrq_n_u8(hbits.val[1], 4); + + for (int l = 0; l < 4; ++l) { + auto s1 = vdup_n_s8(helper.val[8*i128+2*l+0]); + auto s2 = vdup_n_s8(helper.val[8*i128+2*l+1]); + auto q16_1 = vmull_s8(s1, vget_low_s8 (xv[4*i128+l].val[0])); + auto q16_2 = vmull_s8(s1, vget_high_s8(xv[4*i128+l].val[0])); + auto q16_3 = vmull_s8(s2, vget_low_s8 (xv[4*i128+l].val[1])); + auto q16_4 = vmull_s8(s2, vget_high_s8(xv[4*i128+l].val[1])); + auto max1 = vmaxq_s16(vabsq_s16(q16_1), vabsq_s16(q16_2)); + auto max2 = vmaxq_s16(vabsq_s16(q16_3), vabsq_s16(q16_4)); + max_i16 = vmaxq_s16(max_i16, vmaxq_s16(max1, max2)); + } + } + auto imax16 = vmaxvq_s16(max_i16); + bool needs_scaling = true; + float dnew = float(imax16) / 127; + if (dnew < 1.f) { + dnew = 1.f; needs_scaling = false; + } + d *= dnew; + y[i].d[k] = GGML_FP32_TO_FP16(d); + auto scale = vdupq_n_f32(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto s1 = vdup_n_s8(helper.val[2*ib32+0]); + auto s2 = vdup_n_s8(helper.val[2*ib32+1]); + auto q16_1 = vmull_s8(s1, vget_low_s8 (xv[ib32].val[0])); + auto q16_2 = vmull_s8(s1, vget_high_s8(xv[ib32].val[0])); + auto q16_3 = vmull_s8(s2, vget_low_s8 (xv[ib32].val[1])); + auto q16_4 = vmull_s8(s2, vget_high_s8(xv[ib32].val[1])); + if (needs_scaling) { + int32x4x4_t i1{vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_1))))), + vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_1))))), + vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_2))))), + vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_2)))))}; + int32x4x4_t i2{vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_3))))), + vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_3))))), + vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (q16_4))))), + vcvtnq_s32_f32(vmulq_f32(scale, vcvtq_f32_s32(vmovl_s16(vget_high_s16(q16_4)))))}; + int16x8x4_t i3{vcombine_s16(vmovn_s32(i1.val[0]), vmovn_s32(i1.val[1])), + vcombine_s16(vmovn_s32(i1.val[2]), vmovn_s32(i1.val[3])), + vcombine_s16(vmovn_s32(i2.val[0]), vmovn_s32(i2.val[1])), + vcombine_s16(vmovn_s32(i2.val[2]), vmovn_s32(i2.val[3]))}; + vst1q_s8((int8_t *)block + 0, vcombine_s8(vmovn_s16(i3.val[0]), vmovn_s16(i3.val[1]))); + vst1q_s8((int8_t *)block + 16, vcombine_s8(vmovn_s16(i3.val[2]), vmovn_s16(i3.val[3]))); + } else { + vst1q_s8((int8_t *)block + 0, vcombine_s8(vmovn_s16(q16_1), vmovn_s16(q16_2))); + vst1q_s8((int8_t *)block + 16, vcombine_s8(vmovn_s16(q16_3), vmovn_s16(q16_4))); + } + auto qs = (uint32_t *)y[i].qs + 64*ib32; + for (int l = 0; l < 8; ++l) { + qs[8*l + k] = block[l]; + } + } + } + } + y += nb; + } +} + } bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) { switch (ggml_type(type)) { case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break; - // case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break; + case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break; // case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; // case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break; // case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c402710c..4082c03d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -272,6 +272,7 @@ struct MulMat { #else switch (type) { case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type;