From a5153881198e577ac45258f93f68b5c12511dbe4 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 9 Dec 2024 16:26:24 +0200 Subject: [PATCH] Minor --- ggml/src/iqk/iqk_mul_mat.cpp | 42 ++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a754300c..3f1e12df 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3085,7 +3085,7 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D auto m1 = _mm256_set1_epi16(1); #endif int nbl = n / QK_K; - union { __m256i vec; uint32_t val[8]; } hd, hm; + union { __m256i vec; uint32_t val[8]; } hd; __m256 acc[nrc_y] = {}; __m256i qx[4]; for (int ix = 0; ix < nrc_x; ix += 4) { @@ -3093,7 +3093,7 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d)); auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl)); - auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)); + auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1))); if constexpr (nrc_y == 1) { d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl))); } @@ -3101,26 +3101,24 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h); auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4)); hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3)); - hm.vec = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3)); - //if constexpr (nrc_y > 2) { - m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), m4); - auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[0]))))); - auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[5])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[1]))))); - auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[2]))))); - auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(hm.val[3]))))); - for (int iy = 0; iy < nrc_y; ++iy) { - auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); - acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); - acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); - } - //} else { - // m4 = _mm256_mul_ps(_mm256_set1_ps(-0.5f), m4); - //} + auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3)); + auto shuffle = _mm256_set1_epi64x(0x0000000400000000); + auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + shuffle = _mm256_add_epi32(shuffle, _mm256_set1_epi32(1)); + auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm256_castsi256_si128(_mm256_permutevar8x32_epi32(mins, shuffle))))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); + acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]); + acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]); + } for (int ib = 0; ib < QK_K/32; ++ib) { auto scales_d = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib])))); - //auto scales_m = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set1_epi32(hm.val[ib])))); auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); qx[0] = _mm256_and_si256(bits1, mf); @@ -3148,10 +3146,6 @@ static void mul_mat_q4_k_r4_q8_k_avx2(int n, const void * vx, size_t bx, const D float d8 = q8.scale(iy, ibl); acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales_d, _mm256_set1_ps(d8)), _mm256_cvtepi32_ps(sumi), acc[iy]); } - //if constexpr (nrc_y <= 2) { - // float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; - // acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); - //} } } }