mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 15:14:10 +00:00
Refactor iqk: fix AVX2
This commit is contained in:
@@ -670,7 +670,14 @@ static void mul_mat_iqX_k_q8_K_AVX512_new(int n, const void * vx, size_t bx, con
|
||||
|
||||
#else
|
||||
|
||||
truct IQXKScales {
|
||||
inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
|
||||
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
|
||||
scales[0] = MM256_SET_M128I(l_scales, l_scales);
|
||||
scales[1] = MM256_SET_M128I(h_scales, h_scales);
|
||||
}
|
||||
|
||||
struct IQXKScales {
|
||||
IQXKScales(int8_t shift, int8_t min_val) : min(_mm256_set1_epi16(min_val)), eshift(_mm_set1_epi8(shift)) {}
|
||||
template <typename Q8>
|
||||
inline void process(int i, float d, uint16_t extra, __m128i scales8, const Q8& q8, __m256 * accm, __m256i * scales) const {
|
||||
@@ -1073,6 +1080,34 @@ struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
|
||||
const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing
|
||||
};
|
||||
|
||||
inline __m256i get_scale_shuffle_16(int i) {
|
||||
static const uint8_t k_shuffle[128] = {
|
||||
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
|
||||
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
|
||||
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
|
||||
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
|
||||
};
|
||||
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
|
||||
}
|
||||
|
||||
inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
|
||||
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
|
||||
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
|
||||
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
|
||||
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
|
||||
}
|
||||
|
||||
inline __m256i get_scale_shuffle_8(int i) {
|
||||
return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
|
||||
}
|
||||
|
||||
inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
|
||||
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
|
||||
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
|
||||
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
|
||||
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
|
||||
@@ -1945,11 +1945,7 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
set_scales_8(all_scales, j, scales);
|
||||
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
|
||||
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
|
||||
} else {
|
||||
multiply_add(deq.bits, scales, j, i, q8, sumi);
|
||||
}
|
||||
multiply_add(deq.bits, scales, j, i, q8, sumi);
|
||||
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user