mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
NEON Flash Attention: quantized K*Q for q8_0
This makes quite a bit of difference: For Gemma2-2b PP-8192 is 228 t/s with quantized K*Q vs 178 t/s when converting things to fp16 and using fp16 matrix multiplication. We have PP-512 = 307 t/s, so PP-8192 is now ~75% of the performance of PP-512. In contrast, llama.cpp with Q8_0 cache is 38% of PP-512.
This commit is contained in:
@@ -5461,6 +5461,27 @@ struct DequantizerQ80 final : public BaseLegacyDequantizer<block_q8_0> {
|
||||
|
||||
};
|
||||
|
||||
// TODO: handle case where row size is not a multiple of 128
|
||||
struct DequantizerQ80_x4 final : public BaseLegacyDequantizer<block_q8_0_x4> {
|
||||
|
||||
DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
|
||||
|
||||
inline void prepare1(int i) {
|
||||
bits.b[0] = vld1q_s8(x[i].qs);
|
||||
bits.b[1] = vld1q_s8(x[i].qs+16);
|
||||
}
|
||||
|
||||
inline float16x4_t new_block(int i) {
|
||||
auto scale = vld1_f16((const float16_t *)x[i].d);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k);
|
||||
bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16);
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ51 final : public BaseLegacyDequantizer<block_q5_1> {
|
||||
|
||||
DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {}
|
||||
@@ -5529,9 +5550,9 @@ inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& i
|
||||
q8.process_scales(i, deq, sc16, acc);
|
||||
sum_4(i, deq, q8, sc16, acc);
|
||||
}
|
||||
for (int i = 4*(nb/4); i < nb; ++i) {
|
||||
q8.process_1_block(i, deq, acc);
|
||||
}
|
||||
//for (int i = 4*(nb/4); i < nb; ++i) {
|
||||
// q8.process_1_block(i, deq, acc);
|
||||
//}
|
||||
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(acc[iy]));
|
||||
@@ -5591,9 +5612,9 @@ inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8&
|
||||
q8.process_scales(i, deq1, sc16, acc);
|
||||
sum_4(i, deq1, q8, sc16, acc);
|
||||
}
|
||||
for (int i = 4*(nb/4); i < nb; ++i) {
|
||||
q8.process_1_block(i, deq1, acc);
|
||||
}
|
||||
//for (int i = 4*(nb/4); i < nb; ++i) {
|
||||
// q8.process_1_block(i, deq1, acc);
|
||||
//}
|
||||
|
||||
info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1])));
|
||||
}
|
||||
@@ -6802,6 +6823,7 @@ template <int D, int step>
|
||||
struct HelperQ80 final : public BaseHelper<step> {
|
||||
static_assert(step == QK8_0);
|
||||
using Base = BaseHelper<step>;
|
||||
using block_q8 = block_q8_0;
|
||||
HelperQ80(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
inline void load(int l1, F16::Data * vk) const {
|
||||
@@ -7616,6 +7638,10 @@ struct FlashQKfp32 {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
mul_mat_qX_0_q8_0<DequantizerQ40, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
@@ -7651,6 +7677,18 @@ struct FlashQKfp32 {
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerQ40, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, D*sizeof(float), (D/QK8_0)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerQ80_x4, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
@@ -7801,7 +7839,8 @@ struct FlashAttn {
|
||||
template <typename KHelper, typename VHelper>
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>>) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
|
||||
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user