iq2_kl: better Zen4

This commit is contained in:
Iwan Kawrakow
2025-07-11 13:09:58 +03:00
parent cd32c732f5
commit b805f69c5a

View File

@@ -376,14 +376,20 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
inline void prepare(int i) {
__m512i ql[2], qs[4];
__mmask64 mask[2];
// TODO: optimize this
for (int k = 0; k < 2; ++k) {
auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0);
auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1);
auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1);
auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2);
ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1));
}
//// TODO: optimize this
//for (int k = 0; k < 2; ++k) {
// auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0);
// auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1);
// auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1);
// auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2);
// ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1));
//}
auto lbits = _mm512_loadu_si512((const __m512i *)x[i].qs);
ql[0] = _mm512_and_si512(lbits, m4);
ql[1] = _mm512_and_si512(_mm512_srli_epi16(lbits, 4), m4);
auto tmp = _mm512_permutex2var_epi64(ql[0], permute1, ql[1]);
ql[1] = _mm512_permutex2var_epi64(ql[0], permute2, ql[1]);
ql[0] = tmp;
auto h128 = _mm_loadu_si128((const __m128i *)x[i].qh);
auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 1), h128);
auto h512 = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1);
@@ -401,12 +407,16 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
for (int l = 0; l < 4; ++l) qs[l] = _mm512_shuffle_epi8(values[l], ql[k]);
auto q1 = _mm512_mask_blend_epi8(mask[k], qs[0], qs[1]);
auto q2 = _mm512_mask_blend_epi8(mask[k], qs[2], qs[3]);
auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1));
auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1));
auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2));
auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1));
bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8));
bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8));
auto t1 = _mm512_unpacklo_epi8(q1, q2); // 0...15, 32...47, 64...79, 96...111
auto t2 = _mm512_unpackhi_epi8(q1, q2); // 16...31, 48...63, 80...95, 112...127
bits.values[2*k+0] = _mm512_permutex2var_epi64(t1, permute1, t2);
bits.values[2*k+1] = _mm512_permutex2var_epi64(t1, permute2, t2);
//auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1));
//auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1));
//auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2));
//auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1));
//bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8));
//bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8));
}
}
void load_values() {
@@ -426,6 +436,8 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
const __m512i m01 = _mm512_set1_epi8(0x01);
const __m512i m10 = _mm512_set1_epi8(0x10);
const __m512i m4 = _mm512_set1_epi8(0xf);
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
__m512i values[4];
const __m512i shuffles[4] = {
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),