iqk_mul_mat: slightly faster FANCY_SIMD dot product

About 2% faster for q4_K.
This commit is contained in:
Iwan Kawrakow
2024-06-09 18:01:52 +03:00
parent 8a80a31ddd
commit 09d86e5876

View File

@@ -723,6 +723,17 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
};
template <typename Q8>
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
@@ -748,6 +759,7 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
//compute_block(iy, i, deq.d, q8, deq.bits.values, scales, accd);
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
@@ -767,6 +779,51 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da
}
}
template <typename Dequantizer>
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
constexpr int k_nx = 2;
Q8<1> q8(info);
Dequantizer deq1(vx, bx);
Dequantizer deq2(vx, bx);
Dequantizer * deq[k_nx];
deq[0] = &deq1;
deq[1] = &deq2;
__m512i scales[2*k_nx];
for (int ix = 0; ix < nrc_x; ++ix) {
auto accd = _mm512_setzero_ps();
auto accm = _mm256_setzero_ps();
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
for (int i = 0; i < nb/k_nx; ++i) {
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
for (int kx = 0; kx < k_nx; ++kx) {
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
}
}
if (2*(nb/2) < nb) {
int i0 = 2*(nb/2);
deq[0]->new_block(i0, q8, &accm, scales);
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
}
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
}
}
#else
// ===================================== Vanilla AVX2 =====================================
@@ -2267,7 +2324,7 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
else {
#ifdef HAVE_FANCY_SIMD
m.funcs[0] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 1>;
m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
m.funcs[1] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 2>;
m.funcs[2] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 3>;
m.funcs[3] = mul_mat_qX_K_q8_K_AVX512<Dequantizer, 4>;