mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-04 02:50:01 +00:00
Be able to use IQ4_NL for KV cache on ARM_NEON
This commit is contained in:
@@ -6251,7 +6251,6 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
int32x4_t accd[nrc_y];
|
||||
|
||||
const auto m1 = vdupq_n_u8(1);
|
||||
const auto mask2 = vdupq_n_s8(3);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
@@ -7181,10 +7180,14 @@ struct HelperQ41 final : public BaseHelper<step> {
|
||||
};
|
||||
|
||||
template <int D, int step>
|
||||
struct HelperIQ4NL final : public BaseHelper<step> {
|
||||
struct HelperIQ4nl final : public BaseHelper<step> {
|
||||
using Base = BaseHelper<step>;
|
||||
#ifdef __aarch64__
|
||||
using block_q8 = block_q8_0;
|
||||
#else
|
||||
using block_q8 = block_q8_1;
|
||||
HelperIQ4NL(const char * data, int stride) : Base(data, stride) {}
|
||||
#endif
|
||||
HelperIQ4nl(const char * data, int stride) : Base(data, stride), values(vld1q_s8(iq4k_values)) {}
|
||||
|
||||
// Needed for v * softmax(k * q)
|
||||
inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const {
|
||||
@@ -7194,7 +7197,7 @@ struct HelperIQ4NL final : public BaseHelper<step> {
|
||||
auto vd = F16::set1(*(const float16_t *)&dl->d);
|
||||
auto q = vld1q_u8(dl->qs);
|
||||
q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask);
|
||||
q = vaddq_s8(q, m8);
|
||||
q = vqtbl1q_s8(values, q);
|
||||
v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q))));
|
||||
v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q))));
|
||||
#else
|
||||
@@ -7214,11 +7217,12 @@ struct HelperIQ4NL final : public BaseHelper<step> {
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef __AVX2__
|
||||
#ifdef __aarch64__
|
||||
const uint8x16_t mask = vdupq_n_u8(0xf);
|
||||
const int8x16_t values;
|
||||
#else
|
||||
const __m128i mask = _mm_set1_epi8(0xf);
|
||||
const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values);
|
||||
#else
|
||||
const uint8x16_t mask = vdupq_n_u8(0xf);
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -7746,10 +7750,10 @@ struct FlashQKfp32 {
|
||||
mul_mat_qX_1_q8_1_T<Q4_1_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
|
||||
else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
#ifdef __aarch64__
|
||||
mul_mat_qX_1_q8_1<DequantizerQ41, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
mul_mat_qX_0_q8_0<DequantizerIQ4NL, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#else
|
||||
mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, q_step>(D, kh.block, kh.stride, info, k_step);
|
||||
#endif
|
||||
@@ -7853,17 +7857,17 @@ struct FlashQKfp32 {
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
|
||||
else if constexpr (std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
|
||||
DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr};
|
||||
switch (nq) {
|
||||
#ifdef __aarch64__
|
||||
case 1: mul_mat_qX_1_q8_1<DequantizerQ41, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1<DequantizerQ41, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_1_q8_1<DequantizerQ41, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_1_q8_1<DequantizerQ41, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_1_q8_1<DequantizerQ41, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_1_q8_1<DequantizerQ41, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_1_q8_1<DequantizerQ41, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 1: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 3: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 3>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 4: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 4>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 5: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 5>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 6: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 6>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 7: mul_mat_qX_0_q8_0<DequantizerIQ4NL, 7>(D, kh.block, kh.stride, info, k_step); break;
|
||||
#else
|
||||
case 1: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 1>(D, kh.block, kh.stride, info, k_step); break;
|
||||
case 2: mul_mat_qX_1_q8_1_T<IQ4_NL_Unpacker, 2>(D, kh.block, kh.stride, info, k_step); break;
|
||||
@@ -8014,7 +8018,7 @@ struct FlashAttn {
|
||||
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
|
||||
const float * q, const char * mask, float * qkv) {
|
||||
if constexpr (std::is_same_v<KHelper, HelperQ40<D, k_step>> || std::is_same_v<KHelper, HelperQ41<D, k_step>> ||
|
||||
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4NL<D, k_step>>) {
|
||||
std::is_same_v<KHelper, HelperQ80<D, k_step>> || std::is_same_v<KHelper, HelperIQ4nl<D, k_step>>) {
|
||||
compute_helper_q<D, q_step, k_step, KHelper, VHelper, FlashQKfp32<D, q_step, k_step>>(
|
||||
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv);
|
||||
} else {
|
||||
@@ -8356,7 +8360,7 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
HelperIQ4NL<D, k_step> vh(v, stride_v);
|
||||
HelperIQ4nl<D, k_step> vh(v, stride_v);
|
||||
iqk_flash_helper<D, q_step, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
@@ -8387,7 +8391,7 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
case GGML_TYPE_IQ4_NL: {
|
||||
HelperIQ4NL<D, k_step> kh(k, stride_k);
|
||||
HelperIQ4nl<D, k_step> kh(k, stride_k);
|
||||
iqk_flash_helper_T<D, q_step, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv);
|
||||
} break;
|
||||
default: break;
|
||||
|
||||
Reference in New Issue
Block a user