From 3fb04caef44e4944c8c3663722e43201a28383dd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 23 Dec 2024 13:28:12 +0200 Subject: [PATCH] iq3_s_r4: rearranged quants - AVX2 --- ggml/src/iqk/iqk_mul_mat.cpp | 49 ++++++++++-------------------------- 1 file changed, 13 insertions(+), 36 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8d49248b..23f9e799 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4026,39 +4026,14 @@ struct IndexHelperIQ3S { }; #endif -//void dequantize_row_iq3_s_r4(const block_iq3_s_r4 * x, float * y, int64_t k) { -// auto n_per_row = k/4; -// float * y4[4] = {y, y + n_per_row, y + 2*n_per_row, y + 3*n_per_row}; -// int nblock = n_per_row/QK_K; -// for (int ibl = 0; ibl < nblock; ++ibl) { -// for (int k = 0; k < 4; ++k) { -// const float d = GGML_FP16_TO_FP32(x[ibl].d[k]); -// for (int ib = 0; ib < QK_K/32; ++ib) { -// int l = 4*ib + k; -// float dl = d * (1 + 2*((x[ibl].scales[l%16] >> 4*(l/16)) & 0xf)); -// for (int i = 0; i < 4; ++i) { -// auto grid1 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+0] + ((x[ibl].qh[4*ib+k] << (8-i)) & 0x100)); -// auto grid2 = (const uint8_t *)(iq3s_grid + x[ibl].qs[32*ib+k+8*i+4] + ((x[ibl].qh[4*ib+k] << (4-i)) & 0x100)); -// for (int j = 0; j < 4; ++j) { -// y4[k][QK_K*ibl+32*ib+4*i+ 0+j] = dl * grid1[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+0)) ? -1 : 1); -// y4[k][QK_K*ibl+32*ib+4*i+16+j] = dl * grid2[j] * (x[ibl].signs[16*ib+4*k+j] & (1 << (i+4)) ? -1 : 1); -// } -// } -// } -// } -// } -//} - - template static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); int nbl = n / QK_K; #ifndef HAVE_FANCY_SIMD - auto smask = _mm256_set1_epi64x(0x8040201008040201); - auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); - auto m4 = _mm256_set1_epi8(4); + //auto smask = _mm256_set1_epi64x(0x8040201008040201); + //auto sign_shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); #endif auto smask = _mm256_set1_epi8(1); union { __m256i vec; uint32_t val[8]; } helper; @@ -4066,7 +4041,9 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI __m256 acc[nrc_y] = {}; __m256i isum[nrc_y] = {}; __m256i qx[4]; +#ifdef HAVE_FANCY_SIMD __mmask32 mask[4]; +#endif for (int ix = 0; ix < nrc_x; ix += 4) { auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 @@ -4087,10 +4064,10 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI iq3s_grid[hidx.val[3]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[0]]); } qs += 32; qh += 4; - auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib])); auto signs128 = _mm_loadu_si128((const __m128i*)iq3[ibl].signs + ib); auto signs = MM256_SET_M128I(_mm_srli_epi16(signs128, 4), signs128); #ifdef HAVE_FANCY_SIMD + auto scales = _mm256_cvtepi8_epi32(_mm_set1_epi32(helper.val[ib])); mask[0] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); mask[1] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); mask[2] = _mm256_cmpeq_epi8_mask(_mm256_and_si256(signs, smask), smask); signs = _mm256_srli_epi16(signs, 1); @@ -4109,20 +4086,20 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(sumi, scales)); } #else + auto scales16 = _mm256_cvtepi8_epi16(_mm_set1_epi32(helper.val[ib])); + auto scales = _mm256_unpacklo_epi16(scales16, scales16); auto s1 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); auto s2 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); auto s3 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); signs = _mm256_srli_epi16(signs, 1); auto s4 = _mm256_or_si256(_mm256_cmpeq_epi8(_mm256_and_si256(signs, smask), smask), smask); for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i *)q8.y[iy][ibl].qs + ib); - auto sumi1 = _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(y, s1)); // 16x0 - auto sumi2 = _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(y, s2)); // 16x1 - auto sumi3 = _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(y, s3)); // 16x2 - auto sumi4 = _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(y, s4)); // 16x3 - auto s12 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); // 0,0,1,1, 0,0,1,1, 0,0,1,1, 0,0,1,1 - auto s34 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); // 2,2,3,3, 2,2,3,3, 2,2,3,3, 2,2,3,3 - auto s1234 = _mm256_add_epi16(_mm256_unpacklo_epi64(s12, s34), _mm256_unpackhi_epi64(s12, s34)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 - isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, s1234)); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), s1))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), s2))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), s3))); + sumi = _mm256_add_epi16(sumi, _mm256_maddubs_epi16(qx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), s4))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales, sumi)); } #endif }