From 1375c21debb780bee0933549d9e155238896d4a4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 21 May 2025 17:16:51 +0300 Subject: [PATCH] Slighty faster iq2_kt --- ggml/src/iqk/iqk_mul_mat.cpp | 53 +++++++++++++----------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index c98a85f6..684cd6e3 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3655,8 +3655,12 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn Trellis1 trellis; - constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + auto shifts = _mm_set_epi32(0, 0, 4, 0); + auto values = _mm_loadu_si128((const __m128i *)iq4k_values); + union { __m256 vec; float val[8]; } s_helper; + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; __m256 accd[k_acc]; const float * y[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); @@ -3670,47 +3674,28 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn for (int i = 0; i < nb; ++i) { const uint16_t * ql = (const uint16_t *)x[i].ql; + auto s8 = _mm_set1_epi32(*(const uint32_t *)x[i].scales); + s8 = _mm_and_si128(_mm_srlv_epi32(s8, shifts), _mm_set1_epi8(0xf)); + s8 = _mm_shuffle_epi8(values, s8); + auto s32 = _mm256_cvtepi8_epi32(s8); + s_helper.vec = _mm256_cvtepi32_ps(s32); for (int ib = 0; ib < QK_K/64; ++ib) { - auto scale1 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] & 0xf]); - auto scale2 = _mm256_set1_ps(iq4k_values[x[i].scales[ib] >> 4]); + auto scale1 = _mm256_set1_ps(s_helper.val[2*ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[2*ib+1]); for (int j = 0; j < 4; ++j) { - uint32_t val1 = ql[4*ib+j+ 0] + 4096; - uint32_t val2 = ql[4*ib+j+16] + 4096; - //const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(val1)); - //const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(val2)); - const __m256 x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1))); - const __m256 x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2))); + auto xval1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(ql[8*ib+j+0]+4096))); + auto xval2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(ql[8*ib+j+4]+4096))); if constexpr (nrc_y == 1) { - accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j ), x_val1, accd[0]); - accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[1]); + accd[0] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[0]); + accd[1] = _mm256_fmadd_ps(_mm256_load_ps(y[0] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[1]); } else { for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j ), x_val1, accd[iy]); - accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 32*ib + 8*j + 128), x_val2, accd[iy]); + accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 0), xval1, accd[iy]); + accd[iy] = _mm256_fmadd_ps(_mm256_load_ps(y[iy] + i*QK_K + 64*ib + 8*j + 32), xval2, accd[iy]); } } } } - //for (int j = 0; j < 128; j+=8) { - // uint32_t val1 = ql[j/8] + 4096; - // uint32_t val2 = ql[j/8+16] + 4096; - // const float x_scale1 = iq4k_values[x[i].scales[j/32] & 0xf]; - // const float x_scale2 = iq4k_values[x[i].scales[j/32] >> 4]; - // const __m256 x_val1 = trellis_gen8(val1); - // const __m256 x_val2 = trellis_gen8(val2); - // for (int iy = 0; iy < nrc_y; ++iy) { - // accd[iy] = _mm256_fmadd_ps( - // _mm256_load_ps(y[iy] + i*QK_K+j), - // _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), - // accd[iy] - // ); - // accd[iy] = _mm256_fmadd_ps( - // _mm256_load_ps(y[iy] + i*QK_K+j+128), - // _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), - // accd[iy] - // ); - // } - //} } if constexpr (nrc_y == 1) { @@ -3727,7 +3712,7 @@ static void mul_mat_iq2_kt_F32_T(int n, const void * vx, size_t bx, const DataIn static inline __m256 abs_ps(__m256 vals) { // Clear sign-bit of all the 32-bit floats in vals - __m256 sign_bit = _mm256_set1_ps(-0.0f); + __m256 sign_bit = _mm256_set1_ps(-0.0f); return _mm256_andnot_ps(sign_bit, vals); }