Slightly faster FA for Q8_0 KV cache

This commit is contained in:
Iwan Kawrakow
2025-01-14 10:21:11 +02:00
parent 2afe2e1d41
commit 7623f769c0

View File

@@ -6755,10 +6755,10 @@ struct Q_Unpacker {
}
};
struct Q8_0_x4_Unpacker {
struct Q8_0_x4_Unpacker_256 {
using Sum4T = Sum4TypeQ80;
inline static int block_size() { return QK8_0; }
Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
const char * cx_0;
const block_q8_0_x4 * x;
@@ -6784,6 +6784,44 @@ struct Q8_0_x4_Unpacker {
}
};
#ifdef HAVE_FANCY_SIMD
struct Q8_0_x4_Unpacker_512 {
using Sum4T = Sum4TypeQ81;
inline static int block_size() { return QK8_0; }
Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
const char * cx_0;
const block_q8_0_x4 * x;
size_t bx;
const __m128 min = _mm_set1_ps(-128.f);
__m256i qx[4];
inline const __m256i* quants() const { return qx; }
inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); }
inline auto set_block_4(int i) {
auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d));
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80));
}
return _mm256_set_m128(_mm_mul_ps(scales, min), scales);
}
inline auto set_block(int i) {
auto q8 = (const block_q8_0 *)(x + i);
qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs);
qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80));
float d = GGML_FP16_TO_FP32(q8->d);
return std::make_pair(d, -128.f*d);
}
};
using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512;
#else
using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256;
#endif
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
using Sum4T = Sum4TypeQ80;
@@ -12660,7 +12698,12 @@ void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) {
template <int D, int step>
struct HelperQ80 final : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef HAVE_FANCY_SIMD
//using block_q8 = block_q8_1;
using block_q8 = block_q8_1;
#else
using block_q8 = block_q8_0;
#endif
HelperQ80(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
@@ -13406,8 +13449,13 @@ struct FlashQKfp32 {
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
#else
if constexpr (D >= 128) {
#ifdef HAVE_FANCY_SIMD
mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#else
mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
} else {
// This does not actually work until we fix K-cache to be quantizez to Q8_0_x4 only if D%128 == 0
mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
}
#endif
@@ -13492,6 +13540,15 @@ struct FlashQKfp32 {
#else
if constexpr (D >= 128) {
switch (nq) {
#ifdef HAVE_FANCY_SIMD
case 1: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
#else
case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
@@ -13499,8 +13556,10 @@ struct FlashQKfp32 {
case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
#endif
}
} else {
// This does not actually work until we fix K-cache to be quantizez to Q8_0_x4 only if D%128 == 0
switch (nq) {
case 1: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
@@ -14226,7 +14285,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 128:
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
case 256
case 256:
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
default:
return false;