mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
iqk_mul_mat: turn on AVX512
It makes no difference on my Ryzen-7950X, but perhaps it will be beneficial for CPU's with real AVX512.
This commit is contained in:
@@ -94,7 +94,6 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf
|
|||||||
|
|
||||||
struct MulMat {
|
struct MulMat {
|
||||||
std::array<mul_mat_t, 8> funcs = {};
|
std::array<mul_mat_t, 8> funcs = {};
|
||||||
//std::array<mul_mat_t, 4> funcs = {};
|
|
||||||
inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {
|
inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
|
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
|
||||||
@@ -2155,6 +2154,22 @@ struct Q5_1_Unpacker final : public Q_Unpacker<block_q5_1, ScaleHelperQ_1, Q5_1_
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct QF32Base {
|
struct QF32Base {
|
||||||
|
#ifdef __AVX512F__
|
||||||
|
constexpr static int k_step = 16;
|
||||||
|
using Data = __m512;
|
||||||
|
using Acc = __m512;
|
||||||
|
static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }
|
||||||
|
static inline Data load(const float * x) { return _mm512_loadu_ps(x); }
|
||||||
|
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||||
|
return _mm512_fmadd_ps(y, x, prev);
|
||||||
|
}
|
||||||
|
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||||
|
return _mm512_mul_ps(y, x);
|
||||||
|
}
|
||||||
|
static inline float hsum(Acc acc) {
|
||||||
|
return _mm512_reduce_add_ps(acc);
|
||||||
|
}
|
||||||
|
#else
|
||||||
constexpr static int k_step = 8;
|
constexpr static int k_step = 8;
|
||||||
using Data = __m256;
|
using Data = __m256;
|
||||||
using Acc = __m256;
|
using Acc = __m256;
|
||||||
@@ -2169,6 +2184,7 @@ struct QF32Base {
|
|||||||
static inline float hsum(Acc acc) {
|
static inline float hsum(Acc acc) {
|
||||||
return hsum_float_8(acc);
|
return hsum_float_8(acc);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
template <int nrc> struct QF32y final : public QF32Base {
|
template <int nrc> struct QF32y final : public QF32Base {
|
||||||
constexpr static int nrc_y = nrc;
|
constexpr static int nrc_y = nrc;
|
||||||
@@ -2188,7 +2204,7 @@ template <int nrc> struct QF32x final : public QF32Base {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int nrc_y, int nrc_x>
|
template <int nrc_y, int nrc_x>
|
||||||
IQK_NOINLINE void mul_mat_f16_f32_NxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
IQK_NOINLINE void mul_mat_f16_f32_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||||
assert(n%QF16Base::k_step == 0);
|
assert(n%QF16Base::k_step == 0);
|
||||||
int nb = n/QF32Base::k_step;
|
int nb = n/QF32Base::k_step;
|
||||||
QF32y<nrc_y> y(info);
|
QF32y<nrc_y> y(info);
|
||||||
@@ -2228,18 +2244,17 @@ void mul_mat_f16_f32_T(int n, const void * vx, size_t bx, const DataInfo& info,
|
|||||||
#endif
|
#endif
|
||||||
const char * cx = (const char *)vx;
|
const char * cx = (const char *)vx;
|
||||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||||
mul_mat_f16_f32_NxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
|
mul_mat_f16_f32_MxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
|
||||||
}
|
}
|
||||||
int last_x = k_nx*(nrc_x/k_nx);
|
int last_x = k_nx*(nrc_x/k_nx);
|
||||||
if (last_x == nrc_x) return;
|
if (last_x == nrc_x) return;
|
||||||
int nx = nrc_x - last_x;
|
int nx = nrc_x - last_x;
|
||||||
switch (nx) {
|
switch (nx) {
|
||||||
case 1: mul_mat_f16_f32_NxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
|
case 1: mul_mat_f16_f32_MxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
|
||||||
case 2: mul_mat_f16_f32_NxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
|
#ifdef __AVX512F__
|
||||||
case 3: mul_mat_f16_f32_NxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
|
case 2: mul_mat_f16_f32_MxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
|
||||||
case 4: mul_mat_f16_f32_NxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
case 3: mul_mat_f16_f32_MxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
|
||||||
#ifndef __AVX512F__
|
case 4: mul_mat_f16_f32_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
||||||
case 5: mul_mat_f16_f32_NxN<nrc_y, 5>(n, cx, bx, last_x, info); break;
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2394,7 +2409,6 @@ bool MulMat::set_mul_mat(int typeA, int ne00, MulMat& mm, int& row_size_q8, int
|
|||||||
mm.funcs[2] = mul_mat_f16_f32_T<3>;
|
mm.funcs[2] = mul_mat_f16_f32_T<3>;
|
||||||
mm.funcs[3] = mul_mat_f16_f32_T<4>;
|
mm.funcs[3] = mul_mat_f16_f32_T<4>;
|
||||||
mm.funcs[4] = mul_mat_f16_f32_T<5>;
|
mm.funcs[4] = mul_mat_f16_f32_T<5>;
|
||||||
mm.funcs[4] = mul_mat_f16_f32_T<5>;
|
|
||||||
#ifndef __AVX512F__
|
#ifndef __AVX512F__
|
||||||
mm.funcs[5] = mul_mat_f16_f32_T<6>;
|
mm.funcs[5] = mul_mat_f16_f32_T<6>;
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user