mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
FA: slightly faster V*softmax(K*Q)) also for fp16 K-cache
This commit is contained in:
@@ -6895,6 +6895,25 @@ struct QFBase {
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm512_insertf32x4(_mm512_setzero_ps(), load128(x), 0);
|
||||
}
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm512_fmadd_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
|
||||
auto acc = _mm512_mul_ps(xv[0], _mm512_shuffle_ps(yv, yv, 0x00));
|
||||
acc = _mm512_fmadd_ps(xv[1], _mm512_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm512_fmadd_ps(xv[2], _mm512_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm512_fmadd_ps(xv[3], _mm512_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline __m128 hsum_r4(Acc acc) {
|
||||
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 0), _mm512_extractf32x4_ps(acc, 1));
|
||||
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc, 2), _mm512_extractf32x4_ps(acc, 3));
|
||||
return _mm_add_ps(sum1, sum2);
|
||||
}
|
||||
#else
|
||||
constexpr static int k_step = 8;
|
||||
using Data = __m256;
|
||||
@@ -6904,12 +6923,29 @@ struct QFBase {
|
||||
static inline Acc acc(Acc prev, const Data& y, const Data& x) {
|
||||
return _mm256_fmadd_ps(y, x, prev);
|
||||
}
|
||||
static inline Acc acc_r4(Acc acc, const Data * xv, const Data& yv) {
|
||||
acc = _mm256_fmadd_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00), acc);
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_r4_first(const Data * xv, const Data& yv) {
|
||||
auto acc = _mm256_mul_ps(xv[0], _mm256_shuffle_ps(yv, yv, 0x00));
|
||||
acc = _mm256_fmadd_ps(xv[1], _mm256_shuffle_ps(yv, yv, 0x55), acc);
|
||||
acc = _mm256_fmadd_ps(xv[2], _mm256_shuffle_ps(yv, yv, 0xaa), acc);
|
||||
acc = _mm256_fmadd_ps(xv[3], _mm256_shuffle_ps(yv, yv, 0xff), acc);
|
||||
return acc;
|
||||
}
|
||||
static inline Acc acc_first(const Data& y, const Data& x) {
|
||||
return _mm256_mul_ps(y, x);
|
||||
}
|
||||
static inline float hsum(Acc acc) {
|
||||
return hsum_float_8(acc);
|
||||
}
|
||||
static inline __m128 hsum_r4(Acc acc) {
|
||||
return _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
|
||||
}
|
||||
template <typename Float>
|
||||
static inline Data load4Floats(const Float * x) {
|
||||
return _mm256_insertf128_ps(_mm256_setzero_ps(), load128(x), 0);
|
||||
@@ -6928,6 +6964,31 @@ template <typename Float, int nrc_in> struct QFT final : public QFBase {
|
||||
}
|
||||
IQK_ALWAYS_INLINE Data load1(int iy, int i) const { return load(y[iy] + k_step*i); }
|
||||
IQK_ALWAYS_INLINE Data load_tail(int iy, int i) const { return load4Floats(y[iy] + 4*i); }
|
||||
IQK_ALWAYS_INLINE void load_r4(int ix, int i, Data * xv) const {
|
||||
xv[0] = load1(ix+0, i);
|
||||
xv[1] = load1(ix+1, i);
|
||||
xv[2] = load1(ix+2, i);
|
||||
xv[3] = load1(ix+3, i);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto t0 = _mm512_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm512_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm512_unpackhi_ps(xv[0], xv[1]);
|
||||
auto t3 = _mm512_unpackhi_ps(xv[2], xv[3]);
|
||||
xv[0] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
|
||||
xv[1] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t0), _mm512_castps_pd(t1)));
|
||||
xv[2] = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
|
||||
xv[3] = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(t2), _mm512_castps_pd(t3)));
|
||||
#else
|
||||
auto t0 = _mm256_unpacklo_ps(xv[0], xv[1]);
|
||||
auto t1 = _mm256_unpacklo_ps(xv[2], xv[3]);
|
||||
auto t2 = _mm256_unpackhi_ps(xv[0], xv[1]);
|
||||
auto t3 = _mm256_unpackhi_ps(xv[2], xv[3]);
|
||||
xv[0] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
xv[1] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t1)));
|
||||
xv[2] = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
xv[3] = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t2), _mm256_castps_pd(t3)));
|
||||
#endif
|
||||
}
|
||||
const Float * y[nrc];
|
||||
};
|
||||
|
||||
@@ -6973,6 +7034,56 @@ IQK_NOINLINE void mul_mat_Qx_Qy_MxN(int n, const char * cx, size_t bx, int ix0,
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa(int n, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
int nb = n/QFBase::k_step;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc];
|
||||
auto yv = y.load1(0, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, 0);
|
||||
acc[ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, 0);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc_first(yv, xv[ix]);
|
||||
}
|
||||
for (int i = 1; i < nb; ++i) {
|
||||
yv = y.load1(0, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) {
|
||||
xv[ix] = x.load1(ix, i);
|
||||
acc[ix] = QFBase::acc(acc[ix], yv, xv[ix]);
|
||||
}
|
||||
for (int iy = 1; iy < Qy::nrc; ++iy) {
|
||||
yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc; ++ix) acc[Qx::nrc*iy + ix] = QFBase::acc(acc[Qx::nrc*iy + ix], yv, xv[ix]);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) for (int ix = 0; ix < Qx::nrc; ++ix) info.store(ix0+ix, iy, QFBase::hsum(acc[Qx::nrc*iy+ix]));
|
||||
}
|
||||
|
||||
template <typename Qy, typename Qx>
|
||||
inline void mul_mat_Qx_Qy_MxN_fa4(int D, const char * cx, size_t bx, int ix0, const DataInfo& info) {
|
||||
static_assert(Qx::nrc%4 == 0);
|
||||
int nb = D/QFBase::k_step;
|
||||
Qy y(info);
|
||||
Qx x(cx + ix0*bx, bx);
|
||||
QFBase::Data xv[Qx::nrc];
|
||||
QFBase::Acc acc[Qx::nrc*Qy::nrc/4] = {};
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) x.load_r4(4*ix, i, xv + 4*ix);
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) {
|
||||
auto yv = y.load1(iy, i);
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) acc[ix*Qy::nrc + iy] = QFBase::acc_r4(acc[ix*Qy::nrc + iy], xv + 4*ix, yv);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < Qy::nrc; ++iy) {
|
||||
for (int ix = 0; ix < Qx::nrc/4; ++ix) info.store(ix0+4*ix, iy, QFBase::hsum_r4(acc[ix*Qy::nrc + iy]));
|
||||
}
|
||||
}
|
||||
|
||||
// This will handle any of f16 x f32, f32 x f16, f16 x f16, f32 x f32, with computations done
|
||||
// in f32 (i.e., f16 is first converted to f32). It is easy to extend to computations done in
|
||||
// f16, but I don't have a CPU capable of f16 vector arithmetic, so not doing it for now.
|
||||
@@ -13066,147 +13177,194 @@ struct FlashQKfp32 {
|
||||
static_assert(k_step%F16::block_size == 0);
|
||||
static_assert(q_step <= 4 || q_step%4 == 0);
|
||||
|
||||
#ifdef __aarch64__
|
||||
constexpr static bool is_small_head = false;
|
||||
#else
|
||||
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
|
||||
#endif
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
fms.cache[k_step*m1 + l1 + 0] = fms.cache[k_step*m1 + l1 + 1] = -INFINITY;
|
||||
if (mp[l1+0] == fms.h_inf && mp[l1+1] == fms.h_inf) {
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
for (int i = 0; i < D/F16::block_size; ++i) qv[i] = F16::load(qr + F16::block_size*i);
|
||||
if (mp[l1+0] != fms.h_inf) {
|
||||
auto vsum = F16::zero();
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i], qv[i]);
|
||||
fms.cache[k_step*m1 + l1 + 0] = F16::reduce_add(vsum);
|
||||
}
|
||||
if (mp[l1+1] != fms.h_inf) {
|
||||
auto vsum = F16::zero();
|
||||
for (int i = 0; i < D/F16::block_size; ++i) vsum = F16::fmadd(vsum, vk[i+D/16], qv[i]);
|
||||
fms.cache[k_step*m1 + l1 + 1] = F16::reduce_add(vsum);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
||||
// q index is q_step*i1 + m1
|
||||
// k index is k_step*k1 + l1
|
||||
const ggml_half * mp = (const ggml_half *)(mask + stride_m*m1);
|
||||
if (mp[l1] == fms.h_inf) {
|
||||
fms.cache[k_step*m1 + l1] = -INFINITY;
|
||||
return;
|
||||
}
|
||||
auto qr = q + m1*stride_q;
|
||||
auto vsum = F16::zero();
|
||||
for (int i = 0; i < D/F16::block_size; ++i) {
|
||||
vsum = F16::fmadd(vsum, vk[i], F16::load(qr + F16::block_size*i));
|
||||
}
|
||||
fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data qv[D/F16::block_size];
|
||||
F16::Data vk[D/(F16::block_size/2)];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
kh.load_2(l1, vk);
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
|
||||
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data vk[D/F16::block_size];
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
kh.load(l1, vk);
|
||||
for (int m1 = 0; m1 < q_step; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data qv[D/F16::block_size];
|
||||
F16::Data vk[D/(F16::block_size/2)];
|
||||
for (int l1 = 0; l1 < k_step; l1 += 2) {
|
||||
kh.load_2(l1, vk);
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, qv, vk, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
|
||||
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||
F16::Data vk[D/F16::block_size];
|
||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||
kh.load(l1, vk);
|
||||
for (int m1 = 0; m1 < nq; ++m1) {
|
||||
mult_mask_kq_one(l1, m1, stride_q, stride_m, q, mask, vk, fms);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq_l(kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
constexpr int nrc_q = 8;
|
||||
constexpr int nrc_k = 8;
|
||||
#else
|
||||
// somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 8;
|
||||
#endif
|
||||
constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_Qx_Qy_MxN_fa<QFT<q_float, nrc_q>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if constexpr (qrem > 0) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, qrem>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_Qx_Qy_MxN_fa<QFT<q_float, qrem>, QFT<ggml_half, krem>>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
}
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 6;
|
||||
constexpr int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < q_step/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if constexpr (qrem > 0) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<qrem, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<qrem, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
}
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __AVX2__
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
if constexpr (is_small_head) {
|
||||
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
else {
|
||||
mult_mask_kq_l(nq, kh, stride_q, stride_m, q, mask, fms);
|
||||
}
|
||||
#ifdef __aarch64__
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
constexpr int nrc_q = 8;
|
||||
constexpr int nrc_k = 8;
|
||||
#else
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < nq; ++j) {
|
||||
fms.update_M_S(j, vk);
|
||||
}
|
||||
// somewhat surprisingly, nrc_q = 4, nrc_k = 8 is better than nrc_q = 8, nrc_k = 4
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 8;
|
||||
#endif
|
||||
static_assert(k_step%nrc_k == 0);
|
||||
int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, nrc_q>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
if (qrem > 0) {
|
||||
switch (qrem) {
|
||||
case 1: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 1>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 2>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 3: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 3>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case 4: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 4>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 5: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 5>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 6: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 6>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
case 7: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_Qx_Qy_MxN_fa4<QFT<q_float, 7>, QFT<ggml_half, nrc_k>>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
} break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
F16::Data vk[k_step/F16::block_size];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <typename KHelper, typename q_float>
|
||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||
FlashMS<q_step, k_step>& fms) {
|
||||
constexpr int nrc_q = 4;
|
||||
constexpr int nrc_k = 6;
|
||||
constexpr int krem = k_step - nrc_k*(k_step/nrc_k);
|
||||
const int qrem = q_step - nrc_q*(q_step/nrc_q);
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, stride_q*sizeof(q_float), 0, 1, nullptr};
|
||||
for (int iq = 0; iq < nq/nrc_q; ++iq) {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<nrc_q, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
info.cur_y += nrc_q;
|
||||
}
|
||||
switch (qrem) {
|
||||
case 0: break;
|
||||
case 1: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<1, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<1, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
} break;
|
||||
case 2: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<2, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<2, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
} break;
|
||||
case 3: {
|
||||
for (int ik = 0; ik < k_step/nrc_k; ++ik) {
|
||||
mul_mat_f16_f16_NxN<3, nrc_k, true>(D, kh.block, kh.stride, ik*nrc_k, info);
|
||||
}
|
||||
if constexpr (krem > 0) {
|
||||
mul_mat_f16_f16_NxN<3, krem, true>(D, kh.block, kh.stride, k_step - krem, info);
|
||||
}
|
||||
} break;
|
||||
}
|
||||
float32x4_t vk[k_step/4];
|
||||
for (int j = 0; j < q_step; ++j) {
|
||||
fms.update_M_S(j, vk, mask + stride_m*j);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef __aarch64__
|
||||
static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
|
||||
@@ -13827,7 +13985,7 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
|
||||
const float * q, const char * mask, float scale, float softcap, float * qkv) {
|
||||
|
||||
#if defined __AVX2__
|
||||
constexpr bool kUseLargeStepsQ = !std::is_same_v<KHelper, HelperF16<D, k_step>>;
|
||||
constexpr bool kUseLargeStepsQ = true; //!std::is_same_v<KHelper, HelperF16<D, k_step>>;
|
||||
#else
|
||||
constexpr bool kUseLargeStepsQ = true;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user