From 7623f769c095907f0f43e7ea2150aed94202c9c6 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 14 Jan 2025 10:21:11 +0200 Subject: [PATCH] Slightly faster FA for Q8_0 KV cache --- ggml/src/iqk/iqk_mul_mat.cpp | 65 ++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 3 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 5783bf0a..901b2dee 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6755,10 +6755,10 @@ struct Q_Unpacker { } }; -struct Q8_0_x4_Unpacker { +struct Q8_0_x4_Unpacker_256 { using Sum4T = Sum4TypeQ80; inline static int block_size() { return QK8_0; } - Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} const char * cx_0; const block_q8_0_x4 * x; @@ -6784,6 +6784,44 @@ struct Q8_0_x4_Unpacker { } }; +#ifdef HAVE_FANCY_SIMD +struct Q8_0_x4_Unpacker_512 { + using Sum4T = Sum4TypeQ81; + inline static int block_size() { return QK8_0; } + Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {} + + const char * cx_0; + const block_q8_0_x4 * x; + size_t bx; + const __m128 min = _mm_set1_ps(-128.f); + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j); + qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80)); + } + return _mm256_set_m128(_mm_mul_ps(scales, min), scales); + } + inline auto set_block(int i) { + auto q8 = (const block_q8_0 *)(x + i); + qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs); + qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80)); + float d = GGML_FP16_TO_FP32(q8->d); + return std::make_pair(d, -128.f*d); + } +}; +using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512; +#else +using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256; +#endif + struct Q8_0_Unpacker final : public Q_Unpacker { Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} using Sum4T = Sum4TypeQ80; @@ -12660,7 +12698,12 @@ void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) { template struct HelperQ80 final : public BaseHelper { using Base = BaseHelper; +#ifdef HAVE_FANCY_SIMD + //using block_q8 = block_q8_1; + using block_q8 = block_q8_1; +#else using block_q8 = block_q8_0; +#endif HelperQ80(const char * data, int stride) : Base(data, stride) {} // Needed for v * softmax(k * q) @@ -13406,8 +13449,13 @@ struct FlashQKfp32 { mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); #else if constexpr (D >= 128) { +#ifdef HAVE_FANCY_SIMD + mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); +#else mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); +#endif } else { + // This does not actually work until we fix K-cache to be quantizez to Q8_0_x4 only if D%128 == 0 mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); } #endif @@ -13492,6 +13540,15 @@ struct FlashQKfp32 { #else if constexpr (D >= 128) { switch (nq) { +#ifdef HAVE_FANCY_SIMD + case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; +#else case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; case 3: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; @@ -13499,8 +13556,10 @@ struct FlashQKfp32 { case 5: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; +#endif } } else { + // This does not actually work until we fix K-cache to be quantizez to Q8_0_x4 only if D%128 == 0 switch (nq) { case 1: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_0_q8_0_T(D, kh.block, kh.stride, info, k_step); break; @@ -14226,7 +14285,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; case 128: iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; - case 256 + case 256: iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break; default: return false;