From d7ebb3eae4305178842ba1c0074af276da3330e0 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 17 May 2025 10:22:38 +0300 Subject: [PATCH] Zen4: faster PP for iq2_ks --- ggml/src/iqk/iqk_mul_mat.cpp | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f283c701..654cc706 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -2113,16 +2113,26 @@ struct DequantizerIQ2K final : public BaseDequantizer { struct DequantizerIQ2KS final : public BaseDequantizer { DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} template - inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + inline void compute_block(int i, const Q8& q8, __m512 * acc) { prepare(x[i].qs); auto scales128 = make_scales(x[i].scales, x[i].extra >> 8); auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5); - auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); - s8k.accum_mins(scales_s, q8, i, d, accm); + auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); auto scales256 = MM256_SET_M128I(scales128, scales128); auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); - scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); - scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } } inline void prepare(const uint8_t * q2) { bits.prepare(q2); @@ -2147,7 +2157,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch)); } Q2Bits bits; - Scales8K s8k; + Scales8KBase s8k; const __m512i values; const __m128i m16 = _mm_set1_epi8(-16); @@ -2156,6 +2166,12 @@ struct DequantizerIQ2KS final : public BaseDequantizer const __m128i hmask = _mm_set1_epi64x(0x8040201008040201); const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400); const __m128i shift = _mm_set_epi32(0, 0, 4, 0); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; }; struct DequantizerIQ3K final : public BaseDequantizer { @@ -9794,7 +9810,8 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512; - } else if constexpr (std::is_same_v || + } else if constexpr (std::is_same_v || + std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new;