Fix iq4_xs (Zen4)

This commit is contained in:
Iwan Kawrakow
2024-10-08 19:47:21 +03:00
parent c24ad0d1e7
commit ee590519d2

View File

@@ -635,8 +635,10 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
s8k.accum_mins(scales128, q8, i, -128.f*d, accd); s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
auto scales256 = MM256_SET_M128I(scales128, scales128); auto scales256 = MM256_SET_M128I(scales128, scales128);
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
} }
inline void prepare(const uint8_t * q4) { inline void prepare(const uint8_t * q4) {
bits.prepare64(q4); bits.prepare64(q4);
@@ -652,11 +654,17 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
} }
Q4Bits bits; Q4Bits bits;
Scales8K s8k; Scales8KBase s8k;
ScaleIQ4XS siq4; ScaleIQ4XS siq4;
const __m512i values; const __m512i values;
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0); const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4); const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
const __m512i shuffles[4] = {
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
};
}; };
struct HighBit5 { struct HighBit5 {
@@ -3721,6 +3729,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
std::is_same_v<Dequantizer, DequantizerIQ5K> || std::is_same_v<Dequantizer, DequantizerIQ5K> ||
std::is_same_v<Dequantizer, DequantizerIQ4K> || std::is_same_v<Dequantizer, DequantizerIQ4K> ||
std::is_same_v<Dequantizer, DequantizerIQ3K> || std::is_same_v<Dequantizer, DequantizerIQ3K> ||
std::is_same_v<Dequantizer, DequantizerIQ4XS>||
std::is_same_v<Dequantizer, DequantizerIQ4XXS>) { std::is_same_v<Dequantizer, DequantizerIQ4XXS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>; m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 2>;