Refactor iqk: factor out floats (NEON)

This commit is contained in:
Iwan Kawrakow
2025-05-18 18:09:39 +03:00
parent c805a19202
commit f4ab917e9e
2 changed files with 382 additions and 367 deletions

View File

@@ -569,6 +569,384 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t,
#else
// ----------------------------------- __aarch64__ -----------------------------------------------
namespace {
struct QF16Base {
constexpr static int k_step = 8;
using Data = float16x8_t;
using Acc = float16x8_t;
static inline Data load(const __fp16 * x) { return vld1q_f16(x); }
static inline Data load4(const __fp16 * x) { return vcombine_f16(vld1_f16(x), vdup_n_f16(0)); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return vfmaq_f16(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
return vmulq_f16(y, x);
}
//constexpr static int k_step = 16;
//using Data = float16x8x2_t;
//static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); }
//static inline Acc acc(Acc prev, const Data& y, const Data& x) {
// return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]);
//}
//static inline Acc acc_first(const Data& y, const Data& x) {
// return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]);
//}
static inline float hsum(Acc acc) {
float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc)));
return vaddvq_f32(sum);
}
};
template <int nrc> struct QF16 final : public QF16Base {
using Base = QF16Base;
constexpr static int nrc_y = nrc;
QF16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
}
QF16(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); }
const __fp16 * y[nrc_y];
};
struct QBF16Base {
constexpr static int k_step = 4;
using Data = float32x4_t;
using Acc = float32x4_t;
static inline Data load(const uint16_t * x) { return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16(x)), 16)); }
static inline Data load4(const uint16_t * x) { return load(x); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return vfmaq_f32(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
return vmulq_f32(y, x);
}
static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
};
template <int nrc> struct QBF16 final : public QBF16Base {
using Base = QBF16Base;
constexpr static int nrc_y = nrc;
QBF16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)info.src1_row(iy);
}
QBF16(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
const uint16_t * y[nrc_y];
};
struct QF32Base {
constexpr static int k_step = 4;
using Data = float32x4_t;
using Acc = float32x4_t;
static inline Data load(const float * x) { return vld1q_f32(x); }
static inline Data load4(const float * x) { return load(x); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) { return vfmaq_f32(prev, y, x); }
static inline Acc acc_first(const Data& y, const Data& x) { return vmulq_f32(y, x); }
static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
};
template <int nrc> struct QF32 final : public QF32Base {
using Base = QF32Base;
constexpr static int nrc_y = nrc;
QF32(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
}
QF32(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
const float * y[nrc_y];
};
template <typename Qy, typename Qx, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_Qx_Qy_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
GGML_ASSERT(Qx::Base::k_step == Qy::Base::k_step);
int nb = n/Qx::Base::k_step;
Qy y(info);
Qx x(cx + ix0*bx, bx);
typename Qx::Base::Data xv[Qx::nrc_y];
typename Qx::Base::Acc acc[Qx::nrc_y*Qy::nrc_y];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < Qx::nrc_y; ++ix) {
xv[ix] = x.load1(ix, 0);
acc[ix] = Qx::Base::acc_first(yv, xv[ix]);
}
for (int iy = 1; iy < Qy::nrc_y; ++iy) {
yv = y.load1(iy, 0);
for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
yv = y.load1(0, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) {
xv[ix] = x.load1(ix, i);
acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < Qy::nrc_y; ++iy) {
yv = y.load1(iy, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
}
}
if constexpr (Qx::Base::k_step > 4 && !is_multiple_of_k_step) {
int nb4 = n/4;
for (int i = (Qx::Base::k_step/4)*nb; i < nb4; ++i) {
yv = y.load_tail(0, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) {
xv[ix] = x.load_tail(ix, i);
acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < Qy::nrc_y; ++iy) {
yv = y.load_tail(iy, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
}
}
}
for (int iy = 0; iy < Qy::nrc_y; ++iy) for (int ix = 0; ix < Qx::nrc_y; ++ix) info.store(ix0+ix, iy, Qx::Base::hsum(acc[Qx::nrc_y*iy+ix]));
}
template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
assert(n%QF16Base::k_step == 0);
int nb = n/QF16Base::k_step;
QF16<nrc_y> y(info);
QF16<nrc_x> x(cx + ix0*bx, bx);
QF16Base::Data xv[nrc_x];
QF16Base::Acc acc[nrc_x*nrc_y];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, 0);
acc[ix] = QF16Base::acc_first(yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, 0);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
yv = y.load1(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, i);
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, i);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
if constexpr (!is_multiple_of_k_step) {
int nb4 = n/4;
for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
yv = y.load_tail(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load_tail(ix, i);
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load_tail(iy, i);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
}
template <typename Qy, template<int> typename Qx>
void mul_mat_Qx_Qy_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
if (n%Qx<k_nx>::Base::k_step == 0) {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, true>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, true>(n, cx, bx, last_x, info); break;
case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, true>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, true>(n, cx, bx, last_x, info); break;
case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, true>(n, cx, bx, last_x, info); break;
}
} else {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, false>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, false>(n, cx, bx, last_x, info); break;
case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, false>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, false>(n, cx, bx, last_x, info); break;
case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, false>(n, cx, bx, last_x, info); break;
}
}
}
template <int nrc_y>
void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
if (n%QF16Base::k_step == 0) {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_f16_f16_NxN<nrc_y, k_nx, true>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_f16_f16_NxN<nrc_y, 1, true>(n, cx, bx, last_x, info); break;
case 2: mul_mat_f16_f16_NxN<nrc_y, 2, true>(n, cx, bx, last_x, info); break;
case 3: mul_mat_f16_f16_NxN<nrc_y, 3, true>(n, cx, bx, last_x, info); break;
case 4: mul_mat_f16_f16_NxN<nrc_y, 4, true>(n, cx, bx, last_x, info); break;
}
} else {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_f16_f16_NxN<nrc_y, k_nx, false>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_f16_f16_NxN<nrc_y, 1, false>(n, cx, bx, last_x, info); break;
case 2: mul_mat_f16_f16_NxN<nrc_y, 2, false>(n, cx, bx, last_x, info); break;
case 3: mul_mat_f16_f16_NxN<nrc_y, 3, false>(n, cx, bx, last_x, info); break;
case 4: mul_mat_f16_f16_NxN<nrc_y, 4, false>(n, cx, bx, last_x, info); break;
}
}
}
template <int nrc_x, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
assert(n%QF16Base::k_step == 0);
int nb = n/QF16Base::k_step;
QF16<1> y(info);
QF16<nrc_x> x(cx + ix0*bx, bx);
QF16Base::Acc acc[4*nrc_x];
auto yv = y.loadx(0, 0);
for (int ix = 0; ix < nrc_x; ++ix) {
for (int k = 0; k < 4; ++k) {
auto xv = x.load1(ix, k);
acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv);
}
}
for (int i = 1; i < nb/4; ++i) {
yv = y.loadx(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
for (int k = 0; k < 4; ++k) {
auto xv = x.load1(ix, 4*i+k);
acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv);
}
}
}
for (int i = 4*(nb/4); i < nb; ++i) {
auto yv1 = y.load1(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
auto xv1 = x.load1(ix, i);
acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
}
}
if constexpr (!is_multiple_of_k_step) {
int nb4 = n/4;
for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
auto yv1 = y.load_tail(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
auto xv1 = x.load_tail(ix, i);
acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
}
}
}
for (int ix = 0; ix < nrc_x; ++ix) {
auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]);
auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]);
info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2)));
}
}
// At least on my M2-Max the version below, which does the multiplication row-by-row, is faster.
// But let's keep this version commented out for now.
//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(n%4 == 0);
// constexpr int k_nx = 2;
// const char * cx = (const char *)vx;
// if (n%QF16Base::k_step == 0) {
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
// mul_mat_f16_f16_Nx1<k_nx, true>(n, cx, bx, ix*k_nx, info);
// }
// int last_x = k_nx*(nrc_x/k_nx);
// if (last_x == nrc_x) return;
// int nx = nrc_x - last_x;
// switch (nx) {
// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break;
// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break;
// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break;
// }
// } else {
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
// mul_mat_f16_f16_Nx1<k_nx, false>(n, cx, bx, ix*k_nx, info);
// }
// int last_x = k_nx*(nrc_x/k_nx);
// if (last_x == nrc_x) return;
// int nx = nrc_x - last_x;
// switch (nx) {
// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break;
// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break;
// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break;
// }
// }
//}
void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
const char * cx = (const char *)vx;
if (n%QF16Base::k_step == 0) {
for (int ix = 0; ix < nrc_x; ++ix) {
mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info);
}
} else {
for (int ix = 0; ix < nrc_x; ++ix) {
mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info);
}
}
}
}
bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
if (ne00%4 == 0) {
if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
for (auto& f : kernels) f = nullptr;
kernels[0] = mul_mat_f16_f16_1;
kernels[1] = mul_mat_f16_f16_T<2>;
kernels[2] = mul_mat_f16_f16_T<3>;
kernels[3] = mul_mat_f16_f16_T<4>;
kernels[4] = mul_mat_f16_f16_T<5>;
return true;
}
else if (typeA == GGML_TYPE_BF16 && typeB == GGML_TYPE_F32) {
for (auto& f : kernels) f = nullptr;
kernels[0] = mul_mat_Qx_Qy_T<QF32<1>, QBF16>;
kernels[1] = mul_mat_Qx_Qy_T<QF32<2>, QBF16>;
kernels[2] = mul_mat_Qx_Qy_T<QF32<3>, QBF16>;
kernels[3] = mul_mat_Qx_Qy_T<QF32<4>, QBF16>;
kernels[4] = mul_mat_Qx_Qy_T<QF32<5>, QBF16>;
return true;
}
}
return false;
}
#endif
#endif

View File

@@ -2221,351 +2221,6 @@ static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInf
mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);
}
struct QF16Base {
constexpr static int k_step = 8;
using Data = float16x8_t;
using Acc = float16x8_t;
static inline Data load(const __fp16 * x) { return vld1q_f16(x); }
static inline Data load4(const __fp16 * x) { return vcombine_f16(vld1_f16(x), vdup_n_f16(0)); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return vfmaq_f16(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
return vmulq_f16(y, x);
}
//constexpr static int k_step = 16;
//using Data = float16x8x2_t;
//static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); }
//static inline Acc acc(Acc prev, const Data& y, const Data& x) {
// return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]);
//}
//static inline Acc acc_first(const Data& y, const Data& x) {
// return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]);
//}
static inline float hsum(Acc acc) {
float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc)));
return vaddvq_f32(sum);
}
};
template <int nrc> struct QF16 final : public QF16Base {
using Base = QF16Base;
constexpr static int nrc_y = nrc;
QF16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
}
QF16(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4(y[iy] + 4*i); }
IQK_ALWAYS_INLINE float16x8x4_t loadx(int iy, int i) const { return vld1q_f16_x4(y[iy] + 4*k_step*i); }
const __fp16 * y[nrc_y];
};
struct QBF16Base {
constexpr static int k_step = 4;
using Data = float32x4_t;
using Acc = float32x4_t;
static inline Data load(const uint16_t * x) { return vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16(x)), 16)); }
static inline Data load4(const uint16_t * x) { return load(x); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return vfmaq_f32(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
return vmulq_f32(y, x);
}
static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
};
template <int nrc> struct QBF16 final : public QBF16Base {
using Base = QBF16Base;
constexpr static int nrc_y = nrc;
QBF16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)info.src1_row(iy);
}
QBF16(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const uint16_t *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
const uint16_t * y[nrc_y];
};
struct QF32Base {
constexpr static int k_step = 4;
using Data = float32x4_t;
using Acc = float32x4_t;
static inline Data load(const float * x) { return vld1q_f32(x); }
static inline Data load4(const float * x) { return load(x); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) { return vfmaq_f32(prev, y, x); }
static inline Acc acc_first(const Data& y, const Data& x) { return vmulq_f32(y, x); }
static inline float hsum(Acc acc) { return vaddvq_f32(acc); }
};
template <int nrc> struct QF32 final : public QF32Base {
using Base = QF32Base;
constexpr static int nrc_y = nrc;
QF32(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)info.src1_row(iy);
}
QF32(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load(y[iy] + 4*i); }
const float * y[nrc_y];
};
template <typename Qy, typename Qx, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_Qx_Qy_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
GGML_ASSERT(Qx::Base::k_step == Qy::Base::k_step);
int nb = n/Qx::Base::k_step;
Qy y(info);
Qx x(cx + ix0*bx, bx);
typename Qx::Base::Data xv[Qx::nrc_y];
typename Qx::Base::Acc acc[Qx::nrc_y*Qy::nrc_y];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < Qx::nrc_y; ++ix) {
xv[ix] = x.load1(ix, 0);
acc[ix] = Qx::Base::acc_first(yv, xv[ix]);
}
for (int iy = 1; iy < Qy::nrc_y; ++iy) {
yv = y.load1(iy, 0);
for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
yv = y.load1(0, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) {
xv[ix] = x.load1(ix, i);
acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < Qy::nrc_y; ++iy) {
yv = y.load1(iy, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
}
}
if constexpr (Qx::Base::k_step > 4 && !is_multiple_of_k_step) {
int nb4 = n/4;
for (int i = (Qx::Base::k_step/4)*nb; i < nb4; ++i) {
yv = y.load_tail(0, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) {
xv[ix] = x.load_tail(ix, i);
acc[ix] = Qx::Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < Qy::nrc_y; ++iy) {
yv = y.load_tail(iy, i);
for (int ix = 0; ix < Qx::nrc_y; ++ix) acc[Qx::nrc_y*iy + ix] = Qx::Base::acc(acc[Qx::nrc_y*iy + ix], yv, xv[ix]);
}
}
}
for (int iy = 0; iy < Qy::nrc_y; ++iy) for (int ix = 0; ix < Qx::nrc_y; ++ix) info.store(ix0+ix, iy, Qx::Base::hsum(acc[Qx::nrc_y*iy+ix]));
}
template <int nrc_y, int nrc_x, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
assert(n%QF16Base::k_step == 0);
int nb = n/QF16Base::k_step;
QF16<nrc_y> y(info);
QF16<nrc_x> x(cx + ix0*bx, bx);
QF16Base::Data xv[nrc_x];
QF16Base::Acc acc[nrc_x*nrc_y];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, 0);
acc[ix] = QF16Base::acc_first(yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, 0);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
yv = y.load1(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, i);
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, i);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
if constexpr (!is_multiple_of_k_step) {
int nb4 = n/4;
for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
yv = y.load_tail(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load_tail(ix, i);
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load_tail(iy, i);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
}
template <typename Qy, template<int> typename Qx>
void mul_mat_Qx_Qy_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
if (n%Qx<k_nx>::Base::k_step == 0) {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, true>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, true>(n, cx, bx, last_x, info); break;
case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, true>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, true>(n, cx, bx, last_x, info); break;
case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, true>(n, cx, bx, last_x, info); break;
}
} else {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_Qx_Qy_NxN<Qy, Qx<k_nx>, false>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_Qx_Qy_NxN<Qy, Qx<1>, false>(n, cx, bx, last_x, info); break;
case 2: mul_mat_Qx_Qy_NxN<Qy, Qx<2>, false>(n, cx, bx, last_x, info); break;
case 3: mul_mat_Qx_Qy_NxN<Qy, Qx<3>, false>(n, cx, bx, last_x, info); break;
case 4: mul_mat_Qx_Qy_NxN<Qy, Qx<4>, false>(n, cx, bx, last_x, info); break;
}
}
}
template <int nrc_y>
void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
if (n%QF16Base::k_step == 0) {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_f16_f16_NxN<nrc_y, k_nx, true>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_f16_f16_NxN<nrc_y, 1, true>(n, cx, bx, last_x, info); break;
case 2: mul_mat_f16_f16_NxN<nrc_y, 2, true>(n, cx, bx, last_x, info); break;
case 3: mul_mat_f16_f16_NxN<nrc_y, 3, true>(n, cx, bx, last_x, info); break;
case 4: mul_mat_f16_f16_NxN<nrc_y, 4, true>(n, cx, bx, last_x, info); break;
}
} else {
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_f16_f16_NxN<nrc_y, k_nx, false>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_f16_f16_NxN<nrc_y, 1, false>(n, cx, bx, last_x, info); break;
case 2: mul_mat_f16_f16_NxN<nrc_y, 2, false>(n, cx, bx, last_x, info); break;
case 3: mul_mat_f16_f16_NxN<nrc_y, 3, false>(n, cx, bx, last_x, info); break;
case 4: mul_mat_f16_f16_NxN<nrc_y, 4, false>(n, cx, bx, last_x, info); break;
}
}
}
template <int nrc_x, bool is_multiple_of_k_step>
IQK_NOINLINE void mul_mat_f16_f16_Nx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
assert(n%QF16Base::k_step == 0);
int nb = n/QF16Base::k_step;
QF16<1> y(info);
QF16<nrc_x> x(cx + ix0*bx, bx);
QF16Base::Acc acc[4*nrc_x];
auto yv = y.loadx(0, 0);
for (int ix = 0; ix < nrc_x; ++ix) {
for (int k = 0; k < 4; ++k) {
auto xv = x.load1(ix, k);
acc[4*ix+k] = QF16Base::acc_first(yv.val[k], xv);
}
}
for (int i = 1; i < nb/4; ++i) {
yv = y.loadx(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
for (int k = 0; k < 4; ++k) {
auto xv = x.load1(ix, 4*i+k);
acc[4*ix+k] = QF16Base::acc(acc[4*ix+k], yv.val[k], xv);
}
}
}
for (int i = 4*(nb/4); i < nb; ++i) {
auto yv1 = y.load1(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
auto xv1 = x.load1(ix, i);
acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
}
}
if constexpr (!is_multiple_of_k_step) {
int nb4 = n/4;
for (int i = (QF16Base::k_step/4)*nb; i < nb4; ++i) {
auto yv1 = y.load_tail(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
auto xv1 = x.load_tail(ix, i);
acc[4*ix] = QF16Base::acc(acc[4*ix], yv1, xv1);
}
}
}
for (int ix = 0; ix < nrc_x; ++ix) {
auto v1 = vaddq_f16(acc[4*ix+0], acc[4*ix+1]);
auto v2 = vaddq_f16(acc[4*ix+2], acc[4*ix+3]);
info.store(ix0+ix, 0, QF16Base::hsum(vaddq_f16(v1, v2)));
}
}
// At least on my M2-Max the version below, which does the multiplication row-by-row, is faster.
// But let's keep this version commented out for now.
//void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(n%4 == 0);
// constexpr int k_nx = 2;
// const char * cx = (const char *)vx;
// if (n%QF16Base::k_step == 0) {
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
// mul_mat_f16_f16_Nx1<k_nx, true>(n, cx, bx, ix*k_nx, info);
// }
// int last_x = k_nx*(nrc_x/k_nx);
// if (last_x == nrc_x) return;
// int nx = nrc_x - last_x;
// switch (nx) {
// case 1: mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, last_x, info); break;
// //case 2: mul_mat_f16_f16_Nx1<2, true>(n, cx, bx, last_x, info); break;
// //case 3: mul_mat_f16_f16_Nx1<3, true>(n, cx, bx, last_x, info); break;
// }
// } else {
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
// mul_mat_f16_f16_Nx1<k_nx, false>(n, cx, bx, ix*k_nx, info);
// }
// int last_x = k_nx*(nrc_x/k_nx);
// if (last_x == nrc_x) return;
// int nx = nrc_x - last_x;
// switch (nx) {
// case 1: mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, last_x, info); break;
// //case 2: mul_mat_f16_f16_Nx1<2, false>(n, cx, bx, last_x, info); break;
// //case 3: mul_mat_f16_f16_Nx1<3, false>(n, cx, bx, last_x, info); break;
// }
// }
//}
void mul_mat_f16_f16_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%4 == 0);
const char * cx = (const char *)vx;
if (n%QF16Base::k_step == 0) {
for (int ix = 0; ix < nrc_x; ++ix) {
mul_mat_f16_f16_Nx1<1, true>(n, cx, bx, ix, info);
}
} else {
for (int ix = 0; ix < nrc_x; ++ix) {
mul_mat_f16_f16_Nx1<1, false>(n, cx, bx, ix, info);
}
}
}
template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
@@ -4657,31 +4312,13 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
if (typeA == GGML_TYPE_F16 && typeB == GGML_TYPE_F16) {
if (ne00%4) return false;
for (auto& f : m.funcs) f = nullptr;
m.funcs[0] = mul_mat_f16_f16_1;
m.funcs[1] = mul_mat_f16_f16_T<2>;
m.funcs[2] = mul_mat_f16_f16_T<3>;
m.funcs[3] = mul_mat_f16_f16_T<4>;
m.funcs[4] = mul_mat_f16_f16_T<5>;
return true;
}
if (typeA == GGML_TYPE_BF16 && typeB == GGML_TYPE_F32) {
if (ne00%4) return false;
for (auto& f : m.funcs) f = nullptr;
m.funcs[0] = mul_mat_Qx_Qy_T<QF32<1>, QBF16>;
m.funcs[1] = mul_mat_Qx_Qy_T<QF32<2>, QBF16>;
m.funcs[2] = mul_mat_Qx_Qy_T<QF32<3>, QBF16>;
m.funcs[3] = mul_mat_Qx_Qy_T<QF32<4>, QBF16>;
m.funcs[4] = mul_mat_Qx_Qy_T<QF32<5>, QBF16>;
return true;
}
auto expected_Btype = GGML_TYPE_Q8_K;
switch (typeA) {
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_F32:
return iqk_set_kernels_float(ne00, typeA, typeB, m.funcs);
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K: