Zen4: faster PP for iq2_ks

This commit is contained in:
Iwan Kawrakow
2025-05-17 10:22:38 +03:00
parent 2f557a0fd6
commit d7ebb3eae4

View File

@@ -2113,16 +2113,26 @@ struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
inline void compute_block(int i, const Q8& q8, __m512 * acc) {
prepare(x[i].qs);
auto scales128 = make_scales(x[i].scales, x[i].extra >> 8);
auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5);
auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
s8k.accum_mins(scales_s, q8, i, d, accm);
auto mins128 = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts)));
auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0]));
auto scales256 = MM256_SET_M128I(scales128, scales128);
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
__m512i scales[4];
for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums(iy, i);
auto prod = _mm256_madd_epi16(mins, q8s);
auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0);
for (int k = 0; k < 4; ++k) {
auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k));
sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]);
}
acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]);
}
}
inline void prepare(const uint8_t * q2) {
bits.prepare(q2);
@@ -2147,7 +2157,7 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch));
}
Q2Bits bits;
Scales8K s8k;
Scales8KBase s8k;
const __m512i values;
const __m128i m16 = _mm_set1_epi8(-16);
@@ -2156,6 +2166,12 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
const __m128i hmask = _mm_set1_epi64x(0x8040201008040201);
const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400);
const __m128i shift = _mm_set_epi32(0, 0, 4, 0);
const __m512i shuffles[4] = {
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
};
};
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
@@ -9794,7 +9810,8 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
m.funcs[5] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 6>;
m.funcs[6] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 7>;
m.funcs[7] = mul_mat_iqX_k_q8_K_AVX512<Dequantizer, 8>;
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
} else if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2KS> ||
std::is_same_v<Dequantizer, DequantizerIQ4KS> ||
std::is_same_v<Dequantizer, DequantizerIQ5KS>) {
m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 1>;
m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512_new<Dequantizer, 2>;