mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Slightly better gemv for not repacked fp16
This commit is contained in:
@@ -6971,6 +6971,7 @@ struct QFBase {
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
|
||||
}
|
||||
static inline Data add(Data x, Data y) { return _mm512_add_ps(x, y); }
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
@@ -6999,6 +7000,7 @@ struct QFBase {
|
||||
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
return _mm256_fmadd_ps(y, x, prev);
|
||||
}
|
||||
static inline Data add(Data x, Data y) { return _mm256_add_ps(x, y); }
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
@@ -7110,6 +7112,38 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
static_assert(Qy::nrc == 1);
|
||||
static_assert(Qx::nrc == 1);
|
||||
int nb = n/QFBase::k_step;
|
||||
int nb4 = n/4;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Acc acc[4] = {};
|
||||
for (int i = 0; i < nb/4; ++i) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = y.load1(0, 4*i+k);
|
||||
auto xv = x.load1(0, 4*i+k);
|
||||
acc[k] = QFBase::acc(acc[k], yv, xv);
|
||||
}
|
||||
}
|
||||
for (int i = 4*(nb/4); i < nb; ++i) {
|
||||
auto yv = y.load1(0, i);
|
||||
auto xv = x.load1(0, i);
|
||||
acc[0] = QFBase::acc(acc[0], yv, xv);
|
||||
}
|
||||
for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
|
||||
auto yv = y.load_tail(0, i);
|
||||
auto xv = x.load_tail(0, i);
|
||||
acc[0] = QFBase::acc(acc[0], yv, xv);
|
||||
}
|
||||
acc[0] = QFBase::add(acc[0], acc[1]);
|
||||
acc[2] = QFBase::add(acc[2], acc[3]);
|
||||
acc[0] = QFBase::add(acc[0], acc[2]);
|
||||
info.store(ix0, 0, QFBase::hsum(acc[0]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
@@ -7165,12 +7199,17 @@ inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, co
|
||||
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
|
||||
template <int nrc_y, typename FloatX, typename FloatY>
|
||||
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const char * cx = (const char *)vx;
|
||||
if constexpr (nrc_y == 1) {
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
mul_mat_Qx_Qy_Mx1<QFT<FloatY, 1>, QFT<FloatX, 1>>(n, cx, bx, ix, info);
|
||||
}
|
||||
} else {
|
||||
#ifdef __AVX512F__
|
||||
constexpr int k_nx = 5;
|
||||
#else
|
||||
constexpr int k_nx = 2;
|
||||
#endif
|
||||
const char * cx = (const char *)vx;
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
@@ -7185,6 +7224,7 @@ void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, in
|
||||
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
|
||||
Reference in New Issue
Block a user