diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ed79e1a1..35cf0473 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -4335,7 +4335,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); #else auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); - auto m128 = _mm256_set1_epi8(-128); #endif int nbl = n / QK_K; __m256 acc[nrc_y] = {}; @@ -4405,25 +4404,22 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI qx[3] = _mm256_add_epi8(qx[3], shift); } #else - auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), m128); - auto q5vl = _mm256_or_si256(qx[0], qh); - auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, m128)); - qx[0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); - qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), m128); - q5vl = _mm256_or_si256(qx[1], qh); - q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, m128)); - qx[1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]); + auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]); + qx[0] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01))); - qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), m128); - q5vl = _mm256_or_si256(qx[2], qh); - q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, m128)); - qx[2] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + q5vl = _mm256_shuffle_epi8(values[0], qx[1]); + q5vh = _mm256_shuffle_epi8(values[1], qx[1]); + qx[1] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10))); - qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), m128); - q5vl = _mm256_or_si256(qx[3], qh); - q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, m128)); - qx[3] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); + q5vl = _mm256_shuffle_epi8(values[0], qx[2]); + q5vh = _mm256_shuffle_epi8(values[1], qx[2]); + qx[2] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02))); + + q5vl = _mm256_shuffle_epi8(values[0], qx[3]); + q5vh = _mm256_shuffle_epi8(values[1], qx[3]); + qx[3] = _mm256_blendv_epi8(q5vl, q5vh, _mm256_cmpeq_epi8(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20))); auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); shift = _mm256_shuffle_epi8(shift, shift_shuffle);