Be able to use IQ4_NL for KV cache on AVX2/Zen4

This commit is contained in:
Iwan Kawrakow
2024-10-01 13:18:35 +03:00
parent 8457a26f83
commit 4d3ecb5852

View File

@@ -542,9 +542,13 @@ struct SimpleBits {
__m256i values[4];
};
__m256i inline load_iq4nl_values_256() {
__m128i inline load_iq4nl_values_128() {
static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241};
auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
return _mm_loadu_si128((const __m128i *)kvalues_iq4nl);
}
__m256i inline load_iq4nl_values_256() {
auto val128 = load_iq4nl_values_128();
return MM256_SET_M128I(val128, val128);
}
@@ -7176,6 +7180,48 @@ struct HelperQ41 final : public BaseHelper<step> {
#endif
};
template <int D, int step>
struct HelperIQ4NL final : public BaseHelper<step> {
using Base = BaseHelper<step>;
using block_q8 = block_q8_1;
HelperIQ4NL(const char * data, int stride) : Base(data, stride) {}
// Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
int j = F16::block_size*i;
auto dl = (const block_iq4_nl *)Base::lblock(l1) + j/QK4_0;
#ifdef __aarch64__
auto vd = F16::set1(*(const float16_t *)&dl->d);
auto q = vld1q_u8(dl->qs);
q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
q = vaddq_s8(q, m8);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
#else
auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d));
auto q = _mm_loadu_si128((const __m128i *)dl->qs);
#ifdef HAVE_FANCY_SIMD
auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask));
auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask));
v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql)));
v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh)));
#else
if (j%QK4_0) q = _mm_srli_epi16(q, 4);
auto q16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(values, _mm_and_si128(q, mask)));
v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16))));
v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1))));
#endif
#endif
}
#ifdef __AVX2__
const __m128i mask = _mm_set1_epi8(0xf);
const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values);
#else
const uint8x16_t mask = vdupq_n_u8(0xf);
#endif
};
template <int q_step, int k_step>
struct FlashMS {
// Something goes wrong when storing and manipulating K*Q as fp16.
@@ -7698,6 +7744,14 @@ struct FlashQKfp32 {
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
#else
mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
#else
mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
}
else {
@@ -7796,6 +7850,28 @@ struct FlashQKfp32 {
case 5: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
#endif
}
}
else if constexpr (std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
#ifdef __aarch64__
case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
#else
case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 7>(D, kh.block, kh.stride, info, k_step); break;
#endif
}
}
@@ -7938,7 +8014,7 @@ struct FlashAttn {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
std::is_same_v<KHelper, HelperQ80<D, k_step>>) {
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else {
@@ -8279,6 +8355,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
HelperQ41<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4NL<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
default: break;
}
}
@@ -8306,16 +8386,20 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
HelperQ41<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4NL<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
default: break;
}
}
inline bool flash_attn_is_supported(ggml_type type) {
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true;
#ifdef __AVX512BF16__
if (type == GGML_TYPE_BF16) return true;
#endif
if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true;
return false;
}
}