From 2eb9e212be2dc3e780bc65607e736999eafd6f85 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 11 Sep 2024 07:05:52 +0200 Subject: [PATCH] NEON Flash Attention - convert Q to f16 before computing Q*K --- ggml/src/iqk/iqk_mul_mat.cpp | 64 +++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index cae02ed1..907f1271 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6886,8 +6886,8 @@ struct FlashQKfp32 { constexpr static bool is_small_head = D <= (F16::num_registers/2)*F16::block_size; - template > - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + template , typename q_float> + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, F16::Data * qv, F16::Data * vk, FlashMS& fms) { // q index is q_step*i1 + m1 // k index is k_step*k1 + l1 @@ -6910,8 +6910,8 @@ struct FlashQKfp32 { } } - template > - static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const float * q, const char * mask, + template , typename q_float> + static inline void mult_mask_kq_one(int l1, int m1, int stride_q, int stride_m, const q_float * q, const char * mask, F16::Data * vk, FlashMS& fms) { // q index is q_step*i1 + m1 // k index is k_step*k1 + l1 @@ -6928,8 +6928,8 @@ struct FlashQKfp32 { fms.cache[k_step*m1 + l1] = F16::reduce_add(vsum); } - template > - static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template , typename q_float> + static inline void mult_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { F16::Data qv[D/F16::block_size]; F16::Data vk[D/(F16::block_size/2)]; @@ -6941,9 +6941,9 @@ struct FlashQKfp32 { } } - template > + template , typename q_float> static inline void mult_mask_kq_l(const KHelper& kh, int stride_q, int stride_m, - const float * q, const char * mask, FlashMS& fms) { + const q_float * q, const char * mask, FlashMS& fms) { F16::Data vk[D/F16::block_size]; for (int l1 = 0; l1 < k_step; ++l1) { kh.load(l1, vk); @@ -6953,8 +6953,8 @@ struct FlashQKfp32 { } } - template > - static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template , typename q_float> + static inline void mult_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { F16::Data qv[D/F16::block_size]; F16::Data vk[D/(F16::block_size/2)]; @@ -6966,9 +6966,9 @@ struct FlashQKfp32 { } } - template > + template , typename q_float> static inline void mult_mask_kq_l(int nq, const KHelper& kh, int stride_q, int stride_m, - const float * q, const char * mask, FlashMS& fms) { + const q_float * q, const char * mask, FlashMS& fms) { F16::Data vk[D/F16::block_size]; for (int l1 = 0; l1 < k_step; ++l1) { kh.load(l1, vk); @@ -6978,8 +6978,8 @@ struct FlashQKfp32 { } } - template - static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template + static inline void multiply_mask_kq(const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { if constexpr (is_small_head) { mult_mask_kq(kh, stride_q, stride_m, q, mask, fms); @@ -6993,8 +6993,8 @@ struct FlashQKfp32 { } } - template - static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const float * q, const char * mask, + template + static inline void multiply_mask_kq(int nq, const KHelper& kh, int stride_q, int stride_m, const q_float * q, const char * mask, FlashMS& fms) { if constexpr (is_small_head) { mult_mask_kq(nq, kh, stride_q, stride_m, q, mask, fms); @@ -7007,6 +7007,21 @@ struct FlashQKfp32 { fms.update_M_S(j, vk); } } + +#ifdef __aarch64__ + static inline void convert(int nq, int stride_q, const float * q, float16_t * q_f16) { + for (int i = 0; i < nq; ++i) { + for (int j = 0; j < D; j += 8) { + auto val1_f32 = vld1q_f32(q + j + 0); + auto val2_f32 = vld1q_f32(q + j + 4); + auto val_f16 = vcombine_f16(vcvt_f16_f32(val1_f32), vcvt_f16_f32(val2_f32)); + vst1q_f16(q_f16 + j, val_f16); + } + q += stride_q; + q_f16 += D; + } + } +#endif }; template @@ -7014,13 +7029,23 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in FlashMS& fms, FlashQKV& fqkv, const float * q, const char * mask, float * qkv) { +#ifdef __aarch64__ + float16_t q_f16[D*q_step]; +#endif for (int i1 = 0; i1 < nq1/q_step; ++i1) { fms.init_qstep(); kh.reset_block(); vh.reset_block(); +#ifdef __aarch64__ + KQHelper::convert(q_step, stride_q, q, q_f16); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#ifdef __aarch64__ + KQHelper::multiply_mask_kq(kh, D, stride_m, q_f16, mr, fms); +#else KQHelper::multiply_mask_kq(kh, stride_q, stride_m, q, mr, fms); +#endif fqkv.accumulate_qkv(vh, fms); kh.next_block(); vh.next_block(); @@ -7037,9 +7062,16 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in fms.init_qstep(); kh.reset_block(); vh.reset_block(); +#ifdef __aarch64__ + KQHelper::convert(n_left, stride_q, q, q_f16); +#endif auto mr = mask; for (int k1 = 0; k1 < nk1/k_step; ++k1) { +#ifdef __aarch64__ + KQHelper::multiply_mask_kq(n_left, kh, D, stride_m, q_f16, mr, fms); +#else KQHelper::multiply_mask_kq(n_left, kh, stride_q, stride_m, q, mr, fms); +#endif fqkv.accumulate_qkv(n_left, vh, fms); kh.next_block(); vh.next_block();