mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 19:31:48 +00:00
Refactor iqk: Factor out float GEMM (AVX2/AVX512)
This commit is contained in:
@@ -258,8 +258,12 @@ set (GGML_HEADERS_IQK iqk/iqk_config.h)
|
||||
if (GGML_IQK_MUL_MAT)
|
||||
message(STATUS "Using optimized iqk matrix multiplications")
|
||||
add_compile_definitions(GGML_USE_IQK_MULMAT)
|
||||
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h)
|
||||
set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp
|
||||
iqk/iqk_flash_attn.cpp
|
||||
iqk/iqk_gemm_floats.cpp)
|
||||
set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h
|
||||
iqk/iqk_flash_impl.h
|
||||
iqk/iqk_gemm_floats.h)
|
||||
if (GGML_IQK_FLASH_ATTENTION)
|
||||
message(STATUS "Enabling IQK Flash Attention kernels")
|
||||
add_compile_definitions(GGML_IQK_FLASH_ATTENTION)
|
||||
|
||||
@@ -79,8 +79,6 @@ struct Perf {
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
typedef struct {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
|
||||
@@ -2,7 +2,566 @@
|
||||
|
||||
#ifdef IQK_IMPLEMENT
|
||||
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
|
||||
|
||||
struct QFBase {
|
||||
#ifdef __AVX512F__
|
||||
constexpr static int k_step = 16;
|
||||
using Data = __m512;
|
||||
using Acc = __m512;
|
||||
static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }
|
||||
static inline Data load(const float * x) { return _mm512_loadu_ps(x); }
|
||||
static inline Data load(const ggml_bf16_t * x) {
|
||||
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));
|
||||
}
|
||||
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
return _mm512_fmadd_ps(y, x, prev);
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm512_mul_ps(y, x);
|
||||
}
|
||||
static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
|
||||
static inline float hsum(Acc acc) {
|
||||
return _mm512_reduce_add_ps(acc);
|
||||
}
|
||||
template <typename Float>
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
|
||||
}
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
|
||||
auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline __m128 hsum_r4(Acc acc) {
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));
|
||||
return _mm_add_ps(sum1, sum2);
|
||||
}
|
||||
#else
|
||||
constexpr static int k_step = 8;
|
||||
using Data = __m256;
|
||||
using Acc = __m256;
|
||||
static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
|
||||
static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
|
||||
static inline Data load(const ggml_bf16_t * x) {
|
||||
return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));
|
||||
}
|
||||
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
return _mm256_fmadd_ps(y, x, prev);
|
||||
}
|
||||
static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
|
||||
auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm256_mul_ps(y, x);
|
||||
}
|
||||
static inline float hsum(Acc acc) {
|
||||
return hsum_float_8(acc);
|
||||
}
|
||||
static inline __m128 hsum_r4(Acc acc) {
|
||||
return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
|
||||
}
|
||||
template <typename Float>
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
|
||||
}
|
||||
#endif
|
||||
static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
|
||||
static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
|
||||
static inline __m128 load128(const ggml_bf16_t * x) {
|
||||
return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Float, int nrc_in> struct QFT final : public QFBase {
|
||||
constexpr static int nrc = nrc_in;
|
||||
QFT(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);
|
||||
}
|
||||
QFT(const char * cx, size_t bx) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);
|
||||
}
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
|
||||
IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {
|
||||
xv[0] = load1(ix+0, i);
|
||||
xv[1] = load1(ix+1, i);
|
||||
xv[2] = load1(ix+2, i);
|
||||
xv[3] = load1(ix+3, i);
|
||||
#ifdef __AVX512F__
|
||||
auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
|
||||
auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);
|
||||
xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
|
||||
xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
|
||||
xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
|
||||
xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
|
||||
#else
|
||||
auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);
|
||||
auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);
|
||||
xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
#endif
|
||||
}
|
||||
const Float * y[nrc];
|
||||
};
|
||||
|
||||
// TBD if we want this
|
||||
//template <typename Qy, typename Qx>
|
||||
//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
// static_assert(Qy::nrc == 1);
|
||||
// int nb = n/QFBase::k_step;
|
||||
// int nb4 = n/4;
|
||||
// Qy y(info);
|
||||
// Qx x(cx + ix0*bx, bx);
|
||||
// QFBase::Data xv[2*Qx::nrc];
|
||||
// QFBase::Acc acc[2*Qx::nrc];
|
||||
// auto yv1 = y.load1(0, 0);
|
||||
// auto yv2 = y.load1(0, 1);
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// xv[2*ix+0] = x.load1(ix, 0);
|
||||
// xv[2*ix+1] = x.load1(ix, 1);
|
||||
// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
|
||||
// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
|
||||
// }
|
||||
// for (int i = 1; i < nb/2; ++i) {
|
||||
// yv1 = y.load1(0, 2*i+0);
|
||||
// yv2 = y.load1(0, 2*i+1);
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// xv[2*ix+0] = x.load1(ix, 2*i+0);
|
||||
// xv[2*ix+1] = x.load1(ix, 2*i+1);
|
||||
// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
|
||||
// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
|
||||
// }
|
||||
// }
|
||||
// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
|
||||
// yv1 = y.load_tail(0, i);
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// xv[ix] = x.load_tail(ix, i);
|
||||
// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
|
||||
// }
|
||||
// }
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
|
||||
//}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
int nb4 = n/4;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
|
||||
yv = y.load_tail(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load_tail(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load_tail(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
static_assert(Qx::nrc%4 == 0);
|
||||
int nb = D/QFBase::k_step;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {};
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix);
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) {
|
||||
auto yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) {
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy]));
|
||||
}
|
||||
}
|
||||
|
||||
// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
|
||||
// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
|
||||
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
|
||||
template <int nrc_y, typename FloatX, typename FloatY>
|
||||
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const char * cx = (const char *)vx;
|
||||
// TBD if we want this
|
||||
//if constexpr (nrc_y == 1) {
|
||||
// constexpr int k_nx = 2;
|
||||
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
// mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||
// }
|
||||
// if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
|
||||
// int nx = nrc_x - lastx;
|
||||
// switch (nx) {
|
||||
// case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
|
||||
// case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
|
||||
// case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
|
||||
// }
|
||||
// //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
|
||||
// }
|
||||
// return;
|
||||
//}
|
||||
#ifdef __AVX512F__
|
||||
constexpr int k_nx = 5;
|
||||
#else
|
||||
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
|
||||
#endif
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
#ifdef __AVX512F__
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
#else
|
||||
if constexpr (nrc_y == 1) {
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
} else {
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
template <int nrc_y>
|
||||
static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%16 == 0);
|
||||
const ggml_bf16_t * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
for (int ix = 0; ix < nrc_x/32; ++ix) {
|
||||
__m512 acc[2*nrc_y] = {};
|
||||
__m512bh qx[8];
|
||||
const ggml_bf16_t * b8_1 = (const ggml_bf16_t *)((const char *)vx + (32*ix+ 0)*bx);
|
||||
const ggml_bf16_t * b8_2 = (const ggml_bf16_t *)((const char *)vx + (32*ix+16)*bx);
|
||||
for (int ib = 0; ib < n/8; ++ib) {
|
||||
qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+0);
|
||||
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+1);
|
||||
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+2);
|
||||
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_1+4*ib+3);
|
||||
qx[4] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+0);
|
||||
qx[5] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+1);
|
||||
qx[6] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+2);
|
||||
qx[7] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8_2+4*ib+3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
|
||||
//auto y = _mm512_broadcast_i32x4(y128);
|
||||
auto y256 = MM256_SET_M128I(y128, y128);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
|
||||
acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[2*iy+0] = _mm512_dpbf16_ps(acc[2*iy+0], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[4], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[5], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[6], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[2*iy+1] = _mm512_dpbf16_ps(acc[2*iy+1], qx[7], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(32*ix+ 0, iy, acc[2*iy+0]);
|
||||
info.store(32*ix+16, iy, acc[2*iy+1]);
|
||||
}
|
||||
}
|
||||
for (int ix = 32*(nrc_x/32); ix < nrc_x; ix += 16) {
|
||||
__m512 acc[nrc_y] = {};
|
||||
__m512bh qx[4];
|
||||
const ggml_bf16_t * b8 = (const ggml_bf16_t *)((const char *)vx + (ix+0)*bx);
|
||||
for (int ib = 0; ib < n/8; ++ib) {
|
||||
qx[0] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+0);
|
||||
qx[1] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+1);
|
||||
qx[2] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+2);
|
||||
qx[3] = (__m512bh)_mm512_loadu_si512((const __m512i *)b8+4*ib+3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y128 = _mm_loadu_si128((const __m128i*)y[iy]+ib);
|
||||
auto y256 = MM256_SET_M128I(y128, y128);
|
||||
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[0], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[1], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[2], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[iy] = _mm512_dpbf16_ps(acc[iy], qx[3], (__m512bh)_mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct QFBaseBF16 {
|
||||
constexpr static int k_step = 32;
|
||||
using Data = __m512bh;
|
||||
using Acc = __m512;
|
||||
static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
|
||||
static inline Acc acc(Acc prev, Data y, Data x) {
|
||||
return _mm512_dpbf16_ps(prev, y, x);
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm512_dpbf16_ps(_mm512_setzero_ps(), y, x);
|
||||
}
|
||||
static inline float hsum(Acc acc) {
|
||||
return _mm512_reduce_add_ps(acc);
|
||||
}
|
||||
};
|
||||
template <int nrc_in> struct QFTBF16 final : public QFBaseBF16 {
|
||||
constexpr static int nrc = nrc_in;
|
||||
QFTBF16(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
}
|
||||
QFTBF16(const char * cx, size_t bx) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx);
|
||||
}
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
const ggml_bf16_t * y[nrc];
|
||||
};
|
||||
|
||||
template <int nrc_y, int nrc_x>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBaseBF16::k_step;
|
||||
QFTBF16<nrc_y> y(info);
|
||||
QFTBF16<nrc_x> x(cx + ix0*bx, bx);
|
||||
QFBaseBF16::Data xv[nrc_x];
|
||||
QFBaseBF16::Acc acc[nrc_x*nrc_y];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBaseBF16::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBaseBF16::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
|
||||
const char * cx = (const char *)vx;
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_MxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
if constexpr (nrc_y <= 2) {
|
||||
if (nx >= 4) {
|
||||
mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info);
|
||||
last_x += 4;
|
||||
if (last_x == nrc_x) return;
|
||||
nx = nrc_x - last_x;
|
||||
}
|
||||
}
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
template <typename FloatX, typename FloatY>
|
||||
void set_mul_mat_f(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
for (auto& f : funcs) f = nullptr;
|
||||
funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;
|
||||
funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;
|
||||
funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;
|
||||
funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;
|
||||
#ifndef __AVX512F__
|
||||
funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
void set_mul_mat_bf16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
for (auto& f : funcs) f = nullptr;
|
||||
funcs[0] = mul_mat_fX_fY_T<1>;
|
||||
funcs[1] = mul_mat_fX_fY_T<2>;
|
||||
funcs[2] = mul_mat_fX_fY_T<3>;
|
||||
funcs[3] = mul_mat_fX_fY_T<4>;
|
||||
funcs[4] = mul_mat_fX_fY_T<5>;
|
||||
}
|
||||
void set_mul_mat_bf16_r16(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
|
||||
for (auto& f : funcs) f = nullptr;
|
||||
funcs[0] = mul_mat_bf16_r16_bf16<1>;
|
||||
funcs[1] = mul_mat_bf16_r16_bf16<2>;
|
||||
funcs[2] = mul_mat_bf16_r16_bf16<3>;
|
||||
funcs[3] = mul_mat_bf16_r16_bf16<4>;
|
||||
funcs[4] = mul_mat_bf16_r16_bf16<5>;
|
||||
funcs[5] = mul_mat_bf16_r16_bf16<6>;
|
||||
funcs[6] = mul_mat_bf16_r16_bf16<7>;
|
||||
funcs[7] = mul_mat_bf16_r16_bf16<8>;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels) {
|
||||
|
||||
if (typeA == GGML_TYPE_BF16) {
|
||||
if (ne00 % 32) return false;
|
||||
switch (typeB) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: set_mul_mat_bf16(kernels); break;
|
||||
#else
|
||||
case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(kernels); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(kernels); break;
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (typeA == GGML_TYPE_BF16_R16) {
|
||||
if (ne00 % 16) return false;
|
||||
switch (typeB) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: set_mul_mat_bf16_r16(kernels); break;
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
|
||||
if (ne00 % 4) return false;
|
||||
}
|
||||
if (typeA == GGML_TYPE_F16) {
|
||||
switch (typeB) {
|
||||
case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(kernels); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(kernels); break;
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (typeA == GGML_TYPE_F32) {
|
||||
switch (typeB) {
|
||||
case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(kernels); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<float, float>(kernels); break;
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "iqk_mul_mat.h"
|
||||
#include "iqk_quantize.h"
|
||||
#include "iqk_flash_impl.h"
|
||||
#include "iqk_gemm_floats.h"
|
||||
|
||||
#define GGML_COMMON_IMPL_C
|
||||
#include "ggml-common.h"
|
||||
@@ -43,116 +44,10 @@
|
||||
// For fp16/fp32 matri multiplications tiling is used to improve
|
||||
// performance.
|
||||
|
||||
#define FA_TIMING 0
|
||||
|
||||
#include <utility>
|
||||
#include <array>
|
||||
#if FA_TIMING
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
struct Perf {
|
||||
using TimePoint = std::chrono::time_point<std::chrono::high_resolution_clock>;
|
||||
std::array<double, 5> times = {};
|
||||
std::mutex mutex;
|
||||
bool report;
|
||||
static auto cur_time() { return std::chrono::high_resolution_clock::now(); }
|
||||
inline void accum(int what, const TimePoint& t1) {
|
||||
auto t2 = cur_time();
|
||||
auto dt = delta(t1, t2);
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
times[what] += dt;
|
||||
}
|
||||
inline void accum_nolock(int what, const TimePoint& t1) {
|
||||
auto t2 = cur_time();
|
||||
auto dt = delta(t1, t2);
|
||||
times[what] += dt;
|
||||
}
|
||||
inline void add(const Perf& other) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
for (int i = 0; i < int(times.size()); ++i) times[i] += other.times[i];
|
||||
}
|
||||
Perf(bool r) : report(r) {}
|
||||
~Perf() {
|
||||
if (report) {
|
||||
double tot = 0;
|
||||
for (auto& t : times) tot += t;
|
||||
if (!tot) return;
|
||||
printf("======================= Timing: %g ms in total\n", tot);
|
||||
for (int i = 0; i < int(times.size()); ++i) {
|
||||
if (times[i]) {
|
||||
printf("%d: %g ms -> %g%c\n", i, times[i], 100*times[i]/tot, '%');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static Perf& instance() {
|
||||
static Perf p(true);
|
||||
return p;
|
||||
}
|
||||
static double delta(const TimePoint& t1, const TimePoint& t2) {
|
||||
return 1e-6*std::chrono::duration_cast<std::chrono::nanoseconds>(t2-t1).count();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
typedef struct {
|
||||
int32_t i1;
|
||||
int32_t i2;
|
||||
} mmid_row_mapping;
|
||||
|
||||
struct DataInfo {
|
||||
float * s;
|
||||
const char * cy;
|
||||
size_t bs;
|
||||
size_t by;
|
||||
int cur_y = 0;
|
||||
int ne11;
|
||||
const mmid_row_mapping * row_mapping = nullptr;
|
||||
size_t bs2 = 0;
|
||||
|
||||
inline const char * src1_row(int iy) const {
|
||||
if (!row_mapping) return cy + (cur_y + iy)*by;
|
||||
int i11 = row_mapping[cur_y + iy].i1 % ne11;
|
||||
int i12 = row_mapping[cur_y + iy].i2;
|
||||
return cy + (i11 + i12*ne11)*by;
|
||||
}
|
||||
|
||||
inline void store(int ix, int iy, float result) const {
|
||||
*(dst_row(iy) + ix) = result;
|
||||
}
|
||||
#ifdef __AVX__
|
||||
inline void store(int ix, int iy, __m128 result) const {
|
||||
_mm_storeu_ps(dst_row(iy) + ix, result);
|
||||
}
|
||||
inline void store(int ix, int iy, __m256 result) const {
|
||||
_mm256_storeu_ps(dst_row(iy) + ix, result);
|
||||
}
|
||||
#endif
|
||||
#ifdef __AVX512F__
|
||||
inline void store(int ix, int iy, __m512 result) const {
|
||||
_mm512_storeu_ps(dst_row(iy) + ix, result);
|
||||
}
|
||||
#endif
|
||||
#ifdef __ARM_NEON
|
||||
inline void store(int ix, int iy, float32x4_t result) const {
|
||||
vst1q_f32(dst_row(iy) + ix, result);
|
||||
}
|
||||
#endif
|
||||
inline float * dst_row(int iy) const {
|
||||
if (!row_mapping) return s + (cur_y + iy)*bs;
|
||||
int i12 = row_mapping[cur_y + iy].i2;
|
||||
int i1 = row_mapping[cur_y + iy].i1;
|
||||
int i2 = i12;
|
||||
return s + i1*bs + i2*bs2;
|
||||
}
|
||||
};
|
||||
|
||||
typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x);
|
||||
|
||||
struct MulMat {
|
||||
std::array<mul_mat_t, 8> funcs = {};
|
||||
std::array<mul_mat_t, IQK_MAX_NY> funcs = {};
|
||||
mul_mat_t func16 = nullptr;
|
||||
inline void mul_mat_NxM(int n, const void * vx, size_t bx, DataInfo& info, int nrc_x, int nrc_y) {
|
||||
#ifdef __aarch64__
|
||||
@@ -457,7 +352,7 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
|
||||
if (Nx >= 256 && Nx%32 == 0) {
|
||||
int nx32 = Nx/32;
|
||||
int nchunk = nx32*ne02;
|
||||
if (r2 <= 8) {
|
||||
if (r2 <= IQK_MAX_NY) {
|
||||
MulMat mm;
|
||||
if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
|
||||
int ny = mm.funcs.size();
|
||||
@@ -9254,478 +9149,6 @@ struct Q6_0_1_Unpacker final : public Q_Unpacker<block_q6_0, ScaleHelperQ_0_1<32
|
||||
inline static int block_size() { return QK6_0; }
|
||||
};
|
||||
|
||||
// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
|
||||
|
||||
struct QFBase {
|
||||
#ifdef __AVX512F__
|
||||
constexpr static int k_step = 16;
|
||||
using Data = __m512;
|
||||
using Acc = __m512;
|
||||
static inline Data load(const ggml_half * x) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)x)); }
|
||||
static inline Data load(const float * x) { return _mm512_loadu_ps(x); }
|
||||
static inline Data load(const ggml_bf16_t * x) {
|
||||
return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)x)), 16));
|
||||
}
|
||||
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
return _mm512_fmadd_ps(y, x, prev);
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm512_mul_ps(y, x);
|
||||
}
|
||||
static inline Acc add(Acc x, Acc y) { return _mm512_add_ps(x, y); }
|
||||
static inline float hsum(Acc acc) {
|
||||
return _mm512_reduce_add_ps(acc);
|
||||
}
|
||||
template <typename Float>
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
|
||||
}
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
|
||||
auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline __m128 hsum_r4(Acc acc) {
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));
|
||||
return _mm_add_ps(sum1, sum2);
|
||||
}
|
||||
#else
|
||||
constexpr static int k_step = 8;
|
||||
using Data = __m256;
|
||||
using Acc = __m256;
|
||||
static inline Data load(const ggml_half * x) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)x)); }
|
||||
static inline Data load(const float * x) { return _mm256_loadu_ps(x); }
|
||||
static inline Data load(const ggml_bf16_t * x) {
|
||||
return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i*)x)), 16));
|
||||
}
|
||||
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
return _mm256_fmadd_ps(y, x, prev);
|
||||
}
|
||||
static inline Acc add(Acc x, Acc y) { return _mm256_add_ps(x, y); }
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
|
||||
auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm256_mul_ps(y, x);
|
||||
}
|
||||
static inline float hsum(Acc acc) {
|
||||
return hsum_float_8(acc);
|
||||
}
|
||||
static inline __m128 hsum_r4(Acc acc) {
|
||||
return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
|
||||
}
|
||||
template <typename Float>
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
|
||||
}
|
||||
#endif
|
||||
static inline __m128 load128(const ggml_half * x) { return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x)); }
|
||||
static inline __m128 load128(const float * x) { return _mm_loadu_ps(x); }
|
||||
static inline __m128 load128(const ggml_bf16_t * x) {
|
||||
return _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i*)x)), 16));
|
||||
}
|
||||
};
|
||||
template <typename Float, int nrc_in> struct QFT final : public QFBase {
|
||||
constexpr static int nrc = nrc_in;
|
||||
QFT(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)info.src1_row(iy);
|
||||
}
|
||||
QFT(const char * cx, size_t bx) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const Float *)(cx + iy*bx);
|
||||
}
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
|
||||
IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {
|
||||
xv[0] = load1(ix+0, i);
|
||||
xv[1] = load1(ix+1, i);
|
||||
xv[2] = load1(ix+2, i);
|
||||
xv[3] = load1(ix+3, i);
|
||||
#ifdef __AVX512F__
|
||||
auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
|
||||
auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);
|
||||
xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
|
||||
xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
|
||||
xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
|
||||
xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
|
||||
#else
|
||||
auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);
|
||||
auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);
|
||||
xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
#endif
|
||||
}
|
||||
const Float * y[nrc];
|
||||
};
|
||||
|
||||
// TBD if we want this
|
||||
//template <typename Qy, typename Qx>
|
||||
//IQK_NOINLINE void mul_mat_Qx_Qy_Mx1(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
// static_assert(Qy::nrc == 1);
|
||||
// int nb = n/QFBase::k_step;
|
||||
// int nb4 = n/4;
|
||||
// Qy y(info);
|
||||
// Qx x(cx + ix0*bx, bx);
|
||||
// QFBase::Data xv[2*Qx::nrc];
|
||||
// QFBase::Acc acc[2*Qx::nrc];
|
||||
// auto yv1 = y.load1(0, 0);
|
||||
// auto yv2 = y.load1(0, 1);
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// xv[2*ix+0] = x.load1(ix, 0);
|
||||
// xv[2*ix+1] = x.load1(ix, 1);
|
||||
// acc[2*ix+0] = QFBase::acc_first(yv1, xv[2*ix+0]);
|
||||
// acc[2*ix+1] = QFBase::acc_first(yv2, xv[2*ix+1]);
|
||||
// }
|
||||
// for (int i = 1; i < nb/2; ++i) {
|
||||
// yv1 = y.load1(0, 2*i+0);
|
||||
// yv2 = y.load1(0, 2*i+1);
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// xv[2*ix+0] = x.load1(ix, 2*i+0);
|
||||
// xv[2*ix+1] = x.load1(ix, 2*i+1);
|
||||
// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[2*ix+0]);
|
||||
// acc[2*ix+1] = QFBase::acc(acc[2*ix+1], yv2, xv[2*ix+1]);
|
||||
// }
|
||||
// }
|
||||
// for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
|
||||
// yv1 = y.load_tail(0, i);
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
// xv[ix] = x.load_tail(ix, i);
|
||||
// acc[2*ix+0] = QFBase::acc(acc[2*ix+0], yv1, xv[ix]);
|
||||
// }
|
||||
// }
|
||||
// for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, 0, QFBase::hsum(QFBase::add(acc[2*ix+0], acc[2*ix+1])));
|
||||
//}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
int nb4 = n/4;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int i = (QFBase::k_step/4)*nb; i < nb4; ++i) {
|
||||
yv = y.load_tail(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load_tail(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load_tail(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
static_assert(Qx::nrc%4 == 0);
|
||||
int nb = D/QFBase::k_step;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {};
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix);
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) {
|
||||
auto yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) {
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy]));
|
||||
}
|
||||
}
|
||||
|
||||
// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
|
||||
// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
|
||||
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
|
||||
template <int nrc_y, typename FloatX, typename FloatY>
|
||||
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const char * cx = (const char *)vx;
|
||||
// TBD if we want this
|
||||
//if constexpr (nrc_y == 1) {
|
||||
// constexpr int k_nx = 2;
|
||||
// for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
// mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||
// }
|
||||
// if (int lastx = k_nx*(nrc_x/k_nx); lastx < nrc_x) {
|
||||
// int nx = nrc_x - lastx;
|
||||
// switch (nx) {
|
||||
// case 1: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info); break;
|
||||
// case 2: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, lastx, info); break;
|
||||
// case 3: mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, lastx, info); break;
|
||||
// }
|
||||
// //mul_mat_Qx_Qy_Mx1<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, lastx, info);
|
||||
// }
|
||||
// return;
|
||||
//}
|
||||
#ifdef __AVX512F__
|
||||
constexpr int k_nx = 5;
|
||||
#else
|
||||
constexpr int k_nx = nrc_y == 1 ? 4 : 2;
|
||||
#endif
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, k_nx>>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
#ifdef __AVX512F__
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 4>>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
#else
|
||||
if constexpr (nrc_y == 1) {
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 2>>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 3>>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
} else {
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<QFT<FloatY, nrc_y>, QFT<FloatX, 1>>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
struct QFBaseBF16 {
|
||||
constexpr static int k_step = 32;
|
||||
using Data = __m512bh;
|
||||
using Acc = __m512;
|
||||
static inline Data load(const ggml_bf16_t * x) { return __m512bh(_mm512_loadu_si512((const __m512i *)x)); }
|
||||
//static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
static inline Acc acc(Acc prev, Data y, Data x) {
|
||||
return _mm512_dpbf16_ps(prev, y, x);
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm512_dpbf16_ps(_mm512_setzero_ps(), y, x);
|
||||
}
|
||||
static inline float hsum(Acc acc) {
|
||||
return _mm512_reduce_add_ps(acc);
|
||||
}
|
||||
};
|
||||
template <int nrc_in> struct QFTBF16 final : public QFBaseBF16 {
|
||||
constexpr static int nrc = nrc_in;
|
||||
QFTBF16(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)info.src1_row(iy);
|
||||
}
|
||||
QFTBF16(const char * cx, size_t bx) {
|
||||
for (int iy = 0; iy < nrc; ++iy) y[iy] = (const ggml_bf16_t *)(cx + iy*bx);
|
||||
}
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
const ggml_bf16_t * y[nrc];
|
||||
};
|
||||
|
||||
template <int nrc_y, int nrc_x>
|
||||
IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBaseBF16::k_step;
|
||||
QFTBF16<nrc_y> y(info);
|
||||
QFTBF16<nrc_x> x(cx + ix0*bx, bx);
|
||||
QFBaseBF16::Data xv[nrc_x];
|
||||
QFBaseBF16::Acc acc[nrc_x*nrc_y];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBaseBF16::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBaseBF16::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < nrc_y; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < nrc_x; ++ix) acc[nrc_x*iy + ix] = QFBaseBF16::acc(acc[nrc_x*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < nrc_x; ++ix) info.store(ix0+ix, iy, QFBaseBF16::hsum(acc[nrc_x*iy+ix]));
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_fX_fY_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
constexpr int k_nx = nrc_y <= 2 ? 8 : 5;
|
||||
const char * cx = (const char *)vx;
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
mul_mat_Qx_Qy_MxN<nrc_y, k_nx>(n, cx, bx, ix*k_nx, info);
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
int nx = nrc_x - last_x;
|
||||
if constexpr (nrc_y <= 2) {
|
||||
if (nx >= 4) {
|
||||
mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info);
|
||||
last_x += 4;
|
||||
if (last_x == nrc_x) return;
|
||||
nx = nrc_x - last_x;
|
||||
}
|
||||
}
|
||||
switch (nx) {
|
||||
case 1: mul_mat_Qx_Qy_MxN<nrc_y, 1>(n, cx, bx, last_x, info); break;
|
||||
case 2: mul_mat_Qx_Qy_MxN<nrc_y, 2>(n, cx, bx, last_x, info); break;
|
||||
case 3: mul_mat_Qx_Qy_MxN<nrc_y, 3>(n, cx, bx, last_x, info); break;
|
||||
case 4: mul_mat_Qx_Qy_MxN<nrc_y, 4>(n, cx, bx, last_x, info); break;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
//
|
||||
// Tiled Q8_0 x Q8_0 implementation. Not used as the templated legacy quant implementation
|
||||
// above is faster. Left behind so we remember we tried.
|
||||
//
|
||||
template <int nrc> struct Q80 {
|
||||
constexpr static int nrc_y = nrc;
|
||||
Q80(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy);
|
||||
}
|
||||
IQK_ALWAYS_INLINE __m256i load1(int iy, int i) const { return _mm256_loadu_si256((const __m256i *)y[iy][i].qs); }
|
||||
IQK_ALWAYS_INLINE float scale(int iy, int i) const { return GGML_FP16_TO_FP32(y[iy][i].d); }
|
||||
|
||||
const block_q8_0 * y[nrc_y];
|
||||
};
|
||||
inline __m256i mul_q80(__m256i x, __m256i y) {
|
||||
auto ux = _mm256_sign_epi8(x, x);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
return _mm256_dpbusd_epi32(_mm256_setzero_si256(), ux, _mm256_sign_epi8(y, x));
|
||||
#else
|
||||
return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(ux, _mm256_sign_epi8(y, x)));
|
||||
#endif
|
||||
}
|
||||
template <int nrc_y>
|
||||
void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK8_0 == 0);
|
||||
constexpr int k_nx = 4;
|
||||
int nb = n/QK8_0;
|
||||
Q80<nrc_y> q8(info);
|
||||
const block_q8_0 * x[k_nx];
|
||||
float ds[k_nx];
|
||||
__m256 acc[k_nx*nrc_y];
|
||||
__m256i xv[k_nx];
|
||||
for (int ix = 0; ix < nrc_x/k_nx; ++ix) {
|
||||
int ix0 = k_nx*ix;
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
x[kx] = (const block_q8_0 *)((const char *)vx + (ix0 + kx)*bx);
|
||||
ds[kx] = GGML_FP16_TO_FP32(x[kx][0].d);
|
||||
xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][0].qs);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto yv = q8.load1(iy, 0);
|
||||
float d = q8.scale(iy, 0);
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
auto dot = mul_q80(yv, xv[kx]);
|
||||
acc[k_nx*iy + kx] = _mm256_mul_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot));
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
ds[kx] = GGML_FP16_TO_FP32(x[kx][i].d);
|
||||
xv[kx] = _mm256_loadu_si256((const __m256i *)x[kx][i].qs);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto yv = q8.load1(iy, i);
|
||||
float d = q8.scale(iy, i);
|
||||
for (int kx = 0; kx < k_nx; ++kx) {
|
||||
auto dot = mul_q80(yv, xv[kx]);
|
||||
acc[k_nx*iy + kx] = _mm256_fmadd_ps(_mm256_set1_ps(ds[kx]*d), _mm256_cvtepi32_ps(dot), acc[k_nx*iy + kx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
for (int kx = 0; kx < k_nx; ++kx) info.store(ix0+kx, iy, hsum_float_8(acc[k_nx*iy+kx]));
|
||||
}
|
||||
}
|
||||
int last_x = k_nx*(nrc_x/k_nx);
|
||||
if (last_x == nrc_x) return;
|
||||
// TODO: handle remaining rows
|
||||
}
|
||||
|
||||
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
|
||||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
|
||||
@@ -9862,88 +9285,12 @@ template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FloatX, typename FloatY>
|
||||
void set_mul_mat_f(MulMat& mm) {
|
||||
for (auto& f : mm.funcs) f = nullptr;
|
||||
mm.funcs[0] = mul_mat_fX_fY_T<1, FloatX, FloatY>;
|
||||
mm.funcs[1] = mul_mat_fX_fY_T<2, FloatX, FloatY>;
|
||||
mm.funcs[2] = mul_mat_fX_fY_T<3, FloatX, FloatY>;
|
||||
mm.funcs[3] = mul_mat_fX_fY_T<4, FloatX, FloatY>;
|
||||
mm.funcs[4] = mul_mat_fX_fY_T<5, FloatX, FloatY>;
|
||||
#ifndef __AVX512F__
|
||||
mm.funcs[5] = mul_mat_fX_fY_T<6, FloatX, FloatY>;
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX512BF16__
|
||||
void set_mul_mat_bf16(MulMat& mm) {
|
||||
for (auto& f : mm.funcs) f = nullptr;
|
||||
mm.funcs[0] = mul_mat_fX_fY_T<1>;
|
||||
mm.funcs[1] = mul_mat_fX_fY_T<2>;
|
||||
mm.funcs[2] = mul_mat_fX_fY_T<3>;
|
||||
mm.funcs[3] = mul_mat_fX_fY_T<4>;
|
||||
mm.funcs[4] = mul_mat_fX_fY_T<5>;
|
||||
}
|
||||
void set_mul_mat_bf16_r16(MulMat& mm) {
|
||||
for (auto& f : mm.funcs) f = nullptr;
|
||||
mm.funcs[0] = mul_mat_bf16_r16_bf16<1>;
|
||||
mm.funcs[1] = mul_mat_bf16_r16_bf16<2>;
|
||||
mm.funcs[2] = mul_mat_bf16_r16_bf16<3>;
|
||||
mm.funcs[3] = mul_mat_bf16_r16_bf16<4>;
|
||||
mm.funcs[4] = mul_mat_bf16_r16_bf16<5>;
|
||||
mm.funcs[5] = mul_mat_bf16_r16_bf16<6>;
|
||||
mm.funcs[6] = mul_mat_bf16_r16_bf16<7>;
|
||||
mm.funcs[7] = mul_mat_bf16_r16_bf16<8>;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
|
||||
(void)Ny;
|
||||
|
||||
if (typeA == GGML_TYPE_BF16) {
|
||||
if (ne00 % 32) return false;
|
||||
switch (typeB) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: set_mul_mat_bf16(mm); break;
|
||||
#else
|
||||
case GGML_TYPE_BF16: set_mul_mat_f<ggml_bf16_t, ggml_bf16_t>(mm); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<ggml_bf16_t, float>(mm); break;
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (typeA == GGML_TYPE_BF16_R16) {
|
||||
if (ne00 % 16) return false;
|
||||
switch (typeB) {
|
||||
#ifdef __AVX512BF16__
|
||||
case GGML_TYPE_BF16: set_mul_mat_bf16_r16(mm); break;
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32) {
|
||||
if (ne00 % 4) return false;
|
||||
}
|
||||
if (typeA == GGML_TYPE_F16) {
|
||||
switch (typeB) {
|
||||
case GGML_TYPE_F16: set_mul_mat_f<ggml_half, ggml_half>(mm); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<ggml_half, float>(mm); break;
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (typeA == GGML_TYPE_F32) {
|
||||
switch (typeB) {
|
||||
case GGML_TYPE_F16: set_mul_mat_f<float, ggml_half>(mm); break;
|
||||
case GGML_TYPE_F32: set_mul_mat_f<float, float>(mm); break;
|
||||
default: return false;
|
||||
}
|
||||
return true;
|
||||
if (typeA == GGML_TYPE_F16 || typeA == GGML_TYPE_F32 || typeA == GGML_TYPE_BF16 || typeA == GGML_TYPE_BF16_R16) {
|
||||
return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
|
||||
}
|
||||
|
||||
auto expected_typeB = GGML_TYPE_Q8_K;
|
||||
|
||||
Reference in New Issue
Block a user