iq3_ks: AVX2 GEMM/GEMV

This commit is contained in:
Iwan Kawrakow
2025-07-01 11:12:05 +03:00
parent 4fca652130
commit 3e6341d72a

View File

@@ -984,6 +984,40 @@ struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
constexpr static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
};
struct DequantizerIQ3KS final : public BaseDequantizer<block_iq3_ks, true, true> {
DequantizerIQ3KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
template <typename Q8>
inline __m256i new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accd) {
uint32_t aux32; std::memcpy(&aux32, x[i].scales, 4);
auto scl = _mm_cvtepi8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), _mm_set_epi32(0, 0, 4, 0)), _mm_set1_epi8(0xf)));
auto sch = _mm_cmpeq_epi16(_mm_and_si128(_mm_set1_epi16(x[i].extra), mask), mask);
auto scales128 = _mm_add_epi16(scl, _mm_and_si128(sch, _mm_set1_epi16(16)));
scales128 = _mm_sub_epi16(scales128, _mm_set1_epi16(16));
return MM256_SET_M128I(scales128, scales128);
}
inline void prepare(int i, int j) {
uint8_t extra = x[i].extra >> (8 + 4*j);
hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4);
bits.prepare(x[i].qs, j);
bits.values[0] = _mm256_add_epi8(_mm256_set1_epi8((extra << 3) & 8), _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)));
bits.values[1] = _mm256_add_epi8(_mm256_set1_epi8((extra << 2) & 8), _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)));
bits.values[2] = _mm256_add_epi8(_mm256_set1_epi8((extra << 1) & 8), _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh)));
bits.values[3] = _mm256_add_epi8(_mm256_set1_epi8((extra << 0) & 8), _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)));
for (int k = 0; k < 4; ++k) bits.values[k] = _mm256_shuffle_epi8(values, bits.values[k]);
}
inline __m256i load_values() {
auto v = _mm_loadu_si128((const __m128i *)iq3nl_values);
return MM256_SET_M128I(v, v);
}
Q2Bits bits;
__m256i hbits;
const __m256i values;
const __m256i mh = _mm256_set1_epi8(4);
const __m128i mask = _mm_setr_epi16(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80);
};
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
DequantizerIQ4KSS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
template <typename Q8>
@@ -1348,7 +1382,7 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
set_scales_8(all_scales, j, scales);
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS>) {
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> || std::is_same_v<Dequantizer, DequantizerIQ3KS>) {
multiply_add_avx2(deq.bits, scales, j, i, q8, sumi);
} else {
multiply_add(deq.bits, scales, j, i, q8, sumi);