From 4b6f0ff9c12e1b8a14c9a7293799af7b90e9af86 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 18 Jun 2025 13:58:31 +0300 Subject: [PATCH] q2_K 202 t/s -> 364 t/s. q2_k_r4 is at 247 t/s. --- ggml/src/iqk/iqk_gemm_kquants.cpp | 86 +++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 3 +- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 48381b5a..b46077f8 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -2018,6 +2018,91 @@ typedef struct { int8_t qs[8*QK8_1]; } block_q8_1_r8; +void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q2_K * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + float f_values[QK_K]; + uint32_t block[8]; + + __m256i xv[4]; + + auto ml = _mm256_set1_epi8(0x03); + auto sign_bit = _mm256_set1_ps(-0.0f); + auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + auto vd = _mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].d)); + auto vm = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].dmin)), _mm256_set1_ps(-1.f)); + auto block_max = _mm256_setzero_ps(); + for (int i128 = 0; i128 < 2; ++i128) { + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128); + xv[0] = _mm256_and_si256(bits, ml); + xv[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml); + xv[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml); + xv[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml); + for (int l = 0; l < 4; ++l) { + auto q1 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[l])); + auto q2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[l], 1)); + q1 = _mm256_mullo_epi16(q1, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 0] & 0xf)); + q2 = _mm256_mullo_epi16(q2, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 1] & 0xf)); + auto m1 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 0] >> 4)); + auto m2 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 1] >> 4)); + auto v0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q1))), vd, m1); + auto v1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q1, 1))), vd, m1); + auto v2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q2))), vd, m2); + auto v3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q2, 1))), vd, m2); + auto max = _mm256_max_ps(_mm256_max_ps(_mm256_andnot_ps(sign_bit, v0), _mm256_andnot_ps(sign_bit, v1)), + _mm256_max_ps(_mm256_andnot_ps(sign_bit, v2), _mm256_andnot_ps(sign_bit, v3))); + block_max = _mm256_max_ps(block_max, max); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 0, v0); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 8, v1); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 16, v2); + _mm256_storeu_ps(f_values + 128*i128 + 32*l + 24, v3); + } + } + auto max4 = _mm_max_ps(_mm256_extractf128_ps(block_max, 1), _mm256_castps256_ps128(block_max)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float d = _mm_cvtss_f32(max4/127.f); + auto id = _mm256_set1_ps(d != 0.0f ? 1/d : 0.0f); + y[i].d[k] = GGML_FP32_TO_FP16(d); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto v0 = _mm256_loadu_ps(f_values + 32*ib32 + 0); + auto v1 = _mm256_loadu_ps(f_values + 32*ib32 + 8); + auto v2 = _mm256_loadu_ps(f_values + 32*ib32 + 16); + auto v3 = _mm256_loadu_ps(f_values + 32*ib32 + 24); + auto i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v0, id), _MM_ROUND_NEAREST)); + auto i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v1, id), _MM_ROUND_NEAREST)); + auto i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v2, id), _MM_ROUND_NEAREST)); + auto i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v3, id), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, perm); + + _mm256_storeu_si256((__m256i *)block, i0); + auto q8 = (uint32_t *)y[i].qs + 64*ib32; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + } + } + } + y += nb; + } +} + void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -2608,6 +2693,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ1_M : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_Q2_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; @@ -364,7 +365,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_BF16: //case GGML_TYPE_BF16_R16: // return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs); - //case GGML_TYPE_Q2_K: + case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: