From 4ff2c6d18852a5945e88efb369946eaa4fa3e97e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 12 Sep 2024 10:38:44 +0200 Subject: [PATCH] NEON Flash Attention: quantized K*Q for q8_0 This makes quite a bit of difference: For Gemma2-2b PP-8192 is 228 t/s with quantized K*Q vs 178 t/s when converting things to fp16 and using fp16 matrix multiplication. We have PP-512 = 307 t/s, so PP-8192 is now ~75% of the performance of PP-512. In contrast, llama.cpp with Q8_0 cache is 38% of PP-512. --- ggml/src/iqk/iqk_mul_mat.cpp | 53 +++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index e16cceb9..a18c8690 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5461,6 +5461,27 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer { }; +// TODO: handle case where row size is not a multiple of 128 +struct DequantizerQ80_x4 final : public BaseLegacyDequantizer { + + DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.b[0] = vld1q_s8(x[i].qs); + bits.b[1] = vld1q_s8(x[i].qs+16); + } + + inline float16x4_t new_block(int i) { + auto scale = vld1_f16((const float16_t *)x[i].d); + for (int k = 0; k < 4; ++k) { + bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k); + bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16); + } + return scale; + } + +}; + struct DequantizerQ51 final : public BaseLegacyDequantizer { DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} @@ -5529,9 +5550,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i q8.process_scales(i, deq, sc16, acc); sum_4(i, deq, q8, sc16, acc); } - for (int i = 4*(nb/4); i < nb; ++i) { - q8.process_1_block(i, deq, acc); - } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq, acc); + //} for (int iy = 0; iy < Q8::nrc_y; ++iy) { info.store(ix, iy, vaddvq_f32(acc[iy])); @@ -5591,9 +5612,9 @@ inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8.process_scales(i, deq1, sc16, acc); sum_4(i, deq1, q8, sc16, acc); } - for (int i = 4*(nb/4); i < nb; ++i) { - q8.process_1_block(i, deq1, acc); - } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq1, acc); + //} info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); } @@ -6802,6 +6823,7 @@ template struct HelperQ80 final : public BaseHelper { static_assert(step == QK8_0); using Base = BaseHelper; + using block_q8 = block_q8_0; HelperQ80(const char * data, int stride) : Base(data, stride) {} inline void load(int l1, F16::Data * vk) const { @@ -7616,6 +7638,10 @@ struct FlashQKfp32 { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); } + else if constexpr (std::is_same_v>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); @@ -7651,6 +7677,18 @@ struct FlashQKfp32 { case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; } } + else if constexpr (std::is_same_v>) { + DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { + case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + } + } else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { @@ -7801,7 +7839,8 @@ struct FlashAttn { template void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { - if constexpr (std::is_same_v> || std::is_same_v>) { + if constexpr (std::is_same_v> || std::is_same_v> || + std::is_same_v>) { compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else {