From 15a8115fcf18f49bda63ff37996db002ca7b0e89 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 12 Oct 2024 17:54:32 +0300 Subject: [PATCH] iq2_ks: WIP --- ggml/src/iqk/iqk_mul_mat.cpp | 88 ++++++++++++++++++++++++++++++------ 1 file changed, 75 insertions(+), 13 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 1b79d3c3..e58a4f05 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -402,14 +402,20 @@ struct ScaleIQ4XS { const __m128i m32 = _mm_set1_epi16(-32); }; -template +template struct BaseDequantizer { BaseDequantizer(const void * vx, size_t bx) : vx(vx), bx(bx) {} inline void new_row(int ix) { if constexpr (per_row_scale) { - const float * dptr = (const float *)((const char *)vx + bx*ix); - d = *dptr; - x = (const Block *)(dptr + 1); + if constexpr (is_f16) { + const ggml_half * dptr = (const ggml_half *)((const char *)vx + bx*ix); + d = GGML_FP16_TO_FP32(*dptr); + x = (const Block *)(dptr + 1); + } else { + const float * dptr = (const float *)((const char *)vx + bx*ix); + d = *dptr; + x = (const Block *)(dptr + 1); + } } else { x = (const Block *)((const char *)vx + bx*ix); } @@ -898,6 +904,58 @@ struct DequantizerIQ2K final : public BaseDequantizer { const __m128i m8 = _mm_set1_epi8(-8); }; +struct DequantizerIQ2KS final : public BaseDequantizer { + DequantizerIQ2KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_values()) {} + template + inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { + prepare(x[i].qs); + auto scales128 = make_scales(x[i].scales, x[i].extra >> 8); + auto shifts = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(_mm_set1_epi8(x[i].extra), hmask), hmask), m5); + auto scales_s = _mm_mullo_epi16(scales128, _mm_cvtepi8_epi16(_mm_add_epi8(m32, shifts))); + s8k.accum_mins(scales_s, q8, i, d, accm); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]); + scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]); + //scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]); + //scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]); + //scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]); + //scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]); + } + inline void prepare(const uint8_t * q2) { + bits.prepare(q2); + bits.values[0] = _mm512_shuffle_epi8(values, bits.values[0]); + bits.values[1] = _mm512_shuffle_epi8(values, bits.values[1]); + bits.values[2] = _mm512_shuffle_epi8(values, bits.values[2]); + bits.values[3] = _mm512_shuffle_epi8(values, bits.values[3]); + } + static inline __m512i load_values() { + static const uint8_t kvalues_iq2nl[16] = {1, 19, 33, 49, 0, 0, 0, 0, 6, 24, 38, 54, 0, 0, 0, 0}; + auto val128 = _mm_loadu_si128((const __m128i *)kvalues_iq2nl); + auto val256 = MM256_SET_M128I(val128, val128); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); + } + inline __m128i make_scales(const uint8_t * scales_l, const uint8_t scales_h) const { + const uint16_t * scales = (const uint16_t *)scales_l; + uint32_t aux32 = scales[0] | (scales[1] << 16); + auto scl = _mm_srlv_epi32(_mm_set1_epi32(aux32), shift); + scl = _mm_and_si128(_mm_shuffle_epi8(scl, shuffle), _mm_set1_epi8(0xf)); + auto sch = _mm_set1_epi8(scales_h); + sch = _mm_and_si128(_mm_cmpeq_epi8(_mm_and_si128(sch, hmask), hmask), m16); + return _mm_cvtepi8_epi16(_mm_add_epi8(scl, sch)); + } + Q2Bits bits; + Scales8K s8k; + + const __m512i values; + const __m128i m16 = _mm_set1_epi8(-16); + const __m128i m5 = _mm_set1_epi8(5); + const __m128i m32 = _mm_set1_epi8(-32); + const __m128i hmask = _mm_set1_epi64x(0x8040201008040201); + const __m128i shuffle = _mm_set1_epi64x(0x0703060205010400); + const __m128i shift = _mm_set_epi32(0, 0, 4, 0); +}; + struct DequantizerIQ3K final : public BaseDequantizer { DequantizerIQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(4, -64), values(load_values()) {} template @@ -1107,8 +1165,8 @@ struct DequantizerIQ6K final : public BaseDequantizer { const __m512i permute2 = _mm512_set_epi64(15, 14, 13, 12, 7, 6, 5, 4); }; -struct DequantizerIQ4XXS final : public BaseDequantizer { - DequantizerIQ4XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} +struct DequantizerIQ4KS final : public BaseDequantizer { + DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) { auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); @@ -1740,8 +1798,8 @@ struct DequantizerIQ6K final : public BaseDequantizer { const __m256i mh = _mm256_set1_epi8(-128); // to avoid stupid warning about 0x80 overflowing }; -struct DequantizerIQ4XXS final : public BaseDequantizer { - DequantizerIQ4XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {} +struct DequantizerIQ4KS final : public BaseDequantizer { + DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {} template inline __m256i new_block(int i, const Q8& q8, __m256 * accd) { auto scales128 = _mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i *)x[i].scales)); @@ -3751,7 +3809,7 @@ template void MulMat::set_functions(MulMat& m) { std::is_same_v || std::is_same_v || std::is_same_v|| - std::is_same_v) { + std::is_same_v) { m.funcs[0] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[1] = mul_mat_iqX_k_q8_K_AVX512; m.funcs[2] = mul_mat_iqX_k_q8_K_AVX512; @@ -3913,12 +3971,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { break; case GGML_TYPE_IQ4_KS: assert (ne00 % QK_K == 0); - MulMat::set_functions(mm); + MulMat::set_functions(mm); break; case GGML_TYPE_IQ2_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); break; + case GGML_TYPE_IQ2_KS: + assert (ne00 % QK_K == 0); + MulMat::set_functions(mm); + break; case GGML_TYPE_IQ3_K: assert (ne00 % QK_K == 0); MulMat::set_functions(mm); @@ -4809,9 +4871,9 @@ struct DequantizerIQ4XS final : public BaseDequantizer { }; -struct DequantizerIQ4XXS final : public BaseDequantizer { +struct DequantizerIQ4KS final : public BaseDequantizer { - DequantizerIQ4XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {} + DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {} constexpr static int num_blocks() { return 8; } constexpr static bool should_scale_quants() { return false; } @@ -6571,7 +6633,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); break; case GGML_TYPE_IQ4_KS: - MulMat::set_functions(m); + MulMat::set_functions(m); break; case GGML_TYPE_IQ4_K: MulMat::set_functions(m);