iqk_mul_mat: minor improvements

Current performance:
| model             |       size |  threads |    test |              t/s |
| ----------------- | ---------: | -------: | ------: | ---------------: |
| llama 7B IQ3_S    |   2.75 GiB |       16 |   pp512 |    100.21 ± 0.32 |
| llama 7B IQ3_XXS  |   2.41 GiB |       16 |   pp512 |    105.25 ± 0.75 |
| llama 7B IQ2_M    |   2.20 GiB |       16 |   pp512 |    117.88 ± 0.15 |
| llama 7B IQ2_XS   |   1.89 GiB |       16 |   pp512 |    136.38 ± 0.24 |
| llama 7B IQ2_XXS  |   1.73 GiB |       16 |   pp512 |    128.47 ± 0.39 |
                                                     mean: 117.64
| ----------------- | ---------: | -------: | ------: | ---------------: |
| llama 7B IQ2_XXS  |   1.73 GiB |        8 |   tg128 |     23.94 ± 0.04 |
| llama 7B IQ2_XS   |   1.89 GiB |        8 |   tg128 |     23.27 ± 0.03 |
| llama 7B IQ2_M    |   2.20 GiB |        8 |   tg128 |     18.88 ± 0.03 |
| llama 7B IQ3_XXS  |   2.41 GiB |        8 |   tg128 |     19.07 ± 0.04 |
| llama 7B IQ3_S    |   2.75 GiB |        8 |   tg128 |     15.44 ± 0.05 |
                                                     mean:  20.12
This commit is contained in:
Kawrakow
2024-06-05 19:43:08 +03:00
parent e85753e1ad
commit 5039ea8930

View File

@@ -1192,10 +1192,12 @@ template <typename Bits>
inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, const __m256i * q8, __m256i * sumi) {
if (j == 0) {
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
sumi[0] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_packs_epi32(p1, p2));
sumi[1] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[1], _mm256_packs_epi32(p3, p4));
#else
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
@@ -1206,10 +1208,12 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons
#endif
} else {
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[2], _mm256_maddubs_epi16(bits.values[2], q8[2]));
sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[3], _mm256_maddubs_epi16(bits.values[3], q8[3]));
auto p1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[0], q8[0]);
auto p2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[1], q8[1]);
auto p3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[2], q8[2]);
auto p4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), bits.values[3], q8[3]);
sumi[0] = _mm256_dpwssd_epi32(sumi[0], scales[0], _mm256_packs_epi32(p1, p2));
sumi[1] = _mm256_dpwssd_epi32(sumi[1], scales[1], _mm256_packs_epi32(p3, p4));
#else
const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8[0]));
const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8[1]));
@@ -1221,12 +1225,33 @@ inline void multiply_add_1(int j, const Bits& bits, const __m256i * scales, cons
}
}
inline void set_scales_8_iq(int j, const __m256i& all_scales, __m256i * scales) {
#ifdef HAVE_FANCY_SIMD
auto shuffle = j == 0 ? _mm256_set_epi64x(0x0302030203020302, 0x0100010001000100, 0x0302030203020302, 0x0100010001000100)
: _mm256_set_epi64x(0x0b0a0b0a0b0a0b0a, 0x0908090809080908, 0x0b0a0b0a0b0a0b0a, 0x0908090809080908);
scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(4)));
#else
set_scales_8(all_scales, j, scales);
#endif
}
inline void set_scales_16_iq(const __m256i& all_scales, __m256i * scales) {
#ifdef HAVE_FANCY_SIMD
auto shuffle = _mm256_set_epi64x(0x0706070607060706, 0x0302030203020302, 0x0504050405040504, 0x0100010001000100);
scales[0] = _mm256_shuffle_epi8(all_scales, shuffle);
scales[1] = _mm256_shuffle_epi8(all_scales, _mm256_add_epi8(shuffle, _mm256_set1_epi8(8)));
#else
set_scales_16(all_scales, scales);
#endif
}
template <typename Dequantizer>
static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_K;
Q8<1> q8(info);
Dequantizer deq(vx, bx);
__m256i scales[4];
__m256i scales[2];
__m256i q8_quants[4];
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -1241,9 +1266,9 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j, q8, q8_quants);
if constexpr (Dequantizer::num_blocks == 8) {
set_scales_8(all_scales[0], j, scales);
set_scales_8_iq(j, all_scales[0], scales);
} else {
set_scales_16(all_scales[j], scales);
set_scales_16_iq(all_scales[j], scales);
}
multiply_add_1(j, deq.bits, scales, q8_quants, sumi);
}
@@ -1254,6 +1279,32 @@ static void mul_mat_qX_K_q8_K_IQ_1(int n, const void * vx, size_t bx, const Data
}
}
// So, if I uncomment this function and the call to it in mul_mat_qX_K_q8_K_IQ_N() below,
// PP performance improves by ~2-3% (when we have __AVX512VNNI__ and __AVX512VL__).
// But TG performance for iq3_xs drops by 35%. Seriously? I mean, c'mon,
// what does the compilation of mul_mat_qX_K_q8_K_IQ_1 (which gets invoked during TG)
// have to do with the compilation of mul_mat_qX_K_q8_K_IQ_N (invoked during PP)?
//template <typename Q8, typename Bits>
//inline void multiply_add_iq(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) {
//#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
// sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
// }
//#else
// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
// const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4*j+0)));
// const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 4*j+1)));
// const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 4*j+2)));
// const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 4*j+3)));
// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3));
// sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4));
// }
//#endif
//}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_K;
@@ -1271,6 +1322,7 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) {
__m256i sumi[nrc_y], all_scales[Dequantizer::num_blocks/8];
//for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
__m256i mins;
float dmin = deq.new_block(i, all_scales, mins);
for (int iy = 0; iy < nrc_y; ++iy) {
@@ -1286,6 +1338,7 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
} else {
set_scales_16(all_scales[j], scales);
}
//multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
multiply_add(deq.bits, scales, j, i, q8, sumi);
}
for (int iy = 0; iy < nrc_y; ++iy) {