iq2_tn: AVX512

With this tweak we get TG-128 = 19.58 / 35.18 t/s for 1 / 2 threads.
At 4 threads we saturate at 48.41 t/s, and then performance slowly
degrades with increasing number of threads.
This commit is contained in:
Iwan Kawrakow
2024-08-05 15:47:12 +03:00
parent c063954c1a
commit d0cc103878

View File

@@ -698,8 +698,6 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
//process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm);
//scales[0] = scales[1] = _mm512_set1_epi16(1);
}
Q2Bits bits;
};
@@ -972,6 +970,16 @@ inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i *
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
template <typename Q8>
inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) {
auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
_mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)),
values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3));
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -1054,19 +1062,33 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
for (int kx = 0; kx < k_nx; ++kx) {
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
for (int kx = 0; kx < k_nx; ++kx) {
compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd);
}
} else {
for (int kx = 0; kx < k_nx; ++kx) {
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
}
}
}
if (2*(nb/2) < nb) {
int i0 = 2*(nb/2);
deq[0]->new_block(i0, q8, &accm, scales);
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd);
} else {
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
}
}
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
info.store(ix, 0, _mm512_reduce_add_ps(accd));
} else {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
}
}
}