NEON Flash Attention - convert Q to f16 before computing Q*K

This commit is contained in:
Iwan Kawrakow
2024-09-11 07:05:52 +02:00
parent 67bf083f9d
commit 2eb9e212be

View File

@@ -6886,8 +6886,8 @@ struct FlashQKfp32 {
constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size;
template <bool small = is_small_head, class = std::enable_if<small>>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
template <bool small = is_small_head, class = std::enable_if<small>, typename q_float>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
F16::Data * qv, F16::Data * vk, FlashMS<q_step, k_step>& fms) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
@@ -6910,8 +6910,8 @@ struct FlashQKfp32 {
}
}
template <bool small = is_small_head, class = std::enable_if<!small>>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask,
template <bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask,
F16::Data * vk, FlashMS<q_step, k_step>& fms) {
// q index is q_step*i1 + m1
// k index is k_step*k1 + l1
@@ -6928,8 +6928,8 @@ struct FlashQKfp32 {
fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum);
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
F16::Data qv[D/F16::block_size];
F16::Data vk[D/(F16::block_size/2)];
@@ -6941,9 +6941,9 @@ struct FlashQKfp32 {
}
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m,
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
F16::Data vk[D/F16::block_size];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
@@ -6953,8 +6953,8 @@ struct FlashQKfp32 {
}
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>>
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
template <typename KHelper, bool small = is_small_head, class = std::enable_if<small>, typename q_float>
static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
F16::Data qv[D/F16::block_size];
F16::Data vk[D/(F16::block_size/2)];
@@ -6966,9 +6966,9 @@ struct FlashQKfp32 {
}
}
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>>
template <typename KHelper, bool small = is_small_head, class = std::enable_if<!small>, typename q_float>
static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m,
const float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
const q_float * q, const char * mask, FlashMS<q_step, k_step>& fms) {
F16::Data vk[D/F16::block_size];
for (int l1 = 0; l1 < k_step; ++l1) {
kh.load(l1, vk);
@@ -6978,8 +6978,8 @@ struct FlashQKfp32 {
}
}
template <typename KHelper>
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
template <typename KHelper, typename q_float>
static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
if constexpr (is_small_head) {
mult_mask_kq(kh, stride_q, stride_m, q, mask, fms);
@@ -6993,8 +6993,8 @@ struct FlashQKfp32 {
}
}
template <typename KHelper>
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask,
template <typename KHelper, typename q_float>
static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask,
FlashMS<q_step, k_step>& fms) {
if constexpr (is_small_head) {
mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms);
@@ -7007,6 +7007,21 @@ struct FlashQKfp32 {
fms.update_M_S(j, vk);
}
}
#ifdef __aarch64__
static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) {
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < D; j += 8) {
auto val1_f32 = vld1q_f32(q + j + 0);
auto val2_f32 = vld1q_f32(q + j + 4);
auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32));
vst1q_f16(q_f16 + j, val_f16);
}
q += stride_q;
q_f16 += D;
}
}
#endif
};
template <int D, int q_step, int k_step, typename KHelper, typename VHelper, typename KQHelper>
@@ -7014,13 +7029,23 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
FlashMS<q_step, k_step>& fms,
FlashQKV<D, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv) {
#ifdef __aarch64__
float16_t q_f16[D*q_step];
#endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
fms.init_qstep();
kh.reset_block();
vh.reset_block();
#ifdef __aarch64__
KQHelper::convert(q_step, stride_q, q, q_f16);
#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
#ifdef __aarch64__
KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms);
#else
KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms);
#endif
fqkv.accumulate_qkv(vh, fms);
kh.next_block();
vh.next_block();
@@ -7037,9 +7062,16 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
fms.init_qstep();
kh.reset_block();
vh.reset_block();
#ifdef __aarch64__
KQHelper::convert(n_left, stride_q, q, q_f16);
#endif
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
#ifdef __aarch64__
KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms);
#else
KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms);
#endif
fqkv.accumulate_qkv(n_left, vh, fms);
kh.next_block();
vh.next_block();