iq2_kl: AVX2 GEMM/GEMV

This commit is contained in:
Iwan Kawrakow
2025-07-11 15:40:53 +03:00
parent 738031ba0e
commit 4a3b5e3119

View File

@@ -340,14 +340,8 @@ struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
inline __m128i make_scales(int i) const {
//uint16_t aux[8];
//auto h = x[i].scales_h;
//for (int k = 0; k < 4; ++k) { aux[k+0] = (x[i].scales_l[k] & 0xf) | ((h << 4) & 0x30); h >>= 2; }
//for (int k = 0; k < 4; ++k) { aux[k+4] = (x[i].scales_l[k] >> 4) | ((h << 4) & 0x30); h >>= 2; }
//return _mm_sub_epi16(_mm_loadu_si128((const __m128i *)aux), _mm_set1_epi16(32));
uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4);
auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), _mm_set_epi32(0, 0, 4, 0)), _mm_set1_epi8(0xf)));
// 0x000a000800060004
auto sch = _mm_srlv_epi16(_mm_sllv_epi64(_mm_set1_epi16(x[i].scales_h), _mm_set_epi64x(0, 8)), _mm_set1_epi64x(0x000a000800060004));
auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, _mm_and_si128(sch, _mm_set1_epi16(0x30))), _mm_set1_epi16(32));
return scales128;
@@ -376,14 +370,6 @@ struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true>
inline void prepare(int i) {
__m512i ql[2], qs[4];
__mmask64 mask[2];
//// TODO: optimize this
//for (int k = 0; k < 2; ++k) {
// auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0);
// auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1);
// auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1);
// auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2);
// ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1));
//}
auto lbits = _mm512_loadu_si512((const __m512i *)x[i].qs);
ql[0] = _mm512_and_si512(lbits, m4);
ql[1] = _mm512_and_si512(_mm512_srli_epi16(lbits, 4), m4);
@@ -1013,6 +999,68 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
};
struct DequantizerIQ2KL final : public BaseDequantizer<block_iq2_kl, true, true> {
DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); }
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accm) {
auto hbits128 = _mm_loadu_si128((const __m128i *)x[i].qh);
hbits = MM256_SET_M128I(_mm_srli_epi16(hbits128, 1), hbits128);
auto scales128 = make_scales(i);
auto scales_s = _mm_mullo_epi16(scales128, _mm_set1_epi16(-64));
s8k.accum_mins(scales_s, q8, i, d, accm);
return MM256_SET_M128I(scales128, scales128);
}
inline void prepare(int i, int j) {
__m256i ql[2], mask[2];
auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*j+0);
auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*j+1);
ql[0] = _mm256_and_si256(_mm256_set1_epi8(0xf), MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1));
ql[1] = _mm256_and_si256(_mm256_set1_epi8(0xf), MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2));
mask[0] = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, _mm256_set1_epi8(0x1)), _mm256_set1_epi8(0x1));
mask[1] = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, _mm256_set1_epi8(0x4)), _mm256_set1_epi8(0x4));
for (int k = 0; k < 2; ++k) {
auto v0 = _mm256_shuffle_epi8(values[0], ql[k]);
auto v1 = _mm256_shuffle_epi8(values[1], ql[k]);
auto v2 = _mm256_shuffle_epi8(values[2], ql[k]);
auto v3 = _mm256_shuffle_epi8(values[3], ql[k]);
auto q1 = _mm256_or_si256(_mm256_and_si256(mask[k], v1), _mm256_andnot_si256(mask[k], v0));
auto q2 = _mm256_or_si256(_mm256_and_si256(mask[k], v3), _mm256_andnot_si256(mask[k], v2));
auto q1l = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(q1));
auto q1h = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q1, 1));
auto q2l = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(q2));
auto q2h = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q2, 1));
bits.values[2*k+0] = _mm256_or_si256(q1l, _mm256_slli_epi16(q2l, 8));
bits.values[2*k+1] = _mm256_or_si256(q1h, _mm256_slli_epi16(q2h, 8));
}
hbits = _mm256_srli_epi16(hbits, 4);
}
inline __m128i make_scales(int i) const {
uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4);
auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), shift), _mm_set1_epi8(0xf)));
auto sch = _mm_srlv_epi32(_mm_set1_epi16(x[i].scales_h), _mm_set_epi32(12, 8, 4, 0));
sch = _mm_and_si128(sch, _mm_set1_epi32(0x000c0003));
sch = _mm_mullo_epi16(sch, _mm_set1_epi32(0x00040010));
auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, sch), _mm_set1_epi16(32));
return scales128;
}
void load_values() {
static const uint8_t k_values[64] = {
1, 1, 24, 24, 24, 24, 41, 41, 41, 41, 41, 54, 54, 54, 54, 65, 65, 65, 65, 65, 77, 77, 77, 77, 77, 92, 92, 92, 92, 92, 111, 111,
41, 77, 1, 54, 77, 111, 24, 41, 65, 77, 92, 1, 65, 77, 111, 41, 54, 65, 77, 92, 24, 41, 54, 65, 77, 1, 41, 65, 92, 111, 41, 77,
};
for (int k = 0; k < 4; ++k) {
auto v128 = _mm_loadu_si128((const __m128i *)k_values + k);
values[k] = MM256_SET_M128I(v128, v128);
}
}
struct { __m256i values[4]; } bits;
Scales8KBase s8k;
__m256i values[4];
__m256i hbits;
const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
};
struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(5, -32), values(load_values()) {}
template <typename Q8>
@@ -1488,8 +1536,7 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
set_scales_8(all_scales, j, scales);
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS> ||
std::is_same_v<Dequantizer, DequantizerIQ2KL) {
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS>) {
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
} else {
multiply_add(deq.bits, scales, j, i, q8, sumi);