Slightly faster fp16/bf16 gemv on AVX2

It still saturates at the same lower peformance for bf16
This commit is contained in:
Iwan Kawrakow
2025-01-22 09:03:57 +02:00
parent 2c2f728afc
commit cc7642c757

View File

@@ -7148,7 +7148,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
#ifdef __AVX512F__
constexpr int k_nx = 5;
#else
constexpr int k_nx = 2;
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
#endif
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
@@ -7157,14 +7157,26 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
#ifdef __AVX512F__
switch (nx) {
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
#ifdef __AVX512F__
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
#endif
}
#else
if constexpr (nrc_y == 1) {
switch (nx) {
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
}
} else {
switch (nx) {
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
}
}
#endif
}
#ifdef __AVX512BF16__