Slightly better gemv for not repacked fp16

This commit is contained in:
Iwan Kawrakow
2025-01-23 09:04:07 +02:00
parent 4941c043bb
commit 8a5a81b4dc

View File

@@ -6971,6 +6971,7 @@ struct QFBase {
static inline Data load4Floats(const Float * x) {
return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
}
static inline Data add(Data x, Data y) { return _mm512_add_ps(x, y); }
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
@@ -6999,6 +7000,7 @@ struct QFBase {
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return _mm256_fmadd_ps(y, x, prev);
}
static inline Data add(Data x, Data y) { return _mm256_add_ps(x, y); }
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
@@ -7110,6 +7112,38 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
}
template <typename Qy, typename Qx>
void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
static_assert(Qy::nrc == 1);
static_assert(Qx::nrc == 1);
int nb = n/QFBase::k_step;
int nb4 = n/4;
Qy y(info);
Qx x(cx + ix0*bx, bx);
QFBase::Acc acc[4] = {};
for (int i = 0; i < nb/4; ++i) {
for (int k = 0; k < 4; ++k) {
auto yv = y.load1(0, 4*i+k);
auto xv = x.load1(0, 4*i+k);
acc[k] = QFBase::acc(acc[k], yv, xv);
}
}
for (int i = 4*(nb/4); i < nb; ++i) {
auto yv = y.load1(0, i);
auto xv = x.load1(0, i);
acc[0] = QFBase::acc(acc[0], yv, xv);
}
for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
auto yv = y.load_tail(0, i);
auto xv = x.load_tail(0, i);
acc[0] = QFBase::acc(acc[0], yv, xv);
}
acc[0] = QFBase::add(acc[0], acc[1]);
acc[2] = QFBase::add(acc[2], acc[3]);
acc[0] = QFBase::add(acc[0], acc[2]);
info.store(ix0, 0, QFBase::hsum(acc[0]));
}
template <typename Qy, typename Qx>
inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
int nb = n/QFBase::k_step;
@@ -7165,12 +7199,17 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
template <int nrc_y, typename FloatX, typename FloatY>
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const char * cx = (const char *)vx;
if constexpr (nrc_y == 1) {
for (int ix = 0; ix < nrc_x; ++ix) {
mul_mat_Qx_Qy_Mx1<QFT<FloatY, 1>, QFT<FloatX, 1>>(n, cx, bx, ix, info);
}
} else {
#ifdef __AVX512F__
constexpr int k_nx = 5;
#else
constexpr int k_nx = 2;
#endif
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
}
@@ -7185,6 +7224,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
#endif
}
}
}
#ifdef __AVX512BF16__