Refactor iqk: fix AVX2

This commit is contained in:
Iwan Kawrakow
2025-05-17 17:45:32 +03:00
parent de5660cee3
commit 082a9bd632
2 changed files with 37 additions and 6 deletions

View File

@@ -670,7 +670,14 @@ static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, con
#else
truct IQXKScales {
inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
scales[0] = MM256_SET_M128I(l_scales, l_scales);
scales[1] = MM256_SET_M128I(h_scales, h_scales);
}
struct IQXKScales {
IQXKScales(int8_t shift, int8_t min_val) : min(_mm256_set1_epi16(min_val)), eshift(_mm_set1_epi8(shift)) {}
template <typename Q8>
inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m256i * scales) const {
@@ -1073,6 +1080,34 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
};
inline __m256i get_scale_shuffle_16(int i) {
static const uint8_t k_shuffle[128] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
};
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
}
inline __m256i get_scale_shuffle_8(int i) {
return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
}
inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);

View File

@@ -1945,11 +1945,7 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
set_scales_8(all_scales, j, scales);
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
} else {
multiply_add(deq.bits, scales, j, i, q8, sumi);
}
multiply_add(deq.bits, scales, j, i, q8, sumi);
}