mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 23:49:23 +00:00
iq2_tn: AVX512
With this tweak we get TG-128 = 19.58 / 35.18 t/s for 1 / 2 threads. At 4 threads we saturate at 48.41 t/s, and then performance slowly degrades with increasing number of threads.
This commit is contained in:
@@ -698,8 +698,6 @@ struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
|
||||
inline void new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] __m256 * accm, [[maybe_unused]] __m512i * scales) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
bits.prepare(x[i].qs);
|
||||
//process_mins_16(_mm256_set1_epi16(1), q8, i, -d, accm);
|
||||
//scales[0] = scales[1] = _mm512_set1_epi16(1);
|
||||
}
|
||||
Q2Bits bits;
|
||||
};
|
||||
@@ -972,6 +970,16 @@ inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i *
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_block_iq2tn(int iy, int i, float d, const Q8& q8, const __m512i * values, __m512 * accd) {
|
||||
auto sumi_scales = _mm256_madd_epi16(_mm256_set1_epi16(-1), q8.load_bsums(iy, i));
|
||||
auto sumi = _mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(_mm512_dpbusd_epi32(
|
||||
_mm512_inserti32x8(_mm512_setzero_si512(), sumi_scales, 0),
|
||||
values[0], q8.load_quants64(iy, i, 0)), values[1], q8.load_quants64(iy, i, 1)),
|
||||
values[2], q8.load_quants64(iy, i, 2)), values[3], q8.load_quants64(iy, i, 3));
|
||||
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
@@ -1054,19 +1062,33 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
|
||||
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
compute_block_iq2tn(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, &accd);
|
||||
}
|
||||
} else {
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if (2*(nb/2) < nb) {
|
||||
int i0 = 2*(nb/2);
|
||||
deq[0]->new_block(i0, q8, &accm, scales);
|
||||
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
|
||||
compute_block_iq2tn(0, i0, deq[0]->d, q8, deq[0]->bits.values, &accd);
|
||||
} else {
|
||||
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
|
||||
}
|
||||
}
|
||||
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
|
||||
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
|
||||
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
|
||||
info.store(ix, 0, _mm512_reduce_add_ps(accd));
|
||||
} else {
|
||||
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
|
||||
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user