Be able to use IQ4_NL for KV cache on ARM_NEON

This commit is contained in:
Iwan Kawrakow
2024-10-01 14:43:33 +03:00
parent 4d3ecb5852
commit 09789d017f

View File

@@ -6251,7 +6251,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
int32x4_t accd[nrc_y];
const auto m1 = vdupq_n_u8(1);
const auto mask2 = vdupq_n_s8(3);
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -7181,10 +7180,14 @@ struct HelperQ41 final : public BaseHelper<step> {
};
template <int D, int step>
struct HelperIQ4NL final : public BaseHelper<step> {
struct HelperIQ4nl final : public BaseHelper<step> {
using Base = BaseHelper<step>;
#ifdef __aarch64__
using block_q8 = block_q8_0;
#else
using block_q8 = block_q8_1;
HelperIQ4NL(const char * data, int stride) : Base(data, stride) {}
#endif
HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {}
// Needed for v * softmax(k * q)
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
@@ -7194,7 +7197,7 @@ struct HelperIQ4NL final : public BaseHelper<step> {
auto vd = F16::set1(*(const float16_t *)&dl->d);
auto q = vld1q_u8(dl->qs);
q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
q = vaddq_s8(q, m8);
q = vqtbl1q_s8(values, q);
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
#else
@@ -7214,11 +7217,12 @@ struct HelperIQ4NL final : public BaseHelper<step> {
#endif
}
#ifdef __AVX2__
#ifdef __aarch64__
const uint8x16_t mask = vdupq_n_u8(0xf);
const int8x16_t values;
#else
const __m128i mask = _mm_set1_epi8(0xf);
const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values);
#else
const uint8x16_t mask = vdupq_n_u8(0xf);
#endif
};
@@ -7746,10 +7750,10 @@ struct FlashQKfp32 {
mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
}
else if constexpr (std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
#ifdef __aarch64__
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step);
#else
mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
#endif
@@ -7853,17 +7857,17 @@ struct FlashQKfp32 {
#endif
}
}
else if constexpr (std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
switch (nq) {
#ifdef __aarch64__
case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break;
case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break;
case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break;
case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break;
case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break;
case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break;
#else
case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
@@ -8014,7 +8018,7 @@ struct FlashAttn {
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float * qkv) {
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
} else {
@@ -8356,7 +8360,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4NL<D, k_step> vh(v, stride_v);
HelperIQ4nl<D, k_step> vh(v, stride_v);
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
} break;
default: break;
@@ -8387,7 +8391,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4NL<D, k_step> kh(k, stride_k);
HelperIQ4nl<D, k_step> kh(k, stride_k);
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
} break;
default: break;