From 9fe58aac13e26e09a9c0ba5a7971338c40fefbf1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 15 Jun 2025 10:52:26 +0300 Subject: [PATCH] q3_K: don't scale when all quants in a block are <= 127 when repacking --- ggml/src/iqk/iqk_gemm_kquants.cpp | 42 ++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index c173b4dc..43eff43c 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -2383,7 +2383,11 @@ void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int auto max4 = _mm_cvtepi32_ps(imax4); max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); - float dnew = std::max(1.f, _mm_cvtss_f32(max4) / 127); + bool needs_scaling = true; + float dnew = _mm_cvtss_f32(max4) / 127; + if (dnew < 1.f) { + dnew = 1.f; needs_scaling = false; + } d *= dnew; y[i].d[k] = GGML_FP32_TO_FP16(d); auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); @@ -2392,20 +2396,28 @@ void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1)); q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0])); q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1])); - auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); - auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); - auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); - auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); - i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); - i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); - i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); - i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); - i0 = _mm256_packs_epi32(i0, i1); - i2 = _mm256_packs_epi32(i2, i3); - i0 = _mm256_packs_epi16(i0, i2); - i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); - _mm256_storeu_si256((__m256i *)block, i0); - + if (needs_scaling) { + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 + auto i0 = _mm256_packs_epi16(q16_l, q16_h); + auto i0_l = _mm256_castsi256_si128(i0); + auto i0_h = _mm256_extracti128_si256(i0, 1); + _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h)); + _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h)); + } auto qs = (uint32_t *)y[i].qs + 64*ib32; for (int l = 0; l < 8; ++l) { qs[8*l + k] = block[l];