From f682afb407b703da7ac14d1bdb9200245b20cbaa Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 17 Jun 2025 12:42:06 +0300 Subject: [PATCH] iq5_k - there was a bug with the shifts ...and that's why PPL was so high. It is also high on main. This fixes it. --- ggml/src/ggml.c | 4 + ggml/src/iqk/iqk_gemm_iqk_quants.cpp | 155 +++++++++++++++++++-------- 2 files changed, 112 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a6260136..69b1b46d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1699,7 +1699,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq5_k, .from_float_ref = (ggml_from_float_t)quantize_row_iq5_k_ref, .vec_dot = vec_dot_iq5_k_q8_k, +//#ifdef __AVX2__ +// .vec_dot_type = GGML_TYPE_Q8_2_X4, +//#else .vec_dot_type = GGML_TYPE_Q8_K, +//#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index 85b680ae..d0923d65 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -2281,50 +2281,6 @@ void iqk_convert_iq5_ks_q8_k_r8(int n, const void * vx, size_t bx, void * vy, in } } -//struct DequantizerIQ5K final : public BaseDequantizer { -// DequantizerIQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(2, 0) { load_values(values); } -// template -// inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { -// d = GGML_FP16_TO_FP32(x[i].d); -// iqxk.process(i, d, x[i].extra, make_scales(x[i].scales_l, (const uint16_t *)x[i].scales_h), q8, accm, scales); -// hbits = _mm256_loadu_si256((const __m256i *)x[i].qh); -// } -// inline void prepare(int i, int j) { -// bits.prepare(x[i].qs, j); -// auto h = j == 0 ? hbits : _mm256_srli_epi16(hbits, 4); -// for (int k = 0; k < 4; ++k) { -// auto qh = _mm256_and_si256(_mm256_slli_epi16(h, 7-k), mh); -// auto q5vl = _mm256_or_si256(bits.values[k], qh); -// auto q5vh = _mm256_or_si256(bits.values[k], _mm256_xor_si256(qh, mh)); -// bits.values[k] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); -// } -// } -// __m128i make_scales(const uint8_t * scales_l, const uint16_t * scales_h) const { -// uint64_t aux64; -// memcpy(&aux64, scales_l, 8); -// auto scl = _mm_and_si128(_mm_set_epi64x(aux64 >> 4, aux64), maskl); -// const uint32_t aux32 = scales_h[0] | (scales_h[1] << 16); -// auto aux = _mm_and_si128(_mm_set_epi32(aux32 >> 2, aux32, aux32 << 2, aux32 << 4), maskh); -// auto sch = _mm_shuffle_epi8(aux, iqxk.hshuff); -// return _mm_add_epi8(_mm_or_si128(scl, sch), m32); -// } -// static void load_values(__m256i * values) { -// auto values128_1 = _mm_loadu_si128((const __m128i *)iq5nl_values + 0); -// auto values128_2 = _mm_loadu_si128((const __m128i *)iq5nl_values + 1); -// values[0] = MM256_SET_M128I(values128_1, values128_1); -// values[1] = MM256_SET_M128I(values128_2, values128_2); -// } -// -// Q4Bits bits; -// const IQXKScales iqxk; -// __m256i hbits; -// __m256i values[2]; -// const __m128i maskl = _mm_set1_epi8(0xf); -// const __m128i maskh = _mm_set1_epi8(0x30); -// const __m128i m32 = _mm_set1_epi8(-32); -// const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing -//}; - void iqk_convert_iq5_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -2372,12 +2328,12 @@ void iqk_convert_iq5_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int q5vl = _mm256_or_si256(xv[2*ib64+1], qh); q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh)); xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); - auto shift1 = _mm256_set1_epi8((extra & 1) << 1); - auto shift2 = _mm256_set1_epi8((extra & 2) << 0); + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 2) << 0), _mm_set1_epi8((extra & 1) << 1)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 8) >> 2), _mm_set1_epi8((extra & 4) >> 1)); xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1); xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2); hbits = _mm256_srli_epi16(hbits, 2); - extra >>= 2; + extra >>= 4; } float dnew = convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs); y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); @@ -2387,6 +2343,111 @@ void iqk_convert_iq5_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int } } +void iqk_convert_iq5_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq5_k * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + __m256i values[2]; + { + auto v1 = _mm_loadu_si128((const __m128i *)iq5nl_values+0); + auto v2 = _mm_loadu_si128((const __m128i *)iq5nl_values+1); + values[0] = MM256_SET_M128I(v1, v1); + values[1] = MM256_SET_M128I(v2, v2); + } + + __m256i xv[8]; + uint32_t block[8]; + int16_t ls[16]; + float all_s[64]; + + auto mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq5_k *)((const char *)vx + (ix+k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto extra = x8[k][i].extra; + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh); + for (int ib64 = 0; ib64 < 4; ++ib64) { + ls[4*ib64+0] = ((x8[k][i].scales_l[2*ib64+0] & 0xf) | ((x8[k][i].scales_h[ib64] << 4) & 0x30)) - 32; + ls[4*ib64+1] = ((x8[k][i].scales_l[2*ib64+0] >> 4) | ((x8[k][i].scales_h[ib64] << 2) & 0x30)) - 32; + ls[4*ib64+2] = ((x8[k][i].scales_l[2*ib64+1] & 0xf) | ((x8[k][i].scales_h[ib64] >> 0) & 0x30)) - 32; + ls[4*ib64+3] = ((x8[k][i].scales_l[2*ib64+1] >> 4) | ((x8[k][i].scales_h[ib64] >> 2) & 0x30)) - 32; + auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+ib64); + xv[2*ib64+0] = _mm256_and_si256(bits, _mm256_set1_epi8(0xf)); + xv[2*ib64+1] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf)); + auto qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 7), mh); + auto q5vl = _mm256_or_si256(xv[2*ib64+0], qh); + auto q5vh = _mm256_or_si256(xv[2*ib64+0], _mm256_xor_si256(qh, mh)); + xv[2*ib64+0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + qh = _mm256_and_si256(_mm256_slli_epi16(hbits, 6), mh); + q5vl = _mm256_or_si256(xv[2*ib64+1], qh); + q5vh = _mm256_or_si256(xv[2*ib64+1], _mm256_xor_si256(qh, mh)); + xv[2*ib64+1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + auto shift1 = MM256_SET_M128I(_mm_set1_epi8((extra & 2) << 0), _mm_set1_epi8((extra & 1) << 1)); + auto shift2 = MM256_SET_M128I(_mm_set1_epi8((extra & 8) >> 2), _mm_set1_epi8((extra & 4) >> 1)); + xv[2*ib64+0] = _mm256_add_epi8(xv[2*ib64+0], shift1); + xv[2*ib64+1] = _mm256_add_epi8(xv[2*ib64+1], shift2); + hbits = _mm256_srli_epi16(hbits, 2); + extra >>= 4; + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + // We have two blocks of 16 with different scales + // We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale. + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(ls[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(ls[2*ib32+1])); + auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l); + auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h); + auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h); + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + float max = _mm_cvtss_f32(max4) / 127; + all_s[8*ib32+k] = d*max; + if (max > 1e-9f) { + auto scale = _mm256_set1_ps(1/max); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256()); + } + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + } // namespace bool iqk_convert_iqk_quants_q80_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {