diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index fbe70dda..fbd22b68 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -340,14 +340,8 @@ struct DequantizerIQ4KSS final : public BaseDequantizer { struct DequantizerIQ2KL final : public BaseDequantizer { DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } inline __m128i make_scales(int i) const { - //uint16_t aux[8]; - //auto h = x[i].scales_h; - //for (int k = 0; k < 4; ++k) { aux[k+0] = (x[i].scales_l[k] & 0xf) | ((h << 4) & 0x30); h >>= 2; } - //for (int k = 0; k < 4; ++k) { aux[k+4] = (x[i].scales_l[k] >> 4) | ((h << 4) & 0x30); h >>= 2; } - //return _mm_sub_epi16(_mm_loadu_si128((const __m128i *)aux), _mm_set1_epi16(32)); uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4); auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), _mm_set_epi32(0, 0, 4, 0)), _mm_set1_epi8(0xf))); - // 0x000a000800060004 auto sch = _mm_srlv_epi16(_mm_sllv_epi64(_mm_set1_epi16(x[i].scales_h), _mm_set_epi64x(0, 8)), _mm_set1_epi64x(0x000a000800060004)); auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, _mm_and_si128(sch, _mm_set1_epi16(0x30))), _mm_set1_epi16(32)); return scales128; @@ -376,14 +370,6 @@ struct DequantizerIQ2KL final : public BaseDequantizer inline void prepare(int i) { __m512i ql[2], qs[4]; __mmask64 mask[2]; - //// TODO: optimize this - //for (int k = 0; k < 2; ++k) { - // auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+0); - // auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*k+1); - // auto c1 = MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1); - // auto c2 = MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2); - // ql[k] = _mm512_and_si512(m4, _mm512_inserti32x8(_mm512_castsi256_si512(c1), c2, 1)); - //} auto lbits = _mm512_loadu_si512((const __m512i *)x[i].qs); ql[0] = _mm512_and_si512(lbits, m4); ql[1] = _mm512_and_si512(_mm512_srli_epi16(lbits, 4), m4); @@ -1013,6 +999,68 @@ struct DequantizerIQ2KS final : public BaseDequantizer const __m128i shift = _mm_set_epi32(0, 0, 4, 0); }; +struct DequantizerIQ2KL final : public BaseDequantizer { + DequantizerIQ2KL(const void * vx, size_t bx) : BaseDequantizer(vx, bx) { load_values(); } + template + inline __m256i new_block(int i, const Q8& q8, __m256 * accm) { + auto hbits128 = _mm_loadu_si128((const __m128i *)x[i].qh); + hbits = MM256_SET_M128I(_mm_srli_epi16(hbits128, 1), hbits128); + auto scales128 = make_scales(i); + auto scales_s = _mm_mullo_epi16(scales128, _mm_set1_epi16(-64)); + s8k.accum_mins(scales_s, q8, i, d, accm); + return MM256_SET_M128I(scales128, scales128); + } + inline void prepare(int i, int j) { + __m256i ql[2], mask[2]; + auto b1 = _mm_loadu_si128((const __m128i *)x[i].qs+2*j+0); + auto b2 = _mm_loadu_si128((const __m128i *)x[i].qs+2*j+1); + ql[0] = _mm256_and_si256(_mm256_set1_epi8(0xf), MM256_SET_M128I(_mm_srli_epi16(b1, 4), b1)); + ql[1] = _mm256_and_si256(_mm256_set1_epi8(0xf), MM256_SET_M128I(_mm_srli_epi16(b2, 4), b2)); + mask[0] = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, _mm256_set1_epi8(0x1)), _mm256_set1_epi8(0x1)); + mask[1] = _mm256_cmpeq_epi8(_mm256_and_si256(hbits, _mm256_set1_epi8(0x4)), _mm256_set1_epi8(0x4)); + for (int k = 0; k < 2; ++k) { + auto v0 = _mm256_shuffle_epi8(values[0], ql[k]); + auto v1 = _mm256_shuffle_epi8(values[1], ql[k]); + auto v2 = _mm256_shuffle_epi8(values[2], ql[k]); + auto v3 = _mm256_shuffle_epi8(values[3], ql[k]); + auto q1 = _mm256_or_si256(_mm256_and_si256(mask[k], v1), _mm256_andnot_si256(mask[k], v0)); + auto q2 = _mm256_or_si256(_mm256_and_si256(mask[k], v3), _mm256_andnot_si256(mask[k], v2)); + auto q1l = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(q1)); + auto q1h = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q1, 1)); + auto q2l = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(q2)); + auto q2h = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(q2, 1)); + bits.values[2*k+0] = _mm256_or_si256(q1l, _mm256_slli_epi16(q2l, 8)); + bits.values[2*k+1] = _mm256_or_si256(q1h, _mm256_slli_epi16(q2h, 8)); + } + hbits = _mm256_srli_epi16(hbits, 4); + } + inline __m128i make_scales(int i) const { + uint32_t aux32; std::memcpy(&aux32, x[i].scales_l, 4); + auto scl = _mm_cvtepu8_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(aux32), shift), _mm_set1_epi8(0xf))); + auto sch = _mm_srlv_epi32(_mm_set1_epi16(x[i].scales_h), _mm_set_epi32(12, 8, 4, 0)); + sch = _mm_and_si128(sch, _mm_set1_epi32(0x000c0003)); + sch = _mm_mullo_epi16(sch, _mm_set1_epi32(0x00040010)); + auto scales128 = _mm_sub_epi16(_mm_or_si128(scl, sch), _mm_set1_epi16(32)); + return scales128; + } + void load_values() { + static const uint8_t k_values[64] = { + 1, 1, 24, 24, 24, 24, 41, 41, 41, 41, 41, 54, 54, 54, 54, 65, 65, 65, 65, 65, 77, 77, 77, 77, 77, 92, 92, 92, 92, 92, 111, 111, + 41, 77, 1, 54, 77, 111, 24, 41, 65, 77, 92, 1, 65, 77, 111, 41, 54, 65, 77, 92, 24, 41, 54, 65, 77, 1, 41, 65, 92, 111, 41, 77, + }; + for (int k = 0; k < 4; ++k) { + auto v128 = _mm_loadu_si128((const __m128i *)k_values + k); + values[k] = MM256_SET_M128I(v128, v128); + } + } + struct { __m256i values[4]; } bits; + Scales8KBase s8k; + + __m256i values[4]; + __m256i hbits; + const __m128i shift = _mm_set_epi32(0, 0, 4, 0); +}; + struct DequantizerIQ2K final : public BaseDequantizer { DequantizerIQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx), iqxk(5, -32), values(load_values()) {} template @@ -1488,8 +1536,7 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf set_scales_8(all_scales, j, scales); - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { multiply_add_avx2(deq.bits, scales, j, i, q8, sumi); } else { multiply_add(deq.bits, scales, j, i, q8, sumi);