From 5432108e9c7b8efbd7459389f8dfc72fd07a90f0 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 12 Jun 2025 13:10:24 +0300 Subject: [PATCH] q5_K: GEMM with q8_2_X4 and repack to q8_1_r8 --- ggml/src/ggml.c | 4 ++ ggml/src/iqk/iqk_gemm_kquants.cpp | 102 +++++++++++++++++++++++++++--- ggml/src/iqk/iqk_mul_mat.cpp | 3 +- 3 files changed, 100 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7afc5287..069533ae 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1006,7 +1006,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q5_K, .from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref, .vec_dot = ggml_vec_dot_q5_K_q8_K, +#ifdef __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_K, +#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index ff5090cc..b834c5d8 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -752,11 +752,6 @@ struct Q4Bits_AVX2 { struct DequantizerQ4K_AVX2 final : public BaseDequantizer { DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - template - inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { - d = GGML_FP16_TO_FP32(x[i].d); - return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd); - } inline void prepare(int i, int j) { bits.prepare(x[i].qs, j); } @@ -765,6 +760,26 @@ struct DequantizerQ4K_AVX2 final : public BaseDequantizer { Scales8K s8k; }; +struct DequantizerQ5K_AVX2 final : public BaseDequantizer { + DequantizerQ5K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + inline void prepare(int i, int j) { + bits.prepare(x[i].qs, j); + hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4); + apply_hbits(); + } + inline void apply_hbits() { + bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh)); + bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh)); + bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + } + + const __m256i mh = _mm256_set1_epi8(0x10); + Q4Bits_AVX2 bits; + __m256i hbits; + Scales8K s8k; +}; + template static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); @@ -814,7 +829,7 @@ static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const Data for (int iy = 0; iy < nrc_y; ++iy) { const block_q8_2_x4& y = q8.y[iy][2*i+j]; -#ifdef z_HAVE_FANCY_SIMD +#ifdef HAVE_FANCY_SIMD auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0)); auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1)); auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2)); @@ -1901,6 +1916,75 @@ void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int } } +void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q5_K * x8[8]; + + block_q8_1_r8 * y = (block_q8_1_r8 *)vy; + + ggml_half dh[16]; + uint16_t all_ls[128]; + + uint32_t utmp[4]; + const uint8_t * u8 = (const uint8_t *)utmp; + uint32_t block[8]; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q5_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + dh[k+0] = x8[k][i].d; + dh[k+8] = x8[k][i].dmin; + make_q4_scales(x8[k][i].scales, utmp); + auto qs = x8[k][i].qs; + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int ib64 = 0; ib64 < 4; ++ib64) { + all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0]; + all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1]; + all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8]; + all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9]; + auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64); + auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + values1 = _mm256_or_si256(values1, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 4))); + values2 = _mm256_or_si256(values2, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 3))); + hbits = _mm256_srli_epi16(hbits, 2); + _mm256_storeu_si256((__m256i *)block, values1); + auto q8 = (uint32_t *)y[2*ib64+0].qs; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + _mm256_storeu_si256((__m256i *)block, values2); + q8 = (uint32_t *)y[2*ib64+1].qs; + for (int l = 0; l < 4; ++l) { + q8[8*l + k + 0] = block[l + 0]; + q8[8*l + k + 32] = block[l + 4]; + } + } + } + auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0)); + auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1)); + vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm); + for (int ib32 = 0; ib32 < QK_K/32; ++ib32) { + auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32); + auto iscales32 = _mm256_cvtepi16_epi32(iscales16); + auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8); + iscales32 = _mm256_cvtepi16_epi32(iscales16); + scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32)); + _mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + } // namespace @@ -1910,7 +1994,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array(kernels); break; case GGML_TYPE_Q5_K: - set_functions(kernels); + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ5K_AVX2, kernels); + //set_functions(kernels); break; case GGML_TYPE_Q6_K: set_functions(kernels); @@ -1983,6 +2068,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; + case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; default: break; } #else @@ -345,7 +346,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K: //case GGML_TYPE_Q6_K: //case GGML_TYPE_IQ4_XS: //case GGML_TYPE_Q2_K_R4: