From 20d50172d011229eac971d28aa4dca922b0246a8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 28 Apr 2025 11:26:28 +0300 Subject: [PATCH] Much better FA TG with q8_0 KV cache Just repack it even for TG. But do the repacking for k_step rows, not the whole K tensor. --- ggml/src/iqk/iqk_mul_mat.cpp | 127 ++++++++++++++++++++++++++++++++--- 1 file changed, 118 insertions(+), 9 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a4973fc1..5f916584 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6640,6 +6640,84 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI } } +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + GGML_ASSERT(n%32 == 0); + __m512i qx[4]; + __m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {}; + float dy[nrc_y]; + int32_t sy[nrc_y]; + const int8_t * q8y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) { + auto dptr = (const float *)info.src1_row(iy); + dy[iy] = dptr[0]; + auto iptr = (const int32_t *)(dptr + 1); + sy[iy] = -64*iptr[0]; + q8y[iy] = (const int8_t *)(dptr + 2); + } + const int8_t * q8x[8]; + float dx[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int kx = 0; kx < 8; ++kx) { + auto dptr = (const float *)((const char *)vx + (ix+kx)*bx); + dx[kx] = dptr[0]; + q8x[kx] = (const int8_t *)(dptr + 2); + } + for (int i = 0; i < n/32; ++i) { + for (int kx = 0; kx < 4; ++kx) { + qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)), + _mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1); + } + auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]); + auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]); + auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]); + auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]); + qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128)); + qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128)); + qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128)); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i); + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1); + if constexpr (nrc_y <= 4) { + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } else { + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + } + auto scales_x = _mm256_loadu_ps(dx); + for (int iy = 0; iy < nrc_y; ++iy) { + if constexpr (nrc_y <= 4) { + auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } else { + acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy])); + auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1)); + auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3)); + auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy])); + info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1))); + info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2))); + acc[iy] = _mm512_setzero_si512(); + } + } + } +} +#endif + template static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -15697,18 +15775,14 @@ struct HelperQ80R8 : public BaseHelper { Base::stride = (D/QK8_0)*sizeof(block_q8_0); } - static std::vector repack(int nk, const HelperQ80& q8) { - static_assert(D%QK8_0 == 0); - GGML_ASSERT(nk%8 == 0); + static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) { constexpr int nblock = D/QK8_0; - std::vector result(nblock * nk/8); - auto y = result.data(); const block_q8_0 * x8[8]; #ifdef __ARM_NEON int8x16x2_t m0, m1, m2, m3; #endif for (int row = 0; row < nk; row += 8) { - for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); + for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride); for (int ib = 0; ib < nblock; ++ib) { for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d; #ifdef __AVX2__ @@ -15790,6 +15864,15 @@ struct HelperQ80R8 : public BaseHelper { } y += nblock; } + } + + static std::vector repack(int nk, const HelperQ80& q8) { + static_assert(D%QK8_0 == 0); + GGML_ASSERT(nk%8 == 0); + constexpr int nblock = D/QK8_0; + std::vector result(nblock * nk/8); + auto y = result.data(); + repack(nk, q8.data, q8.stride, y); return result; } @@ -16837,10 +16920,15 @@ struct FlashQKfp32 { if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1); MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #else -#ifdef HAVE_FANCY_SIMD - if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); -#endif if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1); +#ifdef HAVE_FANCY_SIMD + if constexpr (D%32 == 0 && k_step%8 == 0) { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq); + } else { + if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16); + } +#endif MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq); #endif } @@ -17016,6 +17104,27 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, const float * q, const char * mask, float * qkv, float * M, float * S, char * qptr) { auto q8 = (typename KHelper::block_q8 *)qptr; + if constexpr (q_step > 1 && std::is_same_v>) { + if (nq1 == q_step) { + fms.init_qstep(); + kh.reset_block(); + vh.reset_block(); + block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8]; + HelperQ80R8 khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0)); + HelperQ80::convert(q_step, stride_q, q, q8); + auto mr = mask; + for (int k1 = 0; k1 < nk1/k_step; ++k1) { + HelperQ80R8::repack(k_step, kh.data, kh.stride, q8r8); + KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms); + fqkv.accumulate_qkv(vh, fms); + kh.next_block(); + vh.next_block(); + mr += k_step*sizeof(ggml_half); + } + fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S); + return; + } + } #if FA_TIMING Perf perf(false); #endif