NEON Flash Attention: quantized K*Q for q8_0

This makes quite a bit of difference:
For Gemma2-2b PP-8192 is 228 t/s with quantized K*Q vs
178 t/s when converting things to fp16 and using fp16
matrix multiplication.
We have PP-512 = 307 t/s, so PP-8192 is now ~75% of the
performance of PP-512. In contrast, llama.cpp with Q8_0
cache is 38% of PP-512.
This commit is contained in:
Iwan Kawrakow
2024-09-12 10:38:44 +02:00
parent c3dc5a27bb
commit 4ff2c6d188

View File

@@ -5461,6 +5461,27 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {
};
// TODO: handle case where row size is not a multiple of 128
struct DequantizerQ80_x4 final : public BaseLegacyDequantizer<block_q8_0_x4> {
DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
inline void prepare1(int i) {
bits.b[0] = vld1q_s8(x[i].qs);
bits.b[1] = vld1q_s8(x[i].qs+16);
}
inline float16x4_t new_block(int i) {
auto scale = vld1_f16((const float16_t *)x[i].d);
for (int k = 0; k < 4; ++k) {
bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k);
bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16);
}
return scale;
}
};
struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {
DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
@@ -5529,9 +5550,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i
q8.process_scales(i, deq, sc16, acc);
sum_4(i, deq, q8, sc16, acc);
}
for (int i = 4*(nb/4); i < nb; ++i) {
q8.process_1_block(i, deq, acc);
}
//for (int i = 4*(nb/4); i < nb; ++i) {
// q8.process_1_block(i, deq, acc);
//}
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(acc[iy]));
@@ -5591,9 +5612,9 @@ inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8&
q8.process_scales(i, deq1, sc16, acc);
sum_4(i, deq1, q8, sc16, acc);
}
for (int i = 4*(nb/4); i < nb; ++i) {
q8.process_1_block(i, deq1, acc);
}
//for (int i = 4*(nb/4); i < nb; ++i) {
// q8.process_1_block(i, deq1, acc);
//}
info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));
}
@@ -6802,6 +6823,7 @@ template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
static_assert(step == QK8_0);
using Base = BaseHelper<step>;
using block_q8 = block_q8_0;
HelperQ80(const char * data, int stride) : Base(data, stride) {}
inline void load(int l1, F16::Data * vk) const {
@@ -7616,6 +7638,10 @@ struct FlashQKfp32 {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
@@ -7651,6 +7677,18 @@ struct FlashQKfp32 {
case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
}
}
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
@@ -7801,7 +7839,8 @@ struct FlashAttn {
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else {