From fb254f0c97401f93acbca621d04e9099bb6b41e1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 May 2025 20:02:42 +0300 Subject: [PATCH] Slightly faster iq4_kt PP is now almost 50% better than original, TG is ~20% better --- ggml/src/iqk/iqk_gemm_ktquants.cpp | 51 ++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 1c9090b1..9f0e3763 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -269,40 +269,63 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn Trellis2 trellis; - union { __m256 vec; float val[8]; } s_helper; + union { __m256 vec; float val[8]; } s_helper; + union { __m256i vec; uint32_t val[8]; } o_helper; //, q_helper1, q_helper2; __m256 accd[nrc_y]; const float * y[nrc_y]; - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy); + float row_sum[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + y[iy] = (const float *)info.src1_row(iy); + auto sum = _mm256_setzero_ps(); + for (int i = 0; i < n/8; ++i) sum = _mm256_add_ps(sum, _mm256_loadu_ps(y[iy] + 8*i)); + row_sum[iy] = hsum_float_8(sum); + } for (int ix = 0; ix < nrc_x; ++ix) { const float * dptr = (const float *)((const char*)vx + ix*bx); auto d = _mm256_set1_ps(dptr[0] * 31.75f * 1.01f); - auto row_av = _mm256_set1_ps(dptr[1]); + auto dav = dptr[1]; + //auto row_av = _mm256_set1_ps(dptr[1]); const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2); for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { + auto vshb = _mm256_loadu_si256((const __m256i *)x[i].qs); const uint32_t * shb = x[i].qs; - const uint8_t * ql = (const uint8_t *)(shb + 8); + //const uint8_t * ql = (const uint8_t *)(shb + 8); + const uint8_t * ql = (const uint8_t *)(x[i].qs + 8); const uint8_t * qh = ql + kNumGroups; - auto iscales = _mm256_loadu_si256((const __m256i *)shb); - iscales = _mm256_srli_epi32(_mm256_and_si256(iscales, _mm256_set1_epi32(0xff)), 1); + //auto iscales = _mm256_loadu_si256((const __m256i *)shb); + auto iscales = _mm256_srli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(0xff)), 1); s_helper.vec = _mm256_mul_ps(d, _mm256_cvtepi32_ps(_mm256_sub_epi32(iscales, _mm256_set1_epi32(64)))); + o_helper.vec = _mm256_add_epi32(_mm256_slli_epi32(_mm256_and_si256(vshb, _mm256_set1_epi32(1)), 15), _mm256_set1_epi32(4096)); + //vshb = _mm256_srli_epi32(vshb, 8); for (int ib = 0; ib < 4; ++ib) { auto scale1 = _mm256_set1_ps(s_helper.val[ib+0]); auto scale2 = _mm256_set1_ps(s_helper.val[ib+4]); - const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15); - const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15); + //auto vq1 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 0))); + //auto vq2 = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(ql + 8*ib + 32))); + //auto vqh = _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(qh + 8*ib))); + //vq1 = _mm256_add_epi32(vq1, _mm256_and_si256(_mm256_slli_epi32(vqh, 8), _mm256_set1_epi32(0xf00))); + //vq2 = _mm256_add_epi32(vq1, _mm256_and_si256(_mm256_slli_epi32(vqh, 4), _mm256_set1_epi32(0xf00))); + //q_helper1.vec = _mm256_add_epi32(vq1, _mm256_set1_epi32(o_helper.val[ib+0])); + //q_helper2.vec = _mm256_add_epi32(vq2, _mm256_set1_epi32(o_helper.val[ib+4])); + //const uint32_t offset1 = 4096 + ((shb[ib+0] & 1) << 15); + //const uint32_t offset2 = 4096 + ((shb[ib+4] & 1) << 15); for (int j = 0; j < 4; ++j) { const uint32_t sh1 = shb[ib+0] >> (8 + 6*j); const uint32_t sh2 = shb[ib+4] >> (8 + 6*j); // j/4 -> (32*ib+8*j)/4 = 8*ib + 2*j - uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + offset1; - uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + offset2; - uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + offset1; - uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + offset2; + uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0]; + uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4]; + uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0]; + uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4]; + //uint32_t val1 = q_helper1.val[2*j+0] + ((sh1 & 7) << 12); + //uint32_t val2 = q_helper2.val[2*j+0] + ((sh2 & 7) << 12); + //uint32_t val3 = q_helper1.val[2*j+1] + ((sh1 & 56) << 9); + //uint32_t val4 = q_helper2.val[2*j+1] + ((sh2 & 56) << 9); auto x_val1 = _mm256_mul_ps(scale1, trellis_gen8(trellis.next8(val1, val3))); auto x_val2 = _mm256_mul_ps(scale2, trellis_gen8(trellis.next8(val2, val4))); for (int iy = 0; iy < nrc_y; ++iy) { @@ -310,7 +333,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn auto y2 = _mm256_load_ps(y[iy] + i*QK_K+32*ib+8*j+128); accd[iy] = _mm256_fmadd_ps(y1, x_val1, accd[iy]); accd[iy] = _mm256_fmadd_ps(y2, x_val2, accd[iy]); - accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]); + //accd[iy] = _mm256_fmadd_ps(row_av, _mm256_add_ps(y1, y2), accd[iy]); } } } @@ -351,7 +374,7 @@ static void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataIn } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, hsum_float_8(accd[iy])); + info.store(ix, iy, hsum_float_8(accd[iy]) + dav*row_sum[iy]); } } }