From f83e64dcb6aa65e80fa2c816601df2847481f24e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 17 May 2025 14:32:00 +0300 Subject: [PATCH] Refactor iqk: Factor out GEMM for legacy quants (AVX2/AVX512) --- ggml/src/CMakeLists.txt | 6 +- ggml/src/iqk/iqk_common.h | 103 +++ ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 789 +++++++++++++++++++++ ggml/src/iqk/iqk_gemm_legacy_quants.h | 11 + ggml/src/iqk/iqk_mul_mat.cpp | 870 +----------------------- 5 files changed, 912 insertions(+), 867 deletions(-) create mode 100644 ggml/src/iqk/iqk_gemm_legacy_quants.cpp create mode 100644 ggml/src/iqk/iqk_gemm_legacy_quants.h diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 7648b745..33b873ec 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -260,10 +260,12 @@ if (GGML_IQK_MUL_MAT) add_compile_definitions(GGML_USE_IQK_MULMAT) set(GGML_SOURCES_IQK_MM iqk/iqk_mul_mat.cpp iqk/iqk_flash_attn.cpp - iqk/iqk_gemm_floats.cpp) + iqk/iqk_gemm_floats.cpp + iqk/iqk_gemm_legacy_quants.cpp) set(GGML_HEADERS_IQK_MM iqk/iqk_mul_mat.h iqk/iqk_flash_impl.h - iqk/iqk_gemm_floats.h) + iqk/iqk_gemm_floats.h + iqk/iqk_gemm_legacy_quants.h) if (GGML_IQK_FLASH_ATTENTION) message(STATUS "Enabling IQK Flash Attention kernels") add_compile_definitions(GGML_IQK_FLASH_ATTENTION) diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 9d8269b9..02f49641 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -7,6 +7,8 @@ // SPDX-License-Identifier: MIT // +#pragma once + #include "iqk_config.h" #if defined IQK_IMPLEMENT @@ -135,4 +137,105 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf #define IQK_MAX_NY 8 +// ================================================================================================== + +#ifdef __AVX2__ + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +static inline float hsum_float_4(__m128 x) { + x = _mm_add_ps(x, _mm_movehl_ps(x, x)); + x = _mm_add_ss(x, _mm_movehdup_ps(x)); + return _mm_cvtss_f32(x); +} +static inline float hsum_float_8(__m256 x) { + return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); +} +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} +static inline float hmax_float_8(__m256 x) { + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); + return _mm_cvtss_f32(max4); +} + +static inline __m256 hsum_float_8x8(__m256 * accm) { + for (int i = 0; i < 4; ++i) { + accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); + //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), + // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); + } + for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); + return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); +} + +static inline __m128i load_iq4nl_values_128() { + static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; + return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); +} + +static inline __m256i load_iq4nl_values_256() { + auto val128 = load_iq4nl_values_128(); + return MM256_SET_M128I(val128, val128); +} + +static inline __m128i load_iq4k_values_128() { + return _mm_loadu_si128((const __m128i *)iq4k_values); +} + +static inline __m256i load_iq4k_values_256() { + auto val128 = load_iq4k_values_128(); + return MM256_SET_M128I(val128, val128); +} + +template struct Q8 { + + constexpr static int nrc_y = nrc; + + Q8(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); + } + +#ifdef HAVE_FANCY_SIMD + inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); } +#endif + inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } + inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); } + inline float scale(int iy, int i) const { return y[iy][i].d; } + + const block_q8 * y[nrc_y]; +}; + +template struct Q8_16 { + + constexpr static int nrc_y = nrc; + + Q8_16(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto ptr = (const float *)info.src1_row(iy); + std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); + y[iy] = (const int8_t *)(ptr + 5); + } + } + +#ifdef HAVE_FANCY_SIMD + inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); } +#endif + inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); } + inline float scale(int iy, int k) const { return d[5*iy+k]; } + inline float sum_row(int iy) const { return d[5*iy + 4]; } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); } + + float d[5*nrc_y]; + const int8_t * y[nrc_y]; +}; + +#endif + #endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp new file mode 100644 index 00000000..ed60afbc --- /dev/null +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -0,0 +1,789 @@ +#include "iqk_gemm_legacy_quants.h" + +#ifdef IQK_IMPLEMENT + +#include "ggml-impl.h" + +#define GGML_COMMON_IMPL_C +#include "ggml-common.h" + +// +// ============================== Legacy quants +// + +namespace { + +struct DotHelper { + const __m256i m1 = _mm256_set1_epi16(1); +#if defined(__AVX512VNNI__) && defined(__AVX512VL__) + inline __m256i dot(__m256i x, __m256i y) const { + return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y); + } +#else + inline __m256i dot(__m256i x, __m256i y) const { + return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y)); + } +#endif +}; + +struct SignedDot { + DotHelper helper; + inline __m256i compute(__m256i x, __m256i y) const { + return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x)); + } +}; +struct UnsignedDot { + DotHelper helper; + inline __m256i compute(__m256i x, __m256i y) const { + return helper.dot(x, y); + } +}; + +template struct Sum4 { + Dot dot; + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0 + const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1 + const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2 + const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3 + if constexpr (can_pack) { + const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1 + const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3 + return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3 + } else { + // Note to myself: this is much faster than using _mm256_hadd_epi32() + auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1 + auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3 + return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 + } + } + inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); } +}; + +template struct Sum4q4 { + inline __m256i compute(const __m256i * qx, const Q8 * y) const { + const Q8x4 * y4 = (const Q8x4 *)y; + auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0 + auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1 + auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2 + auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3 + auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1 + auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3 + auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 + return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123); + } + inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); } +}; + +struct ScaleHelperQ8_0 { + inline __m128 prepare4(const block_q8_0 * y) { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y; + return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d)); + } + inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) { + return _mm_mul_ps(other_scales, prepare4(y)); + } + template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } + template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } +}; + +struct ScaleHelperQ_0 { + ggml_half scales8[4]; + template + inline __m128 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; + return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); + } + template + inline __m128 prepare4(__m128 other_scales, const Q * y) { + return _mm_mul_ps(other_scales, prepare4(y)); + } + template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } + template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } +}; + +template +struct ScaleHelperQ_0_1 { + ggml_half scales8[4]; + template + inline __m256 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; + auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); + return _mm256_set_m128(_mm_mul_ps(s4, min), s4); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm_mul256_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + float d = GGML_FP16_TO_FP32(y->d); + return std::make_pair(d, -d*float(min_value)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); + } + const __m128 min = _mm_set1_ps(float(-min_value)); +}; + +//template +//struct ScaleHelperQ_0_2 { +// ggml_bf16_t scales8[4]; +// template +// inline __m256 prepare4(const Q * y) { +// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; +// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16)); +// return _mm256_set_m128(_mm_mul_ps(s4, min), s4); +// } +// template +// inline __m256 prepare4(__m256 other_scales, const Q * y) { +// return _mm_mul256_ps(other_scales, prepare4(y)); +// } +// template inline std::pair prepare1(const Q * y) const { +// float d = GGML_BF16_TO_FP32(y->d); +// return std::make_pair(d, -d*float(min_value)); +// } +// std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { +// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); +// } +// const __m128 min = _mm_set1_ps(float(-min_value)); +//}; + +struct ScaleHelperQ8_1 { + template + inline __m256 prepare4(const Q * y) { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y; + return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)y4->d)); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); + } +}; + +struct ScaleHelperQ8_2 { + template + inline __m256 prepare4(const Q * y) { + const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y; + auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d)); + return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16)); + } + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; + return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_2 * y) const { + ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; + return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); + } +}; + +struct ScaleHelperQ_1 { + uint32_t scales8[4]; + const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); + + template + inline __m256 prepare4(const Q * y) { + for (int j = 0; j < 4; ++j) { + // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers + // complain that this breaks strict-aliasing rules. + memcpy(scales8 + j, &y[j].d, sizeof(uint32_t)); + } + return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle)); + } + + template + inline __m256 prepare4(__m256 other_scales, const Q * y) { + return _mm256_mul_ps(other_scales, prepare4(y)); + } + + template inline std::pair prepare1(const Q * y) const { + return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); + } + template inline std::pair prepare1(const std::pair& dm, const Q * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); + } + std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { + return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); + } +}; + +struct MinusType0 { + inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } + inline float compute(float d, int) const { return d; } + inline float result(__m256 acc, int) const { return hsum_float_8(acc); } + inline __m256 vresult(__m256 acc, int) const { return acc; } +}; + +template struct MinusType1 { + __m128 accm[nrc_y]; + MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); } + inline __m256 compute(__m256 dm, int iy) { + const __m128 d = _mm256_castps256_ps128(dm); + const __m128 m = _mm256_extractf128_ps(dm, 1); + accm[iy] = _mm_add_ps(accm[iy], m); + return _mm256_set_m128(d, d); + } + inline float compute(const std::pair& dm, int iy) { + accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f)); + return dm.first; + } + inline float result(__m256 acc, int iy) const { + const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); + return hsum_float_4(_mm_add_ps(sum, accm[iy])); + } + inline __m256 vresult(__m256 acc, int iy) const { + return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0)); + } +}; + +template struct AccumT { + __m256 acc[nrc_y]; + Minus accm; + AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); } + template + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, accm.result(acc[iy], iy)); + } + } + template + inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) { + auto qx = unp.quants(); + __m256 dall[nrc_y]; + for (int i = 0; i < nb/4; ++i) { + auto other_scales = unp.set_block_4(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); + dall[iy] = accm.compute(s12, iy); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto pall = sum.compute(qx, y[iy] + 4*i); + acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); + } + } + if (!is_multiple_of_4) { + for (int i = 4*(nb/4); i < nb; ++i) { + auto other_scales = unp.set_block(i); + for (int iy = 0; iy < nrc_y; ++iy) { + auto s12 = scales.prepare1(other_scales, y[iy] + i); + auto d = accm.compute(s12, iy); + const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); + acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); + } + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + result[iy] = accm.vresult(acc[iy], iy); + } + } +}; + +template +using AccumType0 = AccumT; + +template +using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; + +using Sum4TypeQ80 = Sum4; +using Sum4TypeQ82 = Sum4; + +template +void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ++ix) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + +template +void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { + GGML_ASSERT(nrc_x%2 == 0); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + Scales scales; + for (int ix = 0; ix < nrc_x; ix += 2) { + unp.set_row(ix); + AccumType accum; + accum.compute(nb, unp, scales, sum4, y, info, ix); + } +} + +template +void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_0, block_q8_0, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_0, block_q8_0, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + +template +void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_0 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType0 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + +template +void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_1, block_q8_1, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_1, block_q8_1, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + +template +void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%Unpacker::block_size() == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + if (nb%4 == 0) { + mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } else { + mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( + nb, vx, bx, info, q8.y, nrc_x + ); + } +} + +template +void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { + static_assert(8%nrc_y == 0); + Q8 q8(info); + int nb = n/Unpacker::block_size(); + Unpacker unp(vx, bx); + typename Unpacker::Sum4T sum4; + ScaleHelperQ8_2 scales; + __m256 result[8]; + auto store = [&info, &result] (int ix0) { + if constexpr (nrc_y == 1) { + info.store(ix0, 0, hsum_float_8x8(result)); + } + else if constexpr (nrc_y == 2) { + auto value = hsum_float_8x8(result); + auto value1 = _mm256_extractf128_ps(value, 1); + info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); + info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); + } + else { + float val[8]; + _mm256_storeu_ps(val, hsum_float_8x8(result)); + for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); + } + }; + if (nb%4 == 0) { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } else { + for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { + for (int ix = 0; ix < 8/nrc_y; ++ix) { + unp.set_row(ix0 + ix); + AccumType1 accum; + accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); + } + store(ix0); + } + } +} + +struct Dequantizer4bit { + const __m256i m4 = _mm256_set1_epi8(0xf); + inline __m256i dequant(const uint8_t * qs) const { + const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); + return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4); + } +}; + +struct Q8_0_Dequantizer { + inline __m256i dequant(const block_q8_0 * x) const { + return _mm256_loadu_si256((const __m256i *)x->qs); + } +}; + +struct Q8_0_1_Dequantizer { + inline __m256i dequant(const block_q8_0 * x) const { + return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs)); + } +}; + +struct Q4_0_Dequantizer { + Dequantizer4bit b4; + const __m256i m8 = _mm256_set1_epi8(-8); + inline __m256i dequant(const block_q4_0 * x) const { + return _mm256_add_epi8(b4.dequant(x->qs), m8); + } +}; + +struct Q4_0_1_Dequantizer { + Dequantizer4bit b4; + inline __m256i dequant(const block_q4_0 * x) const { + return b4.dequant(x->qs); + } +}; + +struct IQ4_NL_Dequantizer { + Dequantizer4bit b4; +#ifdef HAVE_FANCY_SIMD + const __m256i values = load_iq4nl_values_256(); +#else + const __m256i values = load_iq4k_values_256(); +#endif + inline __m256i dequant(const block_iq4_nl * x) const { + return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); + } +}; + +struct Q4_1_Dequantizer { + Dequantizer4bit b4; + inline __m256i dequant(const block_q4_1 * x) const { + return b4.dequant(x->qs); + } +}; + +struct HBitDequantizer { + const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); + const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + const __m256i minus1 = _mm256_set1_epi64x(-1); + inline __m256i to_bytes(const uint8_t * bits) const { + // Note: Data in all ggml quants is at least 2-byte aligned. + // => we can cast to uint16_t and use or on two consecutive entries + // which is faster than memcpy + const uint16_t * aux16 = (const uint16_t *)bits; + const uint32_t aux32 = aux16[0] | (aux16[1] << 16); + //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t)); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle); + bytes = _mm256_or_si256(bytes, mask); + return _mm256_cmpeq_epi8(bytes, minus1); + } +}; + +struct Q5_0_Dequantizer { + Dequantizer4bit b4; + HBitDequantizer hbit; + const __m256i mh = _mm256_set1_epi8((char)0xF0); + inline __m256i dequant(const block_q5_0 * x) const { + const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh); + return _mm256_or_si256(b4.dequant(x->qs), vqh); + } +}; + +template +struct Q5_1_Dequantizer { + Dequantizer4bit b4; + HBitDequantizer hbit; + const __m256i mh = _mm256_set1_epi8(0x10); + inline __m256i dequant(const Q5 * x) const { + const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh); + return _mm256_or_si256(b4.dequant(x->qs), vqh); + } +}; +struct Q6_0_1_Dequantizer { + Dequantizer4bit b4; + const __m256i mh = _mm256_set1_epi8(0x30); + const __m256i shift1 = _mm256_set_epi64x(0, 2, 0, 4); + const __m256i shift2 = _mm256_set_epi64x(2, 0, 0, 0); + inline __m256i dequant(const block_q6_0 * x) const { + uint64_t aux64; std::memcpy(&aux64, x->qh, 8); + auto h256 = _mm256_sllv_epi64(_mm256_set1_epi64x(aux64), shift1); + return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(_mm256_srlv_epi64(h256, shift2), mh)); + } +}; + +template +struct Q_Unpacker { + Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {} + + const char * cx_0; + const Q * x; + size_t bx; + + Scales scales; + Dequantizer deq; + + __m256i qx[4]; + + inline const __m256i* quants() const { return qx; } + + inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); } + + inline auto set_block_4(int i) { + for (int j = 0; j < 4; ++j) { + qx[j] = deq.dequant(x + 4*i + j); + } + return scales.prepare4(x + 4*i); + } + inline auto set_block(int i) { + qx[0] = deq.dequant(x + i); + return scales.prepare1(x + i); + } +}; + +struct Q8_0_Unpacker final : public Q_Unpacker { + Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK8_0; } +}; +struct Q8_0_1_Unpacker final : public Q_Unpacker, Q8_0_1_Dequantizer> { + Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK8_0; } +}; +struct Q4_0_Unpacker final : public Q_Unpacker { + Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK4_0; } +}; +struct Q4_0_1_Unpacker final : public Q_Unpacker, Q4_0_1_Dequantizer> { + Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + //using Sum4T = Sum4TypeQ82; + using Sum4T = Sum4q4; + inline static int block_size() { return QK4_0; } +}; +#ifdef HAVE_FANCY_SIMD +struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> { + IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK4_NL; } +}; +#else +struct IQ4_NL_Unpacker final : public Q_Unpacker { + IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK4_NL; } +}; +#endif +struct Q5_0_Unpacker final : public Q_Unpacker { + Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ80; + inline static int block_size() { return QK5_0; } +}; +struct Q5_0_1_Unpacker final : public Q_Unpacker, Q5_1_Dequantizer> { + Q5_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK5_0; } +}; +struct Q4_1_Unpacker final : public Q_Unpacker { + Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK4_1; } +}; +struct Q5_1_Unpacker final : public Q_Unpacker> { + Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK5_1; } +}; +struct Q6_0_1_Unpacker final : public Q_Unpacker, Q6_0_1_Dequantizer> { + Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} + using Sum4T = Sum4TypeQ82; + inline static int block_size() { return QK6_0; } +}; + +template void set_functions(std::array& funcs) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + funcs[0] = mul_mat_qX_0_q8_0_T; + funcs[1] = mul_mat_qX_0_q8_0_T; + funcs[2] = mul_mat_qX_0_q8_0_T; + funcs[3] = mul_mat_qX_0_q8_0_T; + funcs[4] = mul_mat_qX_0_q8_0_T; + funcs[5] = mul_mat_qX_0_q8_0_T; + funcs[6] = mul_mat_qX_0_q8_0_T; + funcs[7] = mul_mat_qX_0_q8_0_T; + } + else if constexpr (std::is_same_v || std::is_same_v) { + funcs[0] = mul_mat_qX_1_q8_2_T; + funcs[1] = mul_mat_qX_1_q8_2_T; + funcs[2] = mul_mat_qX_1_q8_2_T; + funcs[3] = mul_mat_qX_1_q8_2_T; + funcs[4] = mul_mat_qX_1_q8_2_T; + funcs[5] = mul_mat_qX_1_q8_2_T; + funcs[6] = mul_mat_qX_1_q8_2_T; + funcs[7] = mul_mat_qX_1_q8_2_T; + } + else if constexpr (std::is_same_v) { +#ifdef HAVE_FANCY_SIMD + funcs[0] = mul_mat_qX_1_q8_2_T; + funcs[1] = mul_mat_qX_1_q8_2_T; + funcs[2] = mul_mat_qX_1_q8_2_T; + funcs[3] = mul_mat_qX_1_q8_2_T; + funcs[4] = mul_mat_qX_1_q8_2_T; + funcs[5] = mul_mat_qX_1_q8_2_T; + funcs[6] = mul_mat_qX_1_q8_2_T; + funcs[7] = mul_mat_qX_1_q8_2_T; +#else + funcs[0] = mul_mat_qX_0_q8_0_T; + funcs[1] = mul_mat_qX_0_q8_0_T; + funcs[2] = mul_mat_qX_0_q8_0_T; + funcs[3] = mul_mat_qX_0_q8_0_T; + funcs[4] = mul_mat_qX_0_q8_0_T; + funcs[5] = mul_mat_qX_0_q8_0_T; + funcs[6] = mul_mat_qX_0_q8_0_T; + funcs[7] = mul_mat_qX_0_q8_0_T; +#endif + } + else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + funcs[0] = mul_mat_qX_1_q8_2_T; + funcs[1] = mul_mat_qX_1_q8_2_T; + funcs[2] = mul_mat_qX_1_q8_2_T; + funcs[3] = mul_mat_qX_1_q8_2_T; + funcs[4] = mul_mat_qX_1_q8_2_T; + funcs[5] = mul_mat_qX_1_q8_2_T; + funcs[6] = mul_mat_qX_1_q8_2_T; + funcs[7] = mul_mat_qX_1_q8_2_T; + } +} + +} // namespace + +bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels) { + + if (ne00%QK8_0 != 0) return false; + + auto expected_typeB = GGML_TYPE_Q8_2_X4; + + switch (typeA) { + case GGML_TYPE_Q4_0: + set_functions(kernels); + break; + case GGML_TYPE_Q4_1: + set_functions(kernels); + break; + case GGML_TYPE_Q5_0: + set_functions(kernels); + break; + case GGML_TYPE_Q5_1: + set_functions(kernels); + break; + case GGML_TYPE_Q6_0: + set_functions(kernels); + break; + case GGML_TYPE_Q8_0: +#ifdef HAVE_FANCY_SIMD + set_functions(kernels); +#else + set_functions(kernels); + expected_typeB = GGML_TYPE_Q8_0_X4; +#endif + break; + case GGML_TYPE_IQ4_NL: + set_functions(kernels); +#ifndef HAVE_FANCY_SIMD + expected_typeB = GGML_TYPE_Q8_0_X4; +#endif + break; + default: + return false; + } + + return ggml_type(typeB) == expected_typeB; +} + +#endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.h b/ggml/src/iqk/iqk_gemm_legacy_quants.h new file mode 100644 index 00000000..dd6d097a --- /dev/null +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.h @@ -0,0 +1,11 @@ +#pragma once + +#include "iqk_common.h" + +#ifdef IQK_IMPLEMENT + +#include + +bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels); + +#endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 2b3b208c..73ba1364 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -21,6 +21,7 @@ #include "iqk_quantize.h" #include "iqk_flash_impl.h" #include "iqk_gemm_floats.h" +#include "iqk_gemm_legacy_quants.h" #define GGML_COMMON_IMPL_C #include "ggml-common.h" @@ -1313,72 +1314,6 @@ const uint64_t keven_signs[128] = { namespace { -inline float hsum_float_4(__m128 x) { - x = _mm_add_ps(x, _mm_movehl_ps(x, x)); - x = _mm_add_ss(x, _mm_movehdup_ps(x)); - return _mm_cvtss_f32(x); -} -inline float hsum_float_8(__m256 x) { - return hsum_float_4(_mm_add_ps(_mm256_castps256_ps128(x), _mm256_extractf128_ps(x, 1))); -} -inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} -inline float hmax_float_8(__m256 x) { - __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); - max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4)); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4)); - return _mm_cvtss_f32(max4); -} - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -template struct Q8 { - - constexpr static int nrc_y = nrc; - - Q8(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy); - } - -#ifdef HAVE_FANCY_SIMD - inline __m512i load_quants64(int iy, int i, int j) const { return _mm512_loadu_si512((const __m512i*)y[iy][i].qs + j); } -#endif - inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].qs + j); } - inline __m256i load_bsums(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy][i].bsums); } - inline float scale(int iy, int i) const { return y[iy][i].d; } - - const block_q8 * y[nrc_y]; -}; - -template struct Q8_16 { - - constexpr static int nrc_y = nrc; - - Q8_16(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto ptr = (const float *)info.src1_row(iy); - std::memcpy(d + 5*iy, ptr, 5*sizeof(float)); - y[iy] = (const int8_t *)(ptr + 5); - } - } - -#ifdef HAVE_FANCY_SIMD - inline __m512i load_quants64(int iy, int i) const { return _mm512_loadu_si512((const __m512i*)y[iy] + i); } -#endif - inline __m256i load_quants(int iy, int i) const { return _mm256_loadu_si256((const __m256i*)y[iy] + i); } - inline float scale(int iy, int k) const { return d[5*iy+k]; } - inline float sum_row(int iy) const { return d[5*iy + 4]; } - inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 5*iy); } - - float d[5*nrc_y]; - const int8_t * y[nrc_y]; -}; - struct Scales8KBase { template inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const { @@ -1645,25 +1580,6 @@ struct SimpleBits { __m256i values[4]; }; -__m128i inline load_iq4nl_values_128() { - static const uint8_t kvalues_iq4nl[16] = {1, 24, 45, 63, 79, 93, 106, 118, 129, 141, 153, 166, 181, 197, 217, 241}; - return _mm_loadu_si128((const __m128i *)kvalues_iq4nl); -} - -__m256i inline load_iq4nl_values_256() { - auto val128 = load_iq4nl_values_128(); - return MM256_SET_M128I(val128, val128); -} - -__m128i inline load_iq4k_values_128() { - return _mm_loadu_si128((const __m128i *)iq4k_values); -} - -__m256i inline load_iq4k_values_256() { - auto val128 = load_iq4k_values_128(); - return MM256_SET_M128I(val128, val128); -} - #ifdef HAVE_FANCY_SIMD //====================================== Zen4 ================================================== @@ -8462,750 +8378,10 @@ struct DequantizerIQ2XXS final : public BaseDequantizer { const __m256i shuffle = _mm256_set_epi32(7, 5, 3, 1, 7, 5, 3, 1); }; -// -// ============================== Legacy quants -// - -struct DotHelper { - const __m256i m1 = _mm256_set1_epi16(1); -#if defined(__AVX512VNNI__) && defined(__AVX512VL__) - inline __m256i dot(__m256i x, __m256i y) const { - return _mm256_dpbusd_epi32(_mm256_setzero_si256(), x, y); - } -#else - inline __m256i dot(__m256i x, __m256i y) const { - return _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x, y)); - } -#endif -}; - -struct SignedDot { - DotHelper helper; - inline __m256i compute(__m256i x, __m256i y) const { - return helper.dot(_mm256_sign_epi8(x, x), _mm256_sign_epi8(y, x)); - } -}; -struct UnsignedDot { - DotHelper helper; - inline __m256i compute(__m256i x, __m256i y) const { - return helper.dot(x, y); - } -}; - -template struct Sum4 { - Dot dot; - inline __m256i compute(const __m256i * qx, const Q8 * y) const { - const Q8x4 * y4 = (const Q8x4 *)y; - const __m256i p0 = dot.compute(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 8x block 0 - const __m256i p1 = dot.compute(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 8x block 1 - const __m256i p2 = dot.compute(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 8x block 2 - const __m256i p3 = dot.compute(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 8x block 3 - if constexpr (can_pack) { - const __m256i p01 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1 - const __m256i p23 = _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3 - return _mm256_madd_epi16(dot.helper.m1, _mm256_packs_epi32(p01, p23)); // 0,1,2,3, 0,1,2,3 - } else { - // Note to myself: this is much faster than using _mm256_hadd_epi32() - auto p01 = _mm256_add_epi32(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,1, 0,1, 0,1, 0,1 - auto p23 = _mm256_add_epi32(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,3, 2,3, 2,3, 2,3 - return _mm256_add_epi32(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,1,2,3, 0,1,2,3 - } - } - inline __m256i compute(__m256i x, __m256i y) const { return dot.compute(x, y); } -}; - -template struct Sum4q4 { - inline __m256i compute(const __m256i * qx, const Q8 * y) const { - const Q8x4 * y4 = (const Q8x4 *)y; - auto p0 = _mm256_maddubs_epi16(qx[0], _mm256_loadu_si256((const __m256i *)y4->qs+0)); // 16x block 0 - auto p1 = _mm256_maddubs_epi16(qx[1], _mm256_loadu_si256((const __m256i *)y4->qs+1)); // 16x block 1 - auto p2 = _mm256_maddubs_epi16(qx[2], _mm256_loadu_si256((const __m256i *)y4->qs+2)); // 16x block 2 - auto p3 = _mm256_maddubs_epi16(qx[3], _mm256_loadu_si256((const __m256i *)y4->qs+3)); // 16x block 3 - auto p01 = _mm256_add_epi16(_mm256_unpacklo_epi32(p0, p1), _mm256_unpackhi_epi32(p0, p1)); // 0,0, 1,1, 0,0, 1,1, 0,0, 1,1, 0,0, 1,1 - auto p23 = _mm256_add_epi16(_mm256_unpacklo_epi32(p2, p3), _mm256_unpackhi_epi32(p2, p3)); // 2,2, 3,3, 2,2, 3,3, 2,2, 3,3, 2,2, 3,3 - auto p0123 = _mm256_add_epi16(_mm256_unpacklo_epi64(p01, p23), _mm256_unpackhi_epi64(p01, p23)); // 0,0, 1,1, 2,2, 3,3, 0,0, 1,1, 2,2, 3,3 - return _mm256_madd_epi16(_mm256_set1_epi16(1), p0123); - } - inline __m256i compute(__m256i x, __m256i y) const { return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(x, y)); } -}; - -struct ScaleHelperQ8_0 { - inline __m128 prepare4(const block_q8_0 * y) { - const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y; - return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4->d)); - } - inline __m128 prepare4(__m128 other_scales, const block_q8_0 * y) { - return _mm_mul_ps(other_scales, prepare4(y)); - } - template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } - template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } -}; - -struct ScaleHelperQ_0 { - ggml_half scales8[4]; - template - inline __m128 prepare4(const Q * y) { - for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; - return _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); - } - template - inline __m128 prepare4(__m128 other_scales, const Q * y) { - return _mm_mul_ps(other_scales, prepare4(y)); - } - template inline float prepare1(const Q * y) const { return GGML_FP16_TO_FP32(y->d); } - template inline float prepare1(float d, const Q * y) const { return d*prepare1(y); } -}; - -template -struct ScaleHelperQ_0_1 { - ggml_half scales8[4]; - template - inline __m256 prepare4(const Q * y) { - for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; - auto s4 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)scales8)); - return _mm256_set_m128(_mm_mul_ps(s4, min), s4); - } - template - inline __m256 prepare4(__m256 other_scales, const Q * y) { - return _mm_mul256_ps(other_scales, prepare4(y)); - } - template inline std::pair prepare1(const Q * y) const { - float d = GGML_FP16_TO_FP32(y->d); - return std::make_pair(d, -d*float(min_value)); - } - std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { - return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); - } - const __m128 min = _mm_set1_ps(float(-min_value)); -}; - -//template -//struct ScaleHelperQ_0_2 { -// ggml_bf16_t scales8[4]; -// template -// inline __m256 prepare4(const Q * y) { -// for (int j = 0; j < 4; ++j) scales8[j] = y[j].d; -// auto s4 = _mm_castsi128_ps(_mm_slli_epi16(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)scales8)), 16)); -// return _mm256_set_m128(_mm_mul_ps(s4, min), s4); -// } -// template -// inline __m256 prepare4(__m256 other_scales, const Q * y) { -// return _mm_mul256_ps(other_scales, prepare4(y)); -// } -// template inline std::pair prepare1(const Q * y) const { -// float d = GGML_BF16_TO_FP32(y->d); -// return std::make_pair(d, -d*float(min_value)); -// } -// std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { -// return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); -// } -// const __m128 min = _mm_set1_ps(float(-min_value)); -//}; - -struct ScaleHelperQ8_1 { - template - inline __m256 prepare4(const Q * y) { - const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y; - return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)y4->d)); - } - template - inline __m256 prepare4(__m256 other_scales, const Q * y) { - return _mm256_mul_ps(other_scales, prepare4(y)); - } - template inline std::pair prepare1(const Q * y) const { - return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); - } - template inline std::pair prepare1(const std::pair& dm, const Q * y) const { - return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); - } - std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { - return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); - } -}; - -struct ScaleHelperQ8_2 { - template - inline __m256 prepare4(const Q * y) { - const block_q8_2_x4 * y4 = (const block_q8_2_x4 *)y; - auto aux = _mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)y4->d)); - return _mm256_castsi256_ps(_mm256_slli_epi32(aux, 16)); - } - template - inline __m256 prepare4(__m256 other_scales, const Q * y) { - return _mm256_mul_ps(other_scales, prepare4(y)); - } - template inline std::pair prepare1(const Q * y) const { - return std::make_pair(GGML_BF16_TO_FP32(y->d), GGML_BF16_TO_FP32(y->m)); - } - template inline std::pair prepare1(const std::pair& dm, const Q * y) const { - ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; - return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); - } - std::pair inline prepare1(const std::pair& dm, const block_q8_2 * y) const { - ggml_bf16_t d, s; d.bits = y->d; s.bits = y->s; - return std::make_pair(dm.first*GGML_BF16_TO_FP32(d), dm.second*GGML_BF16_TO_FP32(s)); - } -}; - -struct ScaleHelperQ_1 { - uint32_t scales8[4]; - const __m128i shuffle = _mm_set_epi16(0x0f0e, 0x0b0a, 0x0706, 0x0302, 0x0d0c, 0x0908, 0x0504, 0x0100); - - template - inline __m256 prepare4(const Q * y) { - for (int j = 0; j < 4; ++j) { - // it is slightly faster to directly dereference (const uint32 *)&y[j].d, but some compilers - // complain that this breaks strict-aliasing rules. - memcpy(scales8 + j, &y[j].d, sizeof(uint32_t)); - } - return _mm256_cvtph_ps(_mm_shuffle_epi8(_mm_loadu_si128((const __m128i *)scales8), shuffle)); - } - - template - inline __m256 prepare4(__m256 other_scales, const Q * y) { - return _mm256_mul_ps(other_scales, prepare4(y)); - } - - template inline std::pair prepare1(const Q * y) const { - return std::make_pair(GGML_FP16_TO_FP32(y->d), GGML_FP16_TO_FP32(y->m)); - } - template inline std::pair prepare1(const std::pair& dm, const Q * y) const { - return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->m)); - } - std::pair inline prepare1(const std::pair& dm, const block_q8_1 * y) const { - return std::make_pair(dm.first*GGML_FP16_TO_FP32(y->d), dm.second*GGML_FP16_TO_FP32(y->s)); - } -}; - -struct MinusType0 { - inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); } - inline float compute(float d, int) const { return d; } - inline float result(__m256 acc, int) const { return hsum_float_8(acc); } - inline __m256 vresult(__m256 acc, int) const { return acc; } -}; - -template struct MinusType1 { - __m128 accm[nrc_y]; - MinusType1() { for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm_setzero_ps(); } - inline __m256 compute(__m256 dm, int iy) { - const __m128 d = _mm256_castps256_ps128(dm); - const __m128 m = _mm256_extractf128_ps(dm, 1); - accm[iy] = _mm_add_ps(accm[iy], m); - return _mm256_set_m128(d, d); - } - inline float compute(const std::pair& dm, int iy) { - accm[iy] = _mm_add_ps(accm[iy], _mm_set1_ps(dm.second*0.25f)); - return dm.first; - } - inline float result(__m256 acc, int iy) const { - const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); - return hsum_float_4(_mm_add_ps(sum, accm[iy])); - } - inline __m256 vresult(__m256 acc, int iy) const { - return _mm256_add_ps(acc, _mm256_insertf128_ps(_mm256_setzero_ps(), accm[iy], 0)); - } -}; - -template struct AccumT { - __m256 acc[nrc_y]; - Minus accm; - AccumT() { for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = _mm256_setzero_ps(); } - template - inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, const DataInfo& info, int ix) { - auto qx = unp.quants(); - __m256 dall[nrc_y]; - for (int i = 0; i < nb/4; ++i) { - auto other_scales = unp.set_block_4(i); - for (int iy = 0; iy < nrc_y; ++iy) { - auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); - dall[iy] = accm.compute(s12, iy); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto pall = sum.compute(qx, y[iy] + 4*i); - acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); - } - } - if (!is_multiple_of_4) { - for (int i = 4*(nb/4); i < nb; ++i) { - auto other_scales = unp.set_block(i); - for (int iy = 0; iy < nrc_y; ++iy) { - auto s12 = scales.prepare1(other_scales, y[iy] + i); - auto d = accm.compute(s12, iy); - const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); - } - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, accm.result(acc[iy], iy)); - } - } - template - inline void compute(int nb, Unpacker& unp, Scales& scales, Sum& sum, const Q8 ** y, __m256 * result) { - auto qx = unp.quants(); - __m256 dall[nrc_y]; - for (int i = 0; i < nb/4; ++i) { - auto other_scales = unp.set_block_4(i); - for (int iy = 0; iy < nrc_y; ++iy) { - auto s12 = scales.prepare4(other_scales, y[iy] + 4*i); - dall[iy] = accm.compute(s12, iy); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto pall = sum.compute(qx, y[iy] + 4*i); - acc[iy] = _mm256_fmadd_ps(dall[iy], _mm256_cvtepi32_ps(pall), acc[iy]); - } - } - if (!is_multiple_of_4) { - for (int i = 4*(nb/4); i < nb; ++i) { - auto other_scales = unp.set_block(i); - for (int iy = 0; iy < nrc_y; ++iy) { - auto s12 = scales.prepare1(other_scales, y[iy] + i); - auto d = accm.compute(s12, iy); - const __m256i p0 = sum.compute(qx[0], _mm256_loadu_si256((const __m256i *)y[iy][i].qs)); - acc[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(p0), acc[iy]); - } - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - result[iy] = accm.vresult(acc[iy], iy); - } - } -}; - -template -using AccumType0 = AccumT; - -template -using AccumType1 = AccumT, nrc_y, is_multiple_of_4>; - -using Sum4TypeQ80 = Sum4; -using Sum4TypeQ82 = Sum4; - -template -void mul_mat_qX_q8_Helper(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { - Unpacker unp(vx, bx); - typename Unpacker::Sum4T sum4; - Scales scales; - for (int ix = 0; ix < nrc_x; ++ix) { - unp.set_row(ix); - AccumType accum; - accum.compute(nb, unp, scales, sum4, y, info, ix); - } -} - -template -void mul_mat_qX_q8_Helper_x2(int nb, const void * vx, size_t bx, const DataInfo& info, const Q8 ** y, int nrc_x) { - GGML_ASSERT(nrc_x%2 == 0); - Unpacker unp(vx, bx); - typename Unpacker::Sum4T sum4; - Scales scales; - for (int ix = 0; ix < nrc_x; ix += 2) { - unp.set_row(ix); - AccumType accum; - accum.compute(nb, unp, scales, sum4, y, info, ix); - } -} - -template -void mul_mat_qX_0_q8_0_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%Unpacker::block_size() == 0); - Q8 q8(info); - int nb = n/Unpacker::block_size(); - if (nb%4 == 0) { - mul_mat_qX_q8_Helper, ScaleHelperQ8_0, block_q8_0, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } else { - mul_mat_qX_q8_Helper, ScaleHelperQ8_0, block_q8_0, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } -} - -inline __m256 hsum_float_8x8(__m256 * accm) { - for (int i = 0; i < 4; ++i) { - accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31)); - //accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)), - // _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1))); - } - for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2])); - return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1])); -} - -template -void mul_mat_qX_0_q8_0_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { - static_assert(8%nrc_y == 0); - Q8 q8(info); - int nb = n/Unpacker::block_size(); - Unpacker unp(vx, bx); - typename Unpacker::Sum4T sum4; - ScaleHelperQ8_0 scales; - __m256 result[8]; - auto store = [&info, &result] (int ix0) { - if constexpr (nrc_y == 1) { - info.store(ix0, 0, hsum_float_8x8(result)); - } - else if constexpr (nrc_y == 2) { - auto value = hsum_float_8x8(result); - auto value1 = _mm256_extractf128_ps(value, 1); - info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); - info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); - } - else { - float val[8]; - _mm256_storeu_ps(val, hsum_float_8x8(result)); - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); - } - }; - if (nb%4 == 0) { - for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { - for (int ix = 0; ix < 8/nrc_y; ++ix) { - unp.set_row(ix0 + ix); - AccumType0 accum; - accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); - } - store(ix0); - } - } else { - for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { - for (int ix = 0; ix < 8/nrc_y; ++ix) { - unp.set_row(ix0 + ix); - AccumType0 accum; - accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); - } - store(ix0); - } - } -} - - -template -void mul_mat_qX_1_q8_1_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%Unpacker::block_size() == 0); - Q8 q8(info); - int nb = n/Unpacker::block_size(); - if (nb%4 == 0) { - mul_mat_qX_q8_Helper, ScaleHelperQ8_1, block_q8_1, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } else { - mul_mat_qX_q8_Helper, ScaleHelperQ8_1, block_q8_1, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } -} - -template -void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - assert(n%Unpacker::block_size() == 0); - Q8 q8(info); - int nb = n/Unpacker::block_size(); - if (nb%4 == 0) { - mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } else { - mul_mat_qX_q8_Helper, ScaleHelperQ8_2, block_q8_2, nrc_y>( - nb, vx, bx, info, q8.y, nrc_x - ); - } -} - -template -void mul_mat_qX_0_q8_2_Tx(int n, const void * vx, size_t bx, const DataInfo& info, int) { - static_assert(8%nrc_y == 0); - Q8 q8(info); - int nb = n/Unpacker::block_size(); - Unpacker unp(vx, bx); - typename Unpacker::Sum4T sum4; - ScaleHelperQ8_2 scales; - __m256 result[8]; - auto store = [&info, &result] (int ix0) { - if constexpr (nrc_y == 1) { - info.store(ix0, 0, hsum_float_8x8(result)); - } - else if constexpr (nrc_y == 2) { - auto value = hsum_float_8x8(result); - auto value1 = _mm256_extractf128_ps(value, 1); - info.store(ix0, 0, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0x88)); - info.store(ix0, 1, _mm_shuffle_ps(_mm256_castps256_ps128(value), value1, 0xdd)); - } - else { - float val[8]; - _mm256_storeu_ps(val, hsum_float_8x8(result)); - for (int iy = 0; iy < nrc_y; ++iy) for (int ix = 0; ix < 8/nrc_y; ++ix) info.store(ix0+ix, iy, val[nrc_y*ix+iy]); - } - }; - if (nb%4 == 0) { - for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { - for (int ix = 0; ix < 8/nrc_y; ++ix) { - unp.set_row(ix0 + ix); - AccumType1 accum; - accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); - } - store(ix0); - } - } else { - for (int ix0 = 0; ix0 < nrc_x; ix0 += 8/nrc_y) { - for (int ix = 0; ix < 8/nrc_y; ++ix) { - unp.set_row(ix0 + ix); - AccumType1 accum; - accum.compute(nb, unp, scales, sum4, q8.y, result + nrc_y*ix); - } - store(ix0); - } - } -} - -struct Dequantizer4bit { - const __m256i m4 = _mm256_set1_epi8(0xf); - inline __m256i dequant(const uint8_t * qs) const { - const __m128i aux128 = _mm_loadu_si128((const __m128i *)qs); - return _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(aux128, 4), aux128), m4); - } -}; - -struct Q8_0_Dequantizer { - inline __m256i dequant(const block_q8_0 * x) const { - return _mm256_loadu_si256((const __m256i *)x->qs); - } -}; - -struct Q8_0_1_Dequantizer { - inline __m256i dequant(const block_q8_0 * x) const { - return _mm256_add_epi8(_mm256_set1_epi8(127), _mm256_loadu_si256((const __m256i *)x->qs)); - } -}; - -struct Q4_0_Dequantizer { - Dequantizer4bit b4; - const __m256i m8 = _mm256_set1_epi8(-8); - inline __m256i dequant(const block_q4_0 * x) const { - return _mm256_add_epi8(b4.dequant(x->qs), m8); - } -}; - -struct Q4_0_1_Dequantizer { - Dequantizer4bit b4; - inline __m256i dequant(const block_q4_0 * x) const { - return b4.dequant(x->qs); - } -}; - -struct IQ4_NL_Dequantizer { - Dequantizer4bit b4; -#ifdef HAVE_FANCY_SIMD - const __m256i values = load_iq4nl_values_256(); -#else - const __m256i values = load_iq4k_values_256(); -#endif - inline __m256i dequant(const block_iq4_nl * x) const { - return _mm256_shuffle_epi8(values, b4.dequant(x->qs)); - } -}; - -struct Q4_1_Dequantizer { - Dequantizer4bit b4; - inline __m256i dequant(const block_q4_1 * x) const { - return b4.dequant(x->qs); - } -}; - -struct HBitDequantizer { - const __m256i shuffle = _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202, 0x0101010101010101, 0x0000000000000000); - const __m256i mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - const __m256i minus1 = _mm256_set1_epi64x(-1); - inline __m256i to_bytes(const uint8_t * bits) const { - // Note: Data in all ggml quants is at least 2-byte aligned. - // => we can cast to uint16_t and use or on two consecutive entries - // which is faster than memcpy - const uint16_t * aux16 = (const uint16_t *)bits; - const uint32_t aux32 = aux16[0] | (aux16[1] << 16); - //uint32_t aux32; memcpy(&aux32, bits, sizeof(uint32_t)); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(aux32), shuffle); - bytes = _mm256_or_si256(bytes, mask); - return _mm256_cmpeq_epi8(bytes, minus1); - } -}; - -struct Q5_0_Dequantizer { - Dequantizer4bit b4; - HBitDequantizer hbit; - const __m256i mh = _mm256_set1_epi8((char)0xF0); - inline __m256i dequant(const block_q5_0 * x) const { - const __m256i vqh = _mm256_andnot_si256(hbit.to_bytes(x->qh), mh); - return _mm256_or_si256(b4.dequant(x->qs), vqh); - } -}; - -template -struct Q5_1_Dequantizer { - Dequantizer4bit b4; - HBitDequantizer hbit; - const __m256i mh = _mm256_set1_epi8(0x10); - inline __m256i dequant(const Q5 * x) const { - const __m256i vqh = _mm256_and_si256(hbit.to_bytes(x->qh), mh); - return _mm256_or_si256(b4.dequant(x->qs), vqh); - } -}; -struct Q6_0_1_Dequantizer { - Dequantizer4bit b4; - const __m256i mh = _mm256_set1_epi8(0x30); - const __m256i shift1 = _mm256_set_epi64x(0, 2, 0, 4); - const __m256i shift2 = _mm256_set_epi64x(2, 0, 0, 0); - inline __m256i dequant(const block_q6_0 * x) const { - uint64_t aux64; std::memcpy(&aux64, x->qh, 8); - auto h256 = _mm256_sllv_epi64(_mm256_set1_epi64x(aux64), shift1); - return _mm256_or_si256(b4.dequant(x->qs), _mm256_and_si256(_mm256_srlv_epi64(h256, shift2), mh)); - } -}; - -template -struct Q_Unpacker { - Q_Unpacker(const void * vx, size_t bx) : cx_0((const char *)vx), x((const Q*)cx_0), bx(bx) {} - - const char * cx_0; - const Q * x; - size_t bx; - - Scales scales; - Dequantizer deq; - - __m256i qx[4]; - - inline const __m256i* quants() const { return qx; } - - inline void set_row(int ix) { x = (const Q*)(cx_0 + ix*bx); } - - inline auto set_block_4(int i) { - for (int j = 0; j < 4; ++j) { - qx[j] = deq.dequant(x + 4*i + j); - } - return scales.prepare4(x + 4*i); - } - inline auto set_block(int i) { - qx[0] = deq.dequant(x + i); - return scales.prepare1(x + i); - } -}; - -struct Q8_0_Unpacker final : public Q_Unpacker { - Q8_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ80; - inline static int block_size() { return QK8_0; } -}; -struct Q8_0_1_Unpacker final : public Q_Unpacker, Q8_0_1_Dequantizer> { - Q8_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; - inline static int block_size() { return QK8_0; } -}; -struct Q4_0_Unpacker final : public Q_Unpacker { - Q4_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ80; - inline static int block_size() { return QK4_0; } -}; -struct Q4_0_1_Unpacker final : public Q_Unpacker, Q4_0_1_Dequantizer> { - Q4_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - //using Sum4T = Sum4TypeQ82; - using Sum4T = Sum4q4; - inline static int block_size() { return QK4_0; } -}; -#ifdef HAVE_FANCY_SIMD -struct IQ4_NL_Unpacker final : public Q_Unpacker, IQ4_NL_Dequantizer> { - IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; - inline static int block_size() { return QK4_NL; } -}; -#else -struct IQ4_NL_Unpacker final : public Q_Unpacker { - IQ4_NL_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ80; - inline static int block_size() { return QK4_NL; } -}; -#endif -struct Q5_0_Unpacker final : public Q_Unpacker { - Q5_0_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ80; - inline static int block_size() { return QK5_0; } -}; -struct Q5_0_1_Unpacker final : public Q_Unpacker, Q5_1_Dequantizer> { - Q5_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; - inline static int block_size() { return QK5_0; } -}; -struct Q4_1_Unpacker final : public Q_Unpacker { - Q4_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; - inline static int block_size() { return QK4_1; } -}; -struct Q5_1_Unpacker final : public Q_Unpacker> { - Q5_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; - inline static int block_size() { return QK5_1; } -}; -struct Q6_0_1_Unpacker final : public Q_Unpacker, Q6_0_1_Dequantizer> { - Q6_0_1_Unpacker(const void * vx, size_t bx) : Q_Unpacker(vx, bx) {} - using Sum4T = Sum4TypeQ82; - inline static int block_size() { return QK6_0; } -}; - template void MulMat::set_functions(MulMat& m) { - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - m.funcs[0] = mul_mat_qX_0_q8_0_T; - m.funcs[1] = mul_mat_qX_0_q8_0_T; - m.funcs[2] = mul_mat_qX_0_q8_0_T; - m.funcs[3] = mul_mat_qX_0_q8_0_T; - m.funcs[4] = mul_mat_qX_0_q8_0_T; - m.funcs[5] = mul_mat_qX_0_q8_0_T; - m.funcs[6] = mul_mat_qX_0_q8_0_T; - m.funcs[7] = mul_mat_qX_0_q8_0_T; - } - else if constexpr (std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_1_q8_2_T; - m.funcs[1] = mul_mat_qX_1_q8_2_T; - m.funcs[2] = mul_mat_qX_1_q8_2_T; - m.funcs[3] = mul_mat_qX_1_q8_2_T; - m.funcs[4] = mul_mat_qX_1_q8_2_T; - m.funcs[5] = mul_mat_qX_1_q8_2_T; - m.funcs[6] = mul_mat_qX_1_q8_2_T; - m.funcs[7] = mul_mat_qX_1_q8_2_T; - } - else if constexpr (std::is_same_v) { -#ifdef HAVE_FANCY_SIMD - m.funcs[0] = mul_mat_qX_1_q8_2_T; - m.funcs[1] = mul_mat_qX_1_q8_2_T; - m.funcs[2] = mul_mat_qX_1_q8_2_T; - m.funcs[3] = mul_mat_qX_1_q8_2_T; - m.funcs[4] = mul_mat_qX_1_q8_2_T; - m.funcs[5] = mul_mat_qX_1_q8_2_T; - m.funcs[6] = mul_mat_qX_1_q8_2_T; - m.funcs[7] = mul_mat_qX_1_q8_2_T; -#else - m.funcs[0] = mul_mat_qX_0_q8_0_T; - m.funcs[1] = mul_mat_qX_0_q8_0_T; - m.funcs[2] = mul_mat_qX_0_q8_0_T; - m.funcs[3] = mul_mat_qX_0_q8_0_T; - m.funcs[4] = mul_mat_qX_0_q8_0_T; - m.funcs[5] = mul_mat_qX_0_q8_0_T; - m.funcs[6] = mul_mat_qX_0_q8_0_T; - m.funcs[7] = mul_mat_qX_0_q8_0_T; -#endif - } - else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_1_q8_2_T; - m.funcs[1] = mul_mat_qX_1_q8_2_T; - m.funcs[2] = mul_mat_qX_1_q8_2_T; - m.funcs[3] = mul_mat_qX_1_q8_2_T; - m.funcs[4] = mul_mat_qX_1_q8_2_T; - m.funcs[5] = mul_mat_qX_1_q8_2_T; - m.funcs[6] = mul_mat_qX_1_q8_2_T; - m.funcs[7] = mul_mat_qX_1_q8_2_T; - } - else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { m.funcs[0] = mul_mat_qX_K_q8_K_IQ; m.funcs[1] = mul_mat_qX_K_q8_K_IQ; m.funcs[2] = mul_mat_qX_K_q8_K_IQ; @@ -9415,49 +8591,13 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { expected_typeB = GGML_TYPE_Q8_K16; break; case GGML_TYPE_Q4_0: - assert (ne00 % QK4_0 == 0); - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_2_X4; - break; case GGML_TYPE_Q4_1: - assert (ne00 % QK4_1 == 0); - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_2_X4; - break; case GGML_TYPE_Q5_0: - assert (ne00 % QK5_0 == 0); - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_2_X4; - break; case GGML_TYPE_Q5_1: - assert (ne00 % QK5_1 == 0); - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_2_X4; - break; case GGML_TYPE_Q6_0: - assert (ne00 % QK6_0 == 0); - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_2_X4; - break; case GGML_TYPE_Q8_0: - assert (ne00 % QK8_0 == 0); -#ifdef HAVE_FANCY_SIMD - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_2_X4; -#else - MulMat::set_functions(mm); - expected_typeB = GGML_TYPE_Q8_0_X4; -#endif - break; case GGML_TYPE_IQ4_NL: - assert (ne00 % QK4_NL == 0); - MulMat::set_functions(mm); -#ifdef HAVE_FANCY_SIMD - expected_typeB = GGML_TYPE_Q8_2_X4; -#else - expected_typeB = GGML_TYPE_Q8_0_X4; -#endif - break; + return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs); case GGML_TYPE_IQ4_NL_R4: assert (ne00 % QK4_NL == 0); mm.funcs[0] = mul_mat_iq4_nl_r4_q8_2<1>;