diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index 57038d0c..166c6274 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -337,6 +337,105 @@ struct DequantizerIQ4KSS final : public BaseDequantizer { }; }; +struct DequantizerIQ2KL final : public BaseDequantizer { + DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } + inline __m128i make_scales(int i) const { + //uint16_t aux[8]; + //auto h = x[i].scales_h; + //for (int k = 0; k < 4; ++k) { aux[k+0] = (x[i].scales_l[k] & 0xf) | ((h << 4) & 0x30); h >>= 2; } + //for (int k = 0; k < 4; ++k) { aux[k+4] = (x[i].scales_l[k] >> 4) | ((h << 4) & 0x30); h >>= 2; } + //return _mm_sub_epi16(_mm_loadu_si128((const __m128i *)aux), _mm_set1_epi16(32)); + uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4); + auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), _mm_set_epi32(0, 0, 4, 0)), _mm_set1_epi8(0xf))); + // 0x000a000800060004 + auto sch = _mm_srlv_epi16(_mm_sllv_epi64(_mm_set1_epi16(x[i].scales_h), _mm_set_epi64x(0, 8)), _mm_set1_epi64x(0x000a000800060004)); + auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, _mm_and_si128(sch, _mm_set1_epi16(0x30))), _mm_set1_epi16(32)); + return scales128; + } + template + inline void compute_block(int i, const Q8& q8, __m512 * acc) { + auto scales128 = make_scales(i); + auto mins128 = _mm_mullo_epi16(scales128, _mm_set1_epi16(-64)); + auto mins = MM256_SET_M128I(_mm_shuffle_epi8(mins128, s8k.shuffles[1]), _mm_shuffle_epi8(mins128, s8k.shuffles[0])); + auto scales256 = MM256_SET_M128I(scales128, scales128); + auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1); + __m512i scales[4]; + for (int k = 0; k < 4; ++k) scales[k] = _mm512_shuffle_epi8(all_scales, shuffles[k]); + prepare(i); + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto q8s = q8.load_bsums(iy, i); + auto prod = _mm256_madd_epi16(mins, q8s); + auto sumi = _mm512_inserti32x8(_mm512_setzero_si512(), prod, 0); + for (int k = 0; k < 4; ++k) { + auto p = _mm512_maddubs_epi16(bits.values[k], q8.load_quants64(iy, i, k)); + sumi = _mm512_dpwssd_epi32(sumi, p, scales[k]); + } + acc[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), acc[iy]); + } + } + inline void prepare(int i) { + __m512i ql[2], qs[4]; + __mmask64 mask[2]; + // TODO: optimize this + for (int k = 0; k < 2; ++k) { + auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0); + auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1); + auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1); + auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2); + ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1)); + } + auto h128 = _mm_loadu_si128((const __m128i *)x[i].qh); + auto h256 = MM256_SET_M128I(_mm_srli_epi16(h128, 1), h128); + auto h512 = _mm512_inserti32x8(_mm512_castsi256_si512(h256), _mm256_srli_epi16(h256, 2), 1); + mask[0] = _mm512_cmpeq_epi8_mask(_mm512_and_si512(h512, m01), m01); + mask[1] = _mm512_cmpeq_epi8_mask(_mm512_and_si512(h512, m10), m10); + + for (int k = 0; k < 2; ++k) { + // qs[0]: even quants when hbits is not set (so pair index is in 0...15) + // qs[1]: even quants when hbits is set (so pair index is in 16...31) + // qs[2]: odd quants when hbits is not set (so pair index is in 0...15) + // qs[3]: odd quants when hbits is set (so pair index is in 16...31) + // if we blend qs[0] and qs[1] with the hbit mask, we get the correct even quants -> q1 + // if we blend qs[2] and qs[3] with the hbit mask, we get the correct odd quants -> q2 + // If we convert q1 and q2 to int16_t, shift q2 left by 8 bits, and or them, we get the quants in the correct order + for (int l = 0; l < 4; ++l) qs[l] = _mm512_shuffle_epi8(values[l], ql[k]); + auto q1 = _mm512_mask_blend_epi8(mask[k], qs[0], qs[1]); + auto q2 = _mm512_mask_blend_epi8(mask[k], qs[2], qs[3]); + auto q1l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q1)); + auto q1h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q1, 1)); + auto q2l = _mm512_cvtepu8_epi16(_mm512_castsi512_si256(q2)); + auto q2h = _mm512_cvtepu8_epi16(_mm512_extracti32x8_epi32(q2, 1)); + bits.values[2*k+0] = _mm512_or_si512(q1l, _mm512_slli_epi16(q2l, 8)); + bits.values[2*k+1] = _mm512_or_si512(q1h, _mm512_slli_epi16(q2h, 8)); + } + } + void load_values() { + static const uint8_t k_values[64] = { + 1, 1, 24, 24, 24, 24, 41, 41, 41, 41, 41, 54, 54, 54, 54, 65, 65, 65, 65, 65, 77, 77, 77, 77, 77, 92, 92, 92, 92, 92, 111, 111, + 41, 77, 1, 54, 77, 111, 24, 41, 65, 77, 92, 1, 65, 77, 111, 41, 54, 65, 77, 92, 24, 41, 54, 65, 77, 1, 41, 65, 92, 111, 41, 77, + }; + for (int k = 0; k < 4; ++k) { + auto v128 = _mm_loadu_si128((const __m128i *)k_values + k); + auto v256 = MM256_SET_M128I(v128, v128); + values[k] = _mm512_inserti32x8(_mm512_castsi256_si512(v256), v256, 1); + } + } + + struct { __m512i values[4]; } bits; + Scales8KBase s8k; + const __m512i m01 = _mm512_set1_epi8(0x01); + const __m512i m10 = _mm512_set1_epi8(0x10); + const __m512i m4 = _mm512_set1_epi8(0xf); + __m512i values[4]; + const __m512i shuffles[4] = { + _mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1), + _mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1), + }; +}; + + struct DequantizerIQ4KS final : public BaseDequantizer { DequantizerIQ4KS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template @@ -1383,7 +1482,8 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf set_scales_8(all_scales, j, scales); - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v void set_functions(std::array& funcs) { #ifdef HAVE_FANCY_SIMD if constexpr (std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { @@ -2916,6 +3017,12 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array(kernels); break; + case GGML_TYPE_IQ2_KL: + set_functions(kernels); +#ifdef HAVE_FANCY_SIMD + func16 = mul_mat_iqX_k_q8_K_AVX512_new; +#endif + break; case GGML_TYPE_IQ3_KS: set_functions(kernels); break; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 0054f6cb..820a0aad 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -424,6 +424,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy, return iqk_convert_iquants_q80_r8(typeA, n, vx, bx, vy, nrc_x); case GGML_TYPE_IQ2_KS: case GGML_TYPE_IQ2_K: + //case GGML_TYPE_IQ2_KL: case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ3_K: case GGML_TYPE_IQ4_KSS: @@ -827,14 +828,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: return iqk_set_kernels_iquants(ne00, typeA, typeB, mm.funcs, mm.func16); - case GGML_TYPE_IQ3_KS: - case GGML_TYPE_IQ4_KS: - case GGML_TYPE_IQ5_KS: - case GGML_TYPE_IQ4_KSS: - case GGML_TYPE_IQ2_K: case GGML_TYPE_IQ2_KS: + case GGML_TYPE_IQ2_K: + case GGML_TYPE_IQ2_KL: + case GGML_TYPE_IQ3_KS: case GGML_TYPE_IQ3_K: + case GGML_TYPE_IQ4_KSS: + case GGML_TYPE_IQ4_KS: case GGML_TYPE_IQ4_K: + case GGML_TYPE_IQ5_KS: case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: case GGML_TYPE_IQ2_K_R4: