From ee590519d2b8b924c27c2c97e4da1e28c08078cc Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 8 Oct 2024 19:47:21 +0300 Subject: [PATCH] Fix iq4_xs (Zen4) --- ggml/src/iqk/iqk_mul_mat.cpp | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d697070d..b9eb2bdd 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -635,8 +635,10 @@ struct DequantizerIQ4XS final : public BaseDequantizer { s8k.accum_mins(scales128, q8, i, -128.f*d, accd); auto scales256 = MM256_SET_M128I(scales128, scales128); auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); - scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); - scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); + scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); + scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); } inline void prepare(const uint8_t * q4) { bits.prepare64(q4); @@ -652,11 +654,17 @@ struct DequantizerIQ4XS final : public BaseDequantizer { } Q4Bits bits; - Scales8K s8k; + Scales8KBase s8k; ScaleIQ4XS siq4; const __m512i values; const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; }; struct HighBit5 { @@ -3721,6 +3729,7 @@ template void MulMat::set_functions(MulMat& m) { std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v|| std::is_same_v) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512;