From 4d3ecb58520a24d5fac047c67817ab69af343368 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 1 Oct 2024 13:18:35 +0300 Subject: [PATCH] Be able to use IQ4_NL for KV cache on AVX2/Zen4 --- ggml/src/iqk/iqk_mul_mat.cpp | 92 ++++++++++++++++++++++++++++++++++-- 1 file changed, 88 insertions(+), 4 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1183246b..9b83b3f2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -542,9 +542,13 @@ struct SimpleBits { __m256i values[4]; }; -__m256i inline load_iq4nl_values_256() { +__m128i inline load_iq4nl_values_128() { static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; - auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq4nl); + return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); +} + +__m256i inline load_iq4nl_values_256() { + auto val128 = load_iq4nl_values_128(); return MM256_SET_M128I(val128, val128); } @@ -7176,6 +7180,48 @@ struct HelperQ41 final : public BaseHelper { #endif }; +template +struct HelperIQ4NL final : public BaseHelper { + using Base = BaseHelper; + using block_q8 = block_q8_1; + HelperIQ4NL(const char * data, int stride) : Base(data, stride) {} + + // Needed for v * softmax(k * q) + inline void load(int l1, int i, F16::Data& v1, F16::Data& v2) const { + int j = F16::block_size*i; + auto dl = (const block_iq4_nl *)Base::lblock(l1) + j/QK4_0; +#ifdef __aarch64__ + auto vd = F16::set1(*(const float16_t *)&dl->d); + auto q = vld1q_u8(dl->qs); + q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); + q = vaddq_s8(q, m8); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q)))); +#else + auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); + auto q = _mm_loadu_si128((const __m128i *)dl->qs); +#ifdef HAVE_FANCY_SIMD + auto ql = _mm_shuffle_epi8(values, _mm_and_si128(q, mask)); + auto qh = _mm_shuffle_epi8(values, _mm_and_si128(_mm_srli_epi16(q, 4), mask)); + v1 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(ql))); + v2 = _mm512_mul_ps(vd, _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(qh))); +#else + if (j%QK4_0) q = _mm_srli_epi16(q, 4); + auto q16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(values, _mm_and_si128(q, mask))); + v1 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16)))); + v2 = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16, 1)))); +#endif +#endif + } + +#ifdef __AVX2__ + const __m128i mask = _mm_set1_epi8(0xf); + const __m128i values = _mm_loadu_si128((const __m128i *)iq4k_values); +#else + const uint8x16_t mask = vdupq_n_u8(0xf); +#endif +}; + template struct FlashMS { // Something goes wrong when storing and manipulating K*Q as fp16. @@ -7698,6 +7744,14 @@ struct FlashQKfp32 { mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); #else mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); +#endif + } + else if constexpr (std::is_same_v>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; +#ifdef __aarch64__ + mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); +#else + mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); #endif } else { @@ -7796,6 +7850,28 @@ struct FlashQKfp32 { case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; +#endif + } + } + else if constexpr (std::is_same_v>) { + DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; + switch (nq) { +#ifdef __aarch64__ + case 1: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1(D, kh.block, kh.stride, info, k_step); break; +#else + case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; #endif } } @@ -7938,7 +8014,7 @@ struct FlashAttn { void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv, const float * q, const char * mask, float * qkv) { if constexpr (std::is_same_v> || std::is_same_v> || - std::is_same_v>) { + std::is_same_v> || std::is_same_v>) { compute_helper_q>( kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv); } else { @@ -8279,6 +8355,10 @@ inline void iqk_flash_helper_T(KHelper& kh, ggml_type type_v, HelperQ41 vh(v, stride_v); iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); } break; + case GGML_TYPE_IQ4_NL: { + HelperIQ4NL vh(v, stride_v); + iqk_flash_helper(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv); + } break; default: break; } } @@ -8306,16 +8386,20 @@ inline void iqk_flash_helper_T(ggml_type type_k, ggml_type type_v, HelperQ41 kh(k, stride_k); iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); } break; + case GGML_TYPE_IQ4_NL: { + HelperIQ4NL kh(k, stride_k); + iqk_flash_helper_T(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv); + } break; default: break; } } inline bool flash_attn_is_supported(ggml_type type) { - if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1) return true; #ifdef __AVX512BF16__ if (type == GGML_TYPE_BF16) return true; #endif + if (type == GGML_TYPE_F16 || type == GGML_TYPE_Q8_0 || type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL) return true; return false; } }