Much better FA TG with q8_0 KV cache

Just repack it even for TG. But do the repacking for k_step rows,
not the whole K tensor.
This commit is contained in:
Iwan Kawrakow
2025-04-28 11:26:28 +03:00
parent 802d4de1b5
commit 20d50172d0

View File

@@ -6640,6 +6640,84 @@ static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataI
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
__m512i qx[4];
__m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
float dy[nrc_y];
int32_t sy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -64*iptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[8];
float dx[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int kx = 0; kx < 8; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/32; ++i) {
for (int kx = 0; kx < 4; ++kx) {
qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
_mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
}
auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
if constexpr (nrc_y <= 4) {
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
} else {
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
}
}
auto scales_x = _mm256_loadu_ps(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
if constexpr (nrc_y <= 4) {
auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
} else {
acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
acc[iy] = _mm512_setzero_si512();
}
}
}
}
#endif
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -15697,18 +15775,14 @@ struct HelperQ80R8 : public BaseHelper<step> {
Base::stride = (D/QK8_0)*sizeof(block_q8_0);
}
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
static_assert(D%QK8_0 == 0);
GGML_ASSERT(nk%8 == 0);
static void repack(int nk, const char * q8_data, int q8_stride, block_q8_0_r8 * y) {
constexpr int nblock = D/QK8_0;
std::vector<block_q8_0_r8> result(nblock * nk/8);
auto y = result.data();
const block_q8_0 * x8[8];
#ifdef __ARM_NEON
int8x16x2_t m0, m1, m2, m3;
#endif
for (int row = 0; row < nk; row += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8_data + (row + k)*q8_stride);
for (int ib = 0; ib < nblock; ++ib) {
for (int k = 0; k < 8; ++k) y[ib].d[k] = x8[k][ib].d;
#ifdef __AVX2__
@@ -15790,6 +15864,15 @@ struct HelperQ80R8 : public BaseHelper<step> {
}
y += nblock;
}
}
static std::vector<block_q8_0_r8> repack(int nk, const HelperQ80<D, step>& q8) {
static_assert(D%QK8_0 == 0);
GGML_ASSERT(nk%8 == 0);
constexpr int nblock = D/QK8_0;
std::vector<block_q8_0_r8> result(nblock * nk/8);
auto y = result.data();
repack(nk, q8.data, q8.stride, y);
return result;
}
@@ -16837,10 +16920,15 @@ struct FlashQKfp32 {
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#else
#ifdef HAVE_FANCY_SIMD
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
#endif
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
#ifdef HAVE_FANCY_SIMD
if constexpr (D%32 == 0 && k_step%8 == 0) {
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
} else {
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
}
#endif
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#endif
}
@@ -17016,6 +17104,27 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
const float * q, const char * mask, float * qkv,
float * M, float * S, char * qptr) {
auto q8 = (typename KHelper::block_q8 *)qptr;
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80<Dk, k_step>>) {
if (nq1 == q_step) {
fms.init_qstep();
kh.reset_block();
vh.reset_block();
block_q8_0_r8 q8r8[Dk/QK8_0 * k_step/8];
HelperQ80R8<Dk, k_step> khr8((const char *)q8r8, Dk/QK8_0*sizeof(block_q8_0));
HelperQ80<Dk, QK8_0>::convert(q_step, stride_q, q, q8);
auto mr = mask;
for (int k1 = 0; k1 < nk1/k_step; ++k1) {
HelperQ80R8<Dk, k_step>::repack(k_step, kh.data, kh.stride, q8r8);
KQHelper::mul_mask_kq(khr8, stride_m, q8, mr, fms);
fqkv.accumulate_qkv(vh, fms);
kh.next_block();
vh.next_block();
mr += k_step*sizeof(ggml_half);
}
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
return;
}
}
#if FA_TIMING
Perf perf(false);
#endif