iqk_mul_mat: fp16 for Arm

~2% slower than tinyBLAS - not sure why.
This commit is contained in:
Kawrakow
2024-06-10 08:16:52 +02:00
parent 6ec0fcc5c7
commit baf6aaa31b
2 changed files with 113 additions and 3 deletions

View File

@@ -3761,6 +3761,93 @@ static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInf
mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);
}
struct QF16Base {
constexpr static int k_step = 8;
using Data = float16x8_t;
using Acc = float16x8_t;
static inline Data load(const __fp16 * x) { return vld1q_f16(x); }
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
return vfmaq_f16(prev, y, x);
}
static inline Acc acc_first(const Data& y, const Data& x) {
return vmulq_f16(y, x);
}
//constexpr static int k_step = 16;
//using Data = float16x8x2_t;
//static inline Data load(const __fp16 * x) { return vld1q_f16_x2(x); }
//static inline Acc acc(Acc prev, const Data& y, const Data& x) {
// return vfmaq_f16(vfmaq_f16(prev, y.val[0], x.val[0]), y.val[1], x.val[1]);
//}
//static inline Acc acc_first(const Data& y, const Data& x) {
// return vfmaq_f16(vmulq_f16(y.val[0], x.val[0]), y.val[1], x.val[1]);
//}
static inline float hsum(Acc acc) {
float32x4_t sum = vcvt_f32_f16(vadd_f16(vget_low_f16(acc), vget_high_f16(acc)));
return vaddvq_f32(sum);
}
};
template <int nrc> struct QF16 final : public QF16Base {
constexpr static int nrc_y = nrc;
QF16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)info.src1_row(iy);
}
QF16(const char * cx, size_t bx) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const __fp16 *)(cx + iy*bx);
}
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
const __fp16 * y[nrc_y];
};
template <int nrc_y, int nrc_x>
IQK_NOINLINE void mul_mat_f16_f16_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
assert(n%QF16Base::k_step == 0);
int nb = n/QF16Base::k_step;
QF16<nrc_y> y(info);
QF16<nrc_x> x(cx + ix0*bx, bx);
QF16Base::Data xv[nrc_x];
QF16Base::Acc acc[nrc_x*nrc_y];
auto yv = y.load1(0, 0);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, 0);
acc[ix] = QF16Base::acc_first(yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, 0);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc_first(yv, xv[ix]);
}
for (int i = 1; i < nb; ++i) {
yv = y.load1(0, i);
for (int ix = 0; ix < nrc_x; ++ix) {
xv[ix] = x.load1(ix, i);
acc[ix] = QF16Base::acc(acc[ix], yv, xv[ix]);
}
for (int iy = 1; iy < nrc_y; ++iy) {
yv = y.load1(iy, i);
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QF16Base::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QF16Base::hsum(acc[nrc_x*iy+ix]));
}
template <int nrc_y>
void mul_mat_f16_f16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QF16Base::k_step == 0);
constexpr int k_nx = 5;
const char * cx = (const char *)vx;
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
mul_mat_f16_f16_NxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
}
int last_x = k_nx*(nrc_x/k_nx);
if (last_x == nrc_x) return;
int nx = nrc_x - last_x;
switch (nx) {
case 1: mul_mat_f16_f16_NxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
case 2: mul_mat_f16_f16_NxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
case 3: mul_mat_f16_f16_NxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
case 4: mul_mat_f16_f16_NxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
}
}
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
if constexpr (std::is_same_v<Dequantizer, DequantizerQ40> || std::is_same_v<Dequantizer, DequantizerQ50> ||
std::is_same_v<Dequantizer, DequantizerQ80>) {
@@ -3798,6 +3885,19 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& m, int& row_size_q8, int /*Ny*/) {
row_size_q8 = ggml_row_size(GGML_TYPE_Q8_K, ne00);
if (typeA == GGML_TYPE_F16) {
for (auto& f : m.funcs) f = nullptr;
m.funcs[0] = mul_mat_f16_f16_T<1>;
m.funcs[1] = mul_mat_f16_f16_T<2>;
m.funcs[2] = mul_mat_f16_f16_T<3>;
m.funcs[3] = mul_mat_f16_f16_T<4>;
m.funcs[4] = mul_mat_f16_f16_T<5>;
//m.funcs[5] = mul_mat_f16_f16_T<6>;
//m.funcs[6] = mul_mat_f16_f16_T<7>;
row_size_q8 = ggml_row_size(GGML_TYPE_F16, ne00);
return true;
}
switch (typeA) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);

View File

@@ -866,10 +866,20 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
if (Ctype != GGML_TYPE_F32)
return false;
if (task == GGML_TASK_TYPE_COMPUTE && k >= 256 && Atype == GGML_TYPE_F16 && Btype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
return true;
if (task == GGML_TASK_TYPE_COMPUTE && k >= 256 && Atype == GGML_TYPE_F16) {
#if defined __AVX2__ && defined __FMA__
if (Btype == GGML_TYPE_F32) {
if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
return true;
}
}
#elif defined __ARM_FEATURE_FP16_VECTOR_ARITHMETIC && defined __ARM_FEATURE_FMA
if (Btype == GGML_TYPE_F16) {
if (iqk_mul_mat(m, n, k, Atype, A, B, (float *)C, ldc, ith, nth)) {
return true;
}
}
#endif
}
switch (Atype) {