mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-01-26 17:20:01 +00:00
iqk_mul_mat(NEON): special case for n not divisible by 8
Else fp16 PP performance drops by nearly a factor of 2 compared to what we had before.
This commit is contained in:
@@ -4235,11 +4235,10 @@ template <int nrc> struct QF16 final : public QF16Base {
|
||||
const __fp16 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <int nrc_y, int nrc_x>
|
||||
template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
|
||||
IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
assert(n%QF16Base::k_step == 0);
|
||||
int nb = n/QF16Base::k_step;
|
||||
int nb4 = n/4;
|
||||
QF16<nrc_y> y(info);
|
||||
QF16<nrc_x> x(cx + ix0*bx, bx);
|
||||
QF16Base::Data xv[nrc_x];
|
||||
@@ -4264,15 +4263,18 @@ IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
|
||||
yv = y.load_tail(0, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load_tail(ix, i);
|
||||
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load_tail(iy, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
|
||||
if constexpr (!is_multiple_of_k_step) {
|
||||
int nb4 = n/4;
|
||||
for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
|
||||
yv = y.load_tail(0, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load_tail(ix, i);
|
||||
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load_tail(iy, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
|
||||
@@ -4283,17 +4285,32 @@ void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
GGML_ASSERT(n%4 == 0);
|
||||
constexpr int k_nx = 5;
|
||||
const char * cx = (const char *)vx;
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_f16_f16_NxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
switch (nx) {
|
||||
case 1: mul_mat_f16_f16_NxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_f16_f16_NxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_f16_f16_NxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_f16_f16_NxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
||||
if (n%QF16Base::k_step == 0) {
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_f16_f16_NxN<nrc_y, k_nx, true>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
switch (nx) {
|
||||
case 1: mul_mat_f16_f16_NxN<nrc_y, 1, true>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_f16_f16_NxN<nrc_y, 2, true>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_f16_f16_NxN<nrc_y, 3, true>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_f16_f16_NxN<nrc_y, 4, true>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
} else {
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_f16_f16_NxN<nrc_y, k_nx, false>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
switch (nx) {
|
||||
case 1: mul_mat_f16_f16_NxN<nrc_y, 1, false>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_f16_f16_NxN<nrc_y, 2, false>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_f16_f16_NxN<nrc_y, 3, false>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_f16_f16_NxN<nrc_y, 4, false>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user