mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-29 19:01:47 +00:00
NEON Flash Attention - convert Q to f16 before computing Q*K
This commit is contained in:
@@ -6886,8 +6886,8 @@ struct FlashQKfp32 {
|
|||||||
|
|
||||||
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
|
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
|
||||||
|
|
||||||
template <bool small = is_small_head, class = std::enable_if<small>>
|
template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
|
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||||
F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
||||||
// q index is q_step*i1 + m1
|
// q index is q_step*i1 + m1
|
||||||
// k index is k_step*k1 + l1
|
// k index is k_step*k1 + l1
|
||||||
@@ -6910,8 +6910,8 @@ struct FlashQKfp32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool small = is_small_head, class = std::enable_if<!small>>
|
template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||||
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
|
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||||
F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
F16::Data * vk, FlashMS<q_step, k_step>& fms) {
|
||||||
// q index is q_step*i1 + m1
|
// q index is q_step*i1 + m1
|
||||||
// k index is k_step*k1 + l1
|
// k index is k_step*k1 + l1
|
||||||
@@ -6928,8 +6928,8 @@ struct FlashQKfp32 {
|
|||||||
fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
|
fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
|
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||||
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||||
FlashMS<q_step, k_step>& fms) {
|
FlashMS<q_step, k_step>& fms) {
|
||||||
F16::Data qv[D/F16::block_size];
|
F16::Data qv[D/F16::block_size];
|
||||||
F16::Data vk[D/(F16::block_size/2)];
|
F16::Data vk[D/(F16::block_size/2)];
|
||||||
@@ -6941,9 +6941,9 @@ struct FlashQKfp32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
|
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||||
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
|
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
|
||||||
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||||
F16::Data vk[D/F16::block_size];
|
F16::Data vk[D/F16::block_size];
|
||||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||||
kh.load(l1, vk);
|
kh.load(l1, vk);
|
||||||
@@ -6953,8 +6953,8 @@ struct FlashQKfp32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
|
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
|
||||||
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||||
FlashMS<q_step, k_step>& fms) {
|
FlashMS<q_step, k_step>& fms) {
|
||||||
F16::Data qv[D/F16::block_size];
|
F16::Data qv[D/F16::block_size];
|
||||||
F16::Data vk[D/(F16::block_size/2)];
|
F16::Data vk[D/(F16::block_size/2)];
|
||||||
@@ -6966,9 +6966,9 @@ struct FlashQKfp32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
|
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
|
||||||
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
|
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
|
||||||
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
|
||||||
F16::Data vk[D/F16::block_size];
|
F16::Data vk[D/F16::block_size];
|
||||||
for (int l1 = 0; l1 < k_step; ++l1) {
|
for (int l1 = 0; l1 < k_step; ++l1) {
|
||||||
kh.load(l1, vk);
|
kh.load(l1, vk);
|
||||||
@@ -6978,8 +6978,8 @@ struct FlashQKfp32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KHelper>
|
template <typename KHelper, typename q_float>
|
||||||
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||||
FlashMS<q_step, k_step>& fms) {
|
FlashMS<q_step, k_step>& fms) {
|
||||||
if constexpr (is_small_head) {
|
if constexpr (is_small_head) {
|
||||||
mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
|
mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
|
||||||
@@ -6993,8 +6993,8 @@ struct FlashQKfp32 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename KHelper>
|
template <typename KHelper, typename q_float>
|
||||||
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
|
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
|
||||||
FlashMS<q_step, k_step>& fms) {
|
FlashMS<q_step, k_step>& fms) {
|
||||||
if constexpr (is_small_head) {
|
if constexpr (is_small_head) {
|
||||||
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
|
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
|
||||||
@@ -7007,6 +7007,21 @@ struct FlashQKfp32 {
|
|||||||
fms.update_M_S(j, vk);
|
fms.update_M_S(j, vk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef __aarch64__
|
||||||
|
static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
|
||||||
|
for (int i = 0; i < nq; ++i) {
|
||||||
|
for (int j = 0; j < D; j += 8) {
|
||||||
|
auto val1_f32 = vld1q_f32(q + j + 0);
|
||||||
|
auto val2_f32 = vld1q_f32(q + j + 4);
|
||||||
|
auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32));
|
||||||
|
vst1q_f16(q_f16 + j, val_f16);
|
||||||
|
}
|
||||||
|
q += stride_q;
|
||||||
|
q_f16 += D;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
|
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
|
||||||
@@ -7014,13 +7029,23 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
FlashMS<q_step, k_step>& fms,
|
FlashMS<q_step, k_step>& fms,
|
||||||
FlashQKV<D, q_step, k_step>& fqkv,
|
FlashQKV<D, q_step, k_step>& fqkv,
|
||||||
const float * q, const char * mask, float * qkv) {
|
const float * q, const char * mask, float * qkv) {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
float16_t q_f16[D*q_step];
|
||||||
|
#endif
|
||||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||||
fms.init_qstep();
|
fms.init_qstep();
|
||||||
kh.reset_block();
|
kh.reset_block();
|
||||||
vh.reset_block();
|
vh.reset_block();
|
||||||
|
#ifdef __aarch64__
|
||||||
|
KQHelper::convert(q_step, stride_q, q, q_f16);
|
||||||
|
#endif
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms);
|
||||||
|
#else
|
||||||
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
|
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
|
||||||
|
#endif
|
||||||
fqkv.accumulate_qkv(vh, fms);
|
fqkv.accumulate_qkv(vh, fms);
|
||||||
kh.next_block();
|
kh.next_block();
|
||||||
vh.next_block();
|
vh.next_block();
|
||||||
@@ -7037,9 +7062,16 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
fms.init_qstep();
|
fms.init_qstep();
|
||||||
kh.reset_block();
|
kh.reset_block();
|
||||||
vh.reset_block();
|
vh.reset_block();
|
||||||
|
#ifdef __aarch64__
|
||||||
|
KQHelper::convert(n_left, stride_q, q, q_f16);
|
||||||
|
#endif
|
||||||
auto mr = mask;
|
auto mr = mask;
|
||||||
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
|
||||||
|
#ifdef __aarch64__
|
||||||
|
KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms);
|
||||||
|
#else
|
||||||
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
|
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
|
||||||
|
#endif
|
||||||
fqkv.accumulate_qkv(n_left, vh, fms);
|
fqkv.accumulate_qkv(n_left, vh, fms);
|
||||||
kh.next_block();
|
kh.next_block();
|
||||||
vh.next_block();
|
vh.next_block();
|
||||||
|
|||||||
Reference in New Issue
Block a user