iqk_mul_mat(bitnet): slightly faster AVX2

We now get 214 t/s on the Ryzen-7950X
This commit is contained in:
Iwan Kawrakow
2024-06-17 16:32:25 +03:00
parent 30a771bd6b
commit ddea72453b

View File

@@ -1326,7 +1326,7 @@ template <int nrc> struct Q8_K64 {
};
template <int nrc_y>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
__m256 accd[nrc_y];
@@ -1375,17 +1375,28 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
auto v2 = _mm256_sub_epi8(_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff2), mask1), mask1),
_mm256_cmpeq_epi8(_mm256_and_si256(_mm256_shuffle_epi8(aux2, shuff1), mask1), mask1));
for (int iy = 0; iy < nrc_y; ++iy) {
auto q8_1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), signs[0]);
auto q8_2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), signs[1]);
auto dot1 = _mm256_sign_epi8(q8_1, v1);
auto dot2 = _mm256_sign_epi8(q8_2, v2);
if constexpr (nrc_y == 1) {
auto dot1 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 0), signs[0]), v1);
auto dot2 = _mm256_sign_epi8(_mm256_sign_epi8(q8.load_quants(0, i, 1), signs[1]), v2);
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_api16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
#endif
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
accd[0] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(0, i)), _mm256_cvtepi32_ps(dot), accd[0]);
} else {
v1 = _mm256_sign_epi8(v1, signs[0]);
v2 = _mm256_sign_epi8(v2, signs[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto dot1 = _mm256_sign_epi8(q8.load_quants(iy, i, 0), v1);
auto dot2 = _mm256_sign_epi8(q8.load_quants(iy, i, 1), v2);
#if defined __AVX512VNNI__ && defined __AVX512VL__
auto dot = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_setzero_si256(), m1_8, dot1), m1_8, dot2);
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(m1_8, dot1), _mm256_maddubs_epi16(m1_8, dot2)));
#endif
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(q8.scale(iy, i)), _mm256_cvtepi32_ps(dot), accd[iy]);
}
}
}
@@ -1393,7 +1404,6 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
info.store(ix, iy, scale.f * hsum_float_8(accd[iy]));
}
//x += step;
}
}