iq5_k_r4: better Zen4

But TG is still slower than iq5_k
This commit is contained in:
Iwan Kawrakow
2024-12-18 11:35:40 +02:00
parent a8b0f8148c
commit ece6d0a52b

View File

@@ -4335,8 +4335,8 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#else
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
#endif
auto m128 = _mm256_set1_epi8(-128);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
@@ -4379,23 +4379,32 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4);
// 0, 4, 1, 5
#ifdef HAVE_FANCY_SIMD
auto q5vl = _mm256_shuffle_epi8(values[0], qx[0]);
auto q5vh = _mm256_shuffle_epi8(values[1], qx[0]);
qx[0] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01)), q5vl, q5vh);
// This is slower
//#ifdef HAVE_FANCY_SIMD
// auto mask1 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x01)), _mm256_set1_epi8(0x01));
// auto mask2 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10));
// auto mask3 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02));
// auto mask4 = _mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20));
// qx[0] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask1), values[0], qx[0]), mask1, values[1], qx[0]);
// qx[1] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask2), values[0], qx[1]), mask2, values[1], qx[1]);
// qx[2] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask3), values[0], qx[2]), mask3, values[1], qx[2]);
// qx[3] = _mm256_mask_shuffle_epi8(_mm256_maskz_shuffle_epi8(_knot_mask64(mask4), values[0], qx[3]), mask4, values[1], qx[3]);
// qx[0] = _mm256_add_epi8(qx[0], shift);
// qx[1] = _mm256_add_epi8(qx[1], shift);
// qx[2] = _mm256_add_epi8(qx[2], shift);
// qx[3] = _mm256_add_epi8(qx[3], shift);
//#else
q5vl = _mm256_shuffle_epi8(values[0], qx[1]);
q5vh = _mm256_shuffle_epi8(values[1], qx[1]);
qx[1] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x10)), _mm256_set1_epi8(0x10)), q5vl, q5vh);
q5vl = _mm256_shuffle_epi8(values[0], qx[2]);
q5vh = _mm256_shuffle_epi8(values[1], qx[2]);
qx[2] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x02)), _mm256_set1_epi8(0x02)), q5vl, q5vh);
q5vl = _mm256_shuffle_epi8(values[0], qx[3]);
q5vh = _mm256_shuffle_epi8(values[1], qx[3]);
qx[3] = _mm256_mask_blend_epi8(_mm256_cmpeq_epi8_mask(_mm256_and_si256(hb, _mm256_set1_epi8(0x20)), _mm256_set1_epi8(0x20)), q5vl, q5vh);
if constexpr (nrc_y == 1) {
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);
qx[1] = _mm256_add_epi8(qx[1], shift);
qx[2] = _mm256_add_epi8(qx[2], shift);
qx[3] = _mm256_add_epi8(qx[3], shift);
}
#else
auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), m128);
auto q5vl = _mm256_or_si256(qx[0], qh);
auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, m128));
@@ -4416,16 +4425,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI
q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, m128));
qx[3] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh));
#ifdef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);
qx[1] = _mm256_add_epi8(qx[1], shift);
qx[2] = _mm256_add_epi8(qx[2], shift);
qx[3] = _mm256_add_epi8(qx[3], shift);
}
#else
auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1);
shift = _mm256_shuffle_epi8(shift, shift_shuffle);
qx[0] = _mm256_add_epi8(qx[0], shift);