From 48c0b7d8f92df51eeb8bd553e8e7381c2de5089d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 May 2025 16:07:21 +0300 Subject: [PATCH] Slightly faster iq4_kt --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 119 ++++++++++++++++++----------- 1 file changed, 76 insertions(+), 43 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index b7abc2e1..1c9090b1 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -115,7 +115,8 @@ struct Trellis2 { const __m256i mask2 = _mm256_set1_epi32(km32); inline __m256i next8(uint32_t val1, uint32_t val2) { - __m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2); + __m256i mval = MM256_SET_M128I(_mm_set1_epi32(val2), _mm_set1_epi32(val1)); + //__m256i mval = _mm256_setr_epi32(val1, val1, val1, val1, val2, val2, val2, val2); __m256i mres = _mm256_add_epi32(_mm256_mullo_epi32(mval, mka), mkb); return _mm256_and_si256(mres, _mm256_set1_epi32(kmask)) ^ _mm256_set1_epi32(km32); } @@ -251,6 +252,15 @@ static void mul_mat_iq3_kt_F32_T(int n, const void * vx, size_t bx, const DataIn } } +// QuantizerIQKT; +// constexpr static int kSuperBlockSize = QK_K; +// constexpr static int kBlockSize = block_size; -> 32 +// constexpr static int kGroupSize = group_size; -> 4 +// constexpr static int kNg = kBlockSize/kGroupSize; -> 8 +// constexpr static int kNblock = kSuperBlockSize/kBlockSize; -> 8 +// constexpr static int kNumVal = 1 << num_bits; -> 32768 +// constexpr int kNumGroups = Q::kSuperBlockSize/Q::kGroupSize -> 64 + template static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n%QK_K == 0); @@ -259,66 +269,89 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn Trellis2 trellis; + union { __m256 vec; float val[8]; } s_helper; + __m256 accd[nrc_y]; - __m256 accd2[nrc_y]; const float * y[nrc_y]; for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); - const float d = dptr[0] * 31.75f * 1.01f; - const float row_av = dptr[1]; + auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); + auto row_av = _mm256_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_setzero_ps(); - accd2[iy] = _mm256_setzero_ps(); - } + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { const uint32_t * shb = x[i].qs; const uint8_t * ql = (const uint8_t *)(shb + 8); const uint8_t * qh = ql + kNumGroups; - for (int j = 0; j < 128; j+=8) { - const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15); - const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15); - const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64; - const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64; - const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); - const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); - uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; - const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3)); - const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4)); - for (int iy = 0; iy < nrc_y; ++iy) { - accd[iy] = _mm256_fmadd_ps( - _mm256_load_ps(y[iy] + i*QK_K+j), - _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), - accd[iy] - ); - accd[iy] = _mm256_fmadd_ps( - _mm256_load_ps(y[iy] + i*QK_K+j+128), - _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), - accd[iy] - ); - accd2[iy] = _mm256_add_ps( - _mm256_load_ps(y[iy] + i*QK_K+j), - accd2[iy] - ); - accd2[iy] = _mm256_add_ps( - _mm256_load_ps(y[iy] + i*QK_K+j+128), - accd2[iy] - ); + auto iscales = _mm256_loadu_si256((const __m256i *)shb); + iscales = _mm256_srli_epi32(_mm256_and_si256(iscales, _mm256_set1_epi32(0xff)), 1); + s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + for (int ib = 0; ib < 4; ++ib) { + auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); + auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); + const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15); + const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15); + for (int j = 0; j < 4; ++j) { + const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); + const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); + // j/4 -> (32*ib+8*j)/4 = 8*ib + 2*j + uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); + auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y1 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+ 0); + auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128); + accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]); + accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]); + accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]); + } } } + //for (int j = 0; j < 128; j+=8) { + // const uint32_t offset1 = 4096 + ((shb[j/32+0] & 1) << 15); + // const uint32_t offset2 = 4096 + ((shb[j/32+4] & 1) << 15); + // const float x_scale1 = (int)((shb[j/32+0] & 0xff) >> 1) - 64; + // const float x_scale2 = (int)((shb[j/32+4] & 0xff) >> 1) - 64; + // const uint32_t sh1 = shb[j/32+0] >> (8 + 6*((j/8)%4)); + // const uint32_t sh2 = shb[j/32+4] >> (8 + 6*((j/8)%4)); + // uint32_t val1 = ql[j/4+ 0] + ((qh[j/4+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; + // uint32_t val2 = ql[j/4+32] + ((qh[j/4+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; + // uint32_t val3 = ql[j/4+ 1] + ((qh[j/4+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; + // uint32_t val4 = ql[j/4+33] + ((qh[j/4+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + // const __m256 x_val1 = trellis_gen8(trellis.next8(val1, val3)); + // const __m256 x_val2 = trellis_gen8(trellis.next8(val2, val4)); + // for (int iy = 0; iy < nrc_y; ++iy) { + // accd[iy] = _mm256_fmadd_ps( + // _mm256_load_ps(y[iy] + i*QK_K+j), + // _mm256_mul_ps(_mm256_set1_ps(x_scale1), x_val1), + // accd[iy] + // ); + // accd[iy] = _mm256_fmadd_ps( + // _mm256_load_ps(y[iy] + i*QK_K+j+128), + // _mm256_mul_ps(_mm256_set1_ps(x_scale2), x_val2), + // accd[iy] + // ); + // accd2[iy] = _mm256_add_ps( + // _mm256_load_ps(y[iy] + i*QK_K+j), + // accd2[iy] + // ); + // accd2[iy] = _mm256_add_ps( + // _mm256_load_ps(y[iy] + i*QK_K+j+128), + // accd2[iy] + // ); + // } + //} } for (int iy = 0; iy < nrc_y; ++iy) { - __m256 res = _mm256_mul_ps(_mm256_set1_ps(d), accd[iy]); - __m256 res2 = _mm256_mul_ps(_mm256_set1_ps(row_av), accd2[iy]); - info.store(ix, iy, hsum_float_8(res) + hsum_float_8(res2)); + info.store(ix, iy, hsum_float_8(accd[iy])); } } }