mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 15:44:10 +00:00
Slightly faster FA for Q8_0 KV cache
This commit is contained in:
@@ -6755,10 +6755,10 @@ struct Q_Unpacker {
|
||||
}
|
||||
};
|
||||
|
||||
struct Q8_0_x4_Unpacker {
|
||||
struct Q8_0_x4_Unpacker_256 {
|
||||
using Sum4T = Sum4TypeQ80;
|
||||
inline static int block_size() { return QK8_0; }
|
||||
Q8_0_x4_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
|
||||
Q8_0_x4_Unpacker_256(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
|
||||
|
||||
const char * cx_0;
|
||||
const block_q8_0_x4 * x;
|
||||
@@ -6784,6 +6784,44 @@ struct Q8_0_x4_Unpacker {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
struct Q8_0_x4_Unpacker_512 {
|
||||
using Sum4T = Sum4TypeQ81;
|
||||
inline static int block_size() { return QK8_0; }
|
||||
Q8_0_x4_Unpacker_512(const void * vx, size_t bx) : cx_0((const char *)vx), x((const block_q8_0_x4 *)cx_0), bx(bx) {}
|
||||
|
||||
const char * cx_0;
|
||||
const block_q8_0_x4 * x;
|
||||
size_t bx;
|
||||
const __m128 min = _mm_set1_ps(-128.f);
|
||||
|
||||
__m256i qx[4];
|
||||
|
||||
inline const __m256i* quants() const { return qx; }
|
||||
|
||||
inline void set_row(int ix) { x = (const block_q8_0_x4 *)(cx_0 + ix*bx); }
|
||||
|
||||
inline auto set_block_4(int i) {
|
||||
auto scales = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)x[i].d));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
qx[j] = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
|
||||
qx[j] = _mm256_xor_si256(qx[j], _mm256_set1_epi8(0x80));
|
||||
}
|
||||
return _mm256_set_m128(_mm_mul_ps(scales, min), scales);
|
||||
}
|
||||
inline auto set_block(int i) {
|
||||
auto q8 = (const block_q8_0 *)(x + i);
|
||||
qx[0] = _mm256_loadu_si256((const __m256i *)q8->qs);
|
||||
qx[0] = _mm256_xor_si256(qx[0], _mm256_set1_epi8(0x80));
|
||||
float d = GGML_FP16_TO_FP32(q8->d);
|
||||
return std::make_pair(d, -128.f*d);
|
||||
}
|
||||
};
|
||||
using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_512;
|
||||
#else
|
||||
using Q8_0_x4_Unpacker = Q8_0_x4_Unpacker_256;
|
||||
#endif
|
||||
|
||||
struct Q8_0_Unpacker final : public Q_Unpacker<block_q8_0, ScaleHelperQ_0, Q8_0_Dequantizer> {
|
||||
Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {}
|
||||
using Sum4T = Sum4TypeQ80;
|
||||
@@ -12660,7 +12698,12 @@ void quantize_row_q8_1(const float * x, block_q8_1 * y, int k) {
|
||||
template <int D, int step>
|
||||
struct HelperQ80 final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
//using block_q8 = block_q8_1;
|
||||
using block_q8 = block_q8_1;
|
||||
#else
|
||||
using block_q8 = block_q8_0;
|
||||
#endif
|
||||
HelperQ80(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
// Needed for v * softmax(k * q)
|
||||
@@ -13406,8 +13449,13 @@ struct FlashQKfp32 {
|
||||
mul_mat_qX_0_q8_0<DequantizerQ80_x4, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
if constexpr (D >= 128) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
} else {
|
||||
// This does not actually work until we fix K-cache to be quantizez to Q8_0_x4 only if D%128 == 0
|
||||
mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
}
|
||||
#endif
|
||||
@@ -13492,6 +13540,15 @@ struct FlashQKfp32 {
|
||||
#else
|
||||
if constexpr (D >= 128) {
|
||||
switch (nq) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
case 1: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#else
|
||||
case 1: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
@@ -13499,8 +13556,10 @@ struct FlashQKfp32 {
|
||||
case 5: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0_T<Q8_0_x4_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#endif
|
||||
}
|
||||
} else {
|
||||
// This does not actually work until we fix K-cache to be quantizez to Q8_0_x4 only if D%128 == 0
|
||||
switch (nq) {
|
||||
case 1: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0_T<Q8_0_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
@@ -14226,7 +14285,7 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
|
||||
iqk_flash_helper_T< 96, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 128:
|
||||
iqk_flash_helper_T<128, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
case 256
|
||||
case 256:
|
||||
iqk_flash_helper_T<256, 8, 32>(nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv, q, ck, cv, cm, scale, softcap, qkv); break;
|
||||
default:
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user