diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 376165f9..ed79e1a1 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4335,8 +4335,8 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); #else auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); -#endif auto m128 = _mm256_set1_epi8(-128); +#endif int nbl = n / QK_K; __m256 acc[nrc_y] = {}; __m256i qx[4]; @@ -4379,23 +4379,32 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4); - // 0, 4, 1, 5 +#ifdef HAVE_FANCY_SIMD + auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); + auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); + qx[0] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01)), q5vl, q5vh); - // This is slower -//#ifdef HAVE_FANCY_SIMD -// auto mask1 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01)); -// auto mask2 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10)); -// auto mask3 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02)); -// auto mask4 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20)); -// qx[0] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask1), values[0], qx[0]), mask1, values[1], qx[0]); -// qx[1] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask2), values[0], qx[1]), mask2, values[1], qx[1]); -// qx[2] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask3), values[0], qx[2]), mask3, values[1], qx[2]); -// qx[3] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask4), values[0], qx[3]), mask4, values[1], qx[3]); -// qx[0] = _mm256_add_epi8(qx[0], shift); -// qx[1] = _mm256_add_epi8(qx[1], shift); -// qx[2] = _mm256_add_epi8(qx[2], shift); -// qx[3] = _mm256_add_epi8(qx[3], shift); -//#else + q5vl = _mm256_shuffle_epi8(values[0], qx[1]); + q5vh = _mm256_shuffle_epi8(values[1], qx[1]); + qx[1] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[2]); + q5vh = _mm256_shuffle_epi8(values[1], qx[2]); + qx[2] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02)), q5vl, q5vh); + + q5vl = _mm256_shuffle_epi8(values[0], qx[3]); + q5vh = _mm256_shuffle_epi8(values[1], qx[3]); + qx[3] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20)), q5vl, q5vh); + + if constexpr (nrc_y == 1) { + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + } +#else auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), m128); auto q5vl = _mm256_or_si256(qx[0], qh); auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, m128)); @@ -4416,16 +4425,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, m128)); qx[3] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); -#ifdef HAVE_FANCY_SIMD - if constexpr (nrc_y == 1) { - auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); - shift = _mm256_shuffle_epi8(shift, shift_shuffle); - qx[0] = _mm256_add_epi8(qx[0], shift); - qx[1] = _mm256_add_epi8(qx[1], shift); - qx[2] = _mm256_add_epi8(qx[2], shift); - qx[3] = _mm256_add_epi8(qx[3], shift); - } -#else auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); shift = _mm256_shuffle_epi8(shift, shift_shuffle); qx[0] = _mm256_add_epi8(qx[0], shift);