mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-06 12:00:29 +00:00
Add q6_0 to CPU flash attention
Disappointing result: for LlaMA-3.2-1B, q6_0 K- and V-cache gives about the same PPL as q8_0 K-cache and q4_0 V-cache, while needing the exact same RAM. I.e., what was the point?
This commit is contained in:
@@ -7242,6 +7242,61 @@ struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperQ60 final : public BaseHelper<step> {
|
||||
#ifdef __aarch64__
|
||||
using block_q8 = block_q8_0;
|
||||
#else
|
||||
using block_q8 = block_q8_1;
|
||||
#endif
|
||||
using Base = BaseHelper<step>;
|
||||
HelperQ60(const char * data, int stride) : Base(data, stride) {}
|
||||
|
||||
// Needed for v * softmax(k * q)
|
||||
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
|
||||
int j = F16::block_size*i;
|
||||
auto dl = (const block_q6_0 *)Base::lblock(l1) + j/QK6_0;
|
||||
#ifdef __aarch64__
|
||||
// TODO
|
||||
auto vd = F16::set1(*(const float16_t *)&dl->d);
|
||||
auto q = vld1q_u8(dl->qs);
|
||||
q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
|
||||
q = vaddq_s8(q, m8);
|
||||
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
|
||||
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
|
||||
#else
|
||||
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
|
||||
auto bl = _mm_loadu_si128((const __m128i *)dl->qs);
|
||||
uint64_t aux64; std::memcpy(&aux64, dl->qh, 8);
|
||||
auto bh = _mm_set_epi64x(aux64, aux64 << 4);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
auto ql = _mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32);
|
||||
auto qh = _mm_add_epi8(_mm_or_si128(_mm_and_si128(_mm_srli_epi16(bl, 4), mask_l), _mm_and_si128(_mm_srli_epi16(bh, 2), mask_h)), m32);
|
||||
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
|
||||
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
|
||||
#else
|
||||
if (j%QK4_0) {
|
||||
bl = _mm_srli_epi16(bl, 4);
|
||||
bh = _mm_srli_epi16(bh, 2);
|
||||
}
|
||||
auto q16 = _mm256_cvtepi8_epi16(_mm_add_epi8(_mm_or_si128(_mm_and_si128(bl, mask_l), _mm_and_si128(bh, mask_h)), m32));
|
||||
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
|
||||
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
const __m128i mask_l = _mm_set1_epi8(0x0f);
|
||||
const __m128i mask_h = _mm_set1_epi8(0x30);
|
||||
const __m128i m32 = _mm_set1_epi8(-32);
|
||||
#else
|
||||
const uint8x16_t mask_l = vdupq_n_u8(0x0f);
|
||||
const uint8x16_t mask_h = vdupq_n_u8(0x30);
|
||||
const int8x16_t m32 = vdupq_n_s8(-32);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <int q_step, int k_step>
|
||||
struct FlashMS {
|
||||
// Something goes wrong when storing and manipulating K*Q as fp16.
|
||||
@@ -7772,6 +7827,14 @@ struct FlashQKfp32 {
|
||||
mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
#ifdef __aarch64__
|
||||
mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
}
|
||||
else {
|
||||
@@ -7892,6 +7955,28 @@ struct FlashQKfp32 {
|
||||
case 5: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
#ifdef __aarch64__
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#else
|
||||
case 1: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1_T<Q6_0_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -8034,7 +8119,8 @@ struct FlashAttn {
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
|
||||
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ60<D, k_step>>) {
|
||||
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
|
||||
} else {
|
||||
@@ -8379,6 +8465,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
HelperIQ4nl<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
HelperQ60<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
@@ -8410,6 +8500,10 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
HelperIQ4nl<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_Q6_0: {
|
||||
HelperQ60<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
@@ -8419,7 +8513,8 @@ inline bool flash_attn_is_supported(ggml_type type) {
|
||||
#ifdef __AVX512BF16__
|
||||
if (type == GGML_TYPE_BF16) return true;
|
||||
#endif
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true;
|
||||
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 ||
|
||||
type == GGML_TYPE_Q6_0 || type == GGML_TYPE_IQ4_NL) return true;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user