diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 60eec8f9..86542a8f 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -138,6 +138,27 @@ typedef void (*mul_mat_t)(int n, const void * vx, size_t bx, const DataInfo& inf #define IQK_MAX_NY 8 +#define IQK_SET_MUL_MAT_FUNCTIONS_T(kernel, Dequantizer, funcs) \ + funcs[0] = kernel;\ + funcs[1] = kernel;\ + funcs[2] = kernel;\ + funcs[3] = kernel;\ + funcs[4] = kernel;\ + funcs[5] = kernel;\ + funcs[6] = kernel;\ + funcs[7] = kernel;\ + +#define IQK_SET_MUL_MAT_FUNCTIONS(kernel, funcs) \ + funcs[0] = kernel<1>;\ + funcs[1] = kernel<2>;\ + funcs[2] = kernel<3>;\ + funcs[3] = kernel<4>;\ + funcs[4] = kernel<5>;\ + funcs[5] = kernel<6>;\ + funcs[6] = kernel<7>;\ + funcs[7] = kernel<8>;\ + + // ================================================================================================== static inline void make_q4_scales(const uint8_t * scales8, uint32_t * aux32) { @@ -234,6 +255,13 @@ static inline __m256i load_iq4nl_values_256() { return MM256_SET_M128I(val128, val128); } +#ifdef HAVE_FANCY_SIMD +static inline __m512i load_iq4nl_values_512() { + auto val256 = load_iq4nl_values_256(); + return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); +} +#endif + static inline __m128i load_iq4k_values_128() { return _mm_loadu_si128((const __m128i *)iq4k_values); } diff --git a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp index baf1bd1e..0f02ba76 100644 --- a/ggml/src/iqk/iqk_gemm_iqk_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_iqk_quants.cpp @@ -11,11 +11,6 @@ namespace { #ifdef HAVE_FANCY_SIMD -__m512i inline load_iq4nl_values_512() { - auto val256 = load_iq4nl_values_256(); - return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); -} - struct IQXKScales { IQXKScales(uint8_t shift, int8_t min_val) : eshift(_mm256_set1_epi16(shift)), min(_mm256_set1_epi16(min_val)) {} template diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index 57014665..9192dca3 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -285,11 +285,6 @@ struct DequantizerQ6K final : public BaseDequantizer { }; -__m512i inline load_iq4nl_values_512() { - auto val256 = load_iq4nl_values_256(); - return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); -} - struct DequantizerIQ4XS final : public BaseDequantizer { DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {} template diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index ed60afbc..1a30c5d8 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -685,70 +685,965 @@ struct Q6_0_1_Unpacker final : public Q_Unpacker +static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto values = load_iq4nl_values_512(); + int nb = n / QK4_NL; + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1); + qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); + qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); + qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); + qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4l[ib], iq4h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } +} +#else +template +static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m1 = _mm256_set1_epi16(1); + auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); + auto values = MM256_SET_M128I(values128, values128); + int nb = n / QK4_NL; + __m256 acc[nrc_y] = {}; + __m256i qs[4]; + float d8[4*nrc_y]; + auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1); + qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); + qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); + qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); + qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); + return scales; + }; + auto dot = [&qs, &m1] (__m256i y) { + auto u1 = _mm256_sign_epi8(qs[0], qs[0]); + auto u2 = _mm256_sign_epi8(qs[1], qs[1]); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1])))); + u1 = _mm256_sign_epi8(qs[2], qs[2]); + u2 = _mm256_sign_epi8(qs[3], qs[3]); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3])))); + return _mm256_add_epi32(sumi1, sumi2); + }; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto aux = _mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm_storeu_ps(d8+4*iy, _mm_castsi128_ps(aux)); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq4[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + ggml_bf16_t d{qy[ib].d}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} +#endif + +inline void prepare_q4_0_quants_avx2(const uint8_t * qs, __m256i * v, const __m256i& m4) { + auto bits1 = _mm256_loadu_si256((const __m256i *)qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)qs+1); + auto bits3 = _mm256_loadu_si256((const __m256i *)qs+2); + auto bits4 = _mm256_loadu_si256((const __m256i *)qs+3); + v[0] = _mm256_and_si256(bits1, m4); + v[1] = _mm256_and_si256(bits2, m4); + v[2] = _mm256_and_si256(bits3, m4); + v[3] = _mm256_and_si256(bits4, m4); + v[4] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4); + v[5] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4); + v[6] = _mm256_and_si256(_mm256_srli_epi16(bits3, 4), m4); + v[7] = _mm256_and_si256(_mm256_srli_epi16(bits4, 4), m4); +} + +inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { + auto y4l = _mm_loadu_si128((const __m128i*)qs+0); + auto y4h = _mm_loadu_si128((const __m128i*)qs+1); + auto yl = MM256_SET_M128I(y4l, y4l); + auto yh = MM256_SET_M128I(y4h, y4h); +#ifdef HAVE_FANCY_SIMD + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, v[0], _mm256_shuffle_epi32(yl, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, v[1], _mm256_shuffle_epi32(yl, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, v[2], _mm256_shuffle_epi32(yl, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, v[3], _mm256_shuffle_epi32(yl, 0xff)); + sumi = _mm256_dpbusd_epi32(sumi, v[4], _mm256_shuffle_epi32(yh, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, v[5], _mm256_shuffle_epi32(yh, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, v[6], _mm256_shuffle_epi32(yh, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, v[7], _mm256_shuffle_epi32(yh, 0xff)); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(v[0], _mm256_shuffle_epi32(yl, 0x00)), + _mm256_maddubs_epi16(v[1], _mm256_shuffle_epi32(yl, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(v[2], _mm256_shuffle_epi32(yl, 0xaa)), + _mm256_maddubs_epi16(v[3], _mm256_shuffle_epi32(yl, 0xff))); + auto sumi3 = _mm256_add_epi16(_mm256_maddubs_epi16(v[4], _mm256_shuffle_epi32(yh, 0x00)), + _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55))); + auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)), + _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff))); + auto sumi = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_add_epi16(sumi1, sumi2), _mm256_add_epi16(sumi3, sumi4))); +#endif + return sumi; +} + +template +static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + int nb = n / QK4_NL; + __m256i v[8]; + GGML_ASSERT(nb%4 == 0); + if constexpr (nrc_y == 1) { + union { __m256 vec; float val[8]; } helper; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx); + auto acc1 = _mm256_setzero_ps(); + auto acc2 = _mm256_setzero_ps(); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16)); + for (int k = 0; k < 4; ++k) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); + prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); + auto sumi = accum_q4_0_quants(v, q8.y[0][ib4].qs+32*k); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(helper.val[k])); + acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2); + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto qy = (const block_q8_1 *)q8.y[0]; + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); + prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); + auto sumi = accum_q4_0_quants(v, qy[ib].qs); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); + acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2); + } + acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); + info.store(ix, 0, acc1); + } + } + else { + __m256 acc[nrc_y] = {}; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + { + __m256 d4[4]; + for (int k = 0; k < 4; ++k) { + d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + _mm256_storeu_ps(d8 + 8*iy, scales); + auto m4 = _mm256_extractf128_ps(scales, 1); + auto m8 = _mm256_set_m128(m4, m4); + auto sumf = _mm256_mul_ps(d4[0], _mm256_shuffle_ps(m8, m8, 0x00)); + sumf = _mm256_fmadd_ps(d4[1], _mm256_shuffle_ps(m8, m8, 0x55), sumf); + sumf = _mm256_fmadd_ps(d4[2], _mm256_shuffle_ps(m8, m8, 0xaa), sumf); + sumf = _mm256_fmadd_ps(d4[3], _mm256_shuffle_ps(m8, m8, 0xff), sumf); + acc[iy] = _mm256_fmadd_ps(sumf, _mm256_set1_ps(-8.f), acc[iy]); + } + } + for (int k = 0; k < 4; ++k) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); + prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = accum_q4_0_quants(v, q8.y[iy][ib4].qs+32*k); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); + auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); + prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = accum_q4_0_quants(v, qy[ib].qs); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = _mm256_setzero_ps(); + } + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1) { + mul_mat_q4_0_r8_q8_2_avx2<1>(n, vx, bx, info, nrc_x); + return; + } + GGML_ASSERT(nrc_x%16 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + int nb = n / QK4_NL; + __m512 acc[2*nrc_y] = {}; + __m512i qx[8]; + auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 4; ++j) { + auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)), + _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1); + qx[j+0] = _mm512_and_si512(bits, m4); + qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4); + } + return scales; + }; + auto dot = [&qx] (const int8_t * qy) { + auto y4l = _mm_loadu_si128((const __m128i*)qy+0); + auto y4h = _mm_loadu_si128((const __m128i*)qy+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + return sumi; + }; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 16) { + const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx); + const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][ib4].qs+32*k); + auto dy = _mm512_set1_ps(d8[8*iy+k]); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq4l[ib], iq4h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + info.store(ix, iy, sum); + } + } +} +#else +template +static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q4_0_r8_q8_2_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m5 = _mm256_set1_epi8(0x10); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f)); + int nb = n / QK5_0; + __m256 acc[nrc_y] = {}; + __m256i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1); + auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh); + auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); + qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); + qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; + return scales; + }; +#ifdef HAVE_FANCY_SIMD + auto dot = [&qx] (__m256i y) { + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; + }; +#else + auto dot = [&qx, &m1] (__m256i y) { + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + return sumi; + }; +#endif + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq5[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq5[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1) { + mul_mat_q5_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto m5 = _mm512_set1_epi8(0x10); + int nb = n / QK5_0; + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1); + auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh); + auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh); + auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); + auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); + qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); + qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); + qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); + qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx); + const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16))); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq5l[ib], iq5h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } + } +} +#else +template +static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q5_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); +} +#endif + +template +static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + Q8 q8(info); + auto m4 = _mm256_set1_epi8(0xf); + auto m6 = _mm256_set1_epi8(0x30); + auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f)); +#ifndef HAVE_FANCY_SIMD + auto m1 = _mm256_set1_epi16(1); +#endif + int nb = n / QK6_0; + __m256 acc[nrc_y] = {}; + float d8[8*nrc_y]; + __m256i qx[4]; + auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d)); + auto scales = _mm256_set_m128(scales128, scales128); + auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0); + auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1); + auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh); + qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); + qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); + qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); + qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); + return scales; + }; +#ifdef HAVE_FANCY_SIMD + auto dot = [&qx] (__m256i y) { + auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); + return sumi; + }; +#else + auto dot = [&qx, &m1] (__m256i y) { + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); + return sumi; + }; +#endif + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq6[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq6[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]); + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); + info.store(ix, iy, sum); + acc[iy] = _mm256_setzero_ps(); + } + } +} + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if constexpr (nrc_y == 1) { + mul_mat_q6_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); + } else { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m4 = _mm512_set1_epi8(0xf); + auto m6 = _mm512_set1_epi8(0x30); + int nb = n / QK6_0; + __m512 acc[2*nrc_y] = {}; + __m512i qx[4]; + float d8[8*nrc_y]; + auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) { + auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d)); + auto scales1 = _mm256_set_m128(scales128, scales128); + scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d)); + auto scales2 = _mm256_set_m128(scales128, scales128); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)), + _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1); + auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)), + _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1); + auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh); + auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh); + auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); + qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); + qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); + return scales; + }; + auto dot = [&qx] (__m256i y8) { + auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + return sumi; + }; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx); + const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); + _mm256_storeu_ps(d8 + 8*iy, scales); + } + for (int k = 0; k < 4; ++k) { + auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); + auto dy = _mm512_set1_ps(d8[8*iy+k]); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = prepare(iq6l[ib], iq6h[ib]); + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); + ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); + auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); + info.store(ix+0, iy, sum1); + info.store(ix+4, iy, sum2); + } + } + } +} +#else +template +static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + mul_mat_q6_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); +} +#endif + +#ifdef HAVE_FANCY_SIMD +inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) { + auto y4l = _mm_loadu_si128((const __m128i*)y+0); + auto y4h = _mm_loadu_si128((const __m128i*)y+1); + auto y8l = MM256_SET_M128I(y4l, y4l); + auto y8h = MM256_SET_M128I(y4h, y4h); + auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); + auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); + auto sumi = _mm512_setzero_si512(); + sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); + sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); + sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); + sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); + sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); + return sumi; +} +inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) { + auto y4l = _mm_loadu_si128((const __m128i*)y+0); + auto y4h = _mm_loadu_si128((const __m128i*)y+1); + auto yl = MM256_SET_M128I(y4l, y4l); + auto yh = MM256_SET_M128I(y4h, y4h); + auto sumi = _mm256_setzero_si256(); + sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff)); + sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00)); + sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55)); + sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa)); + sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); + return sumi; +} +inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) { + for (int i = 0; i < 8; ++i) { + qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127)); + } + return qx_r8_q8_dot_product(qx, y); +} +template +static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%16 == 0); + Q8 q8(info); + int nb = n / QK8_0; + if constexpr (nrc_y == 1) { + __m256 acc[2] = {}; + __m256i qx[8]; + float d8[8]; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16); + _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux)); + for (int k = 0; k < 4; ++k) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); + auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k])); + acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]); + } + } + if (4*(nb/4) < nb) { + auto qy = (const block_q8_1 *)q8.y[0]; + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); + auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); + acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); + acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]); + } + } + info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); + acc[0] = acc[1] = _mm256_setzero_ps(); + } + } else { + __m512 acc[2*nrc_y] = {}; + __m512i qx[8]; + float d8[8*nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 16) { + const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx); + const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); + _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); + } + for (int k = 0; k < 4; ++k) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 8; ++j) { + qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), + _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k); + auto dy = _mm512_set1_ps(d8[8*iy+k]); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d)); + auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d)); + auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); + for (int j = 0; j < 8; ++j) { + qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), + _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); + qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_1 *)q8.y[iy]; + auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); + ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; + auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); + acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); + acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]); + info.store(ix, iy, sum512); + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); + } + } + } +} +#else +template +static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%8 == 0); + Q8 q8(info); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK8_0; + __m256 acc[nrc_y] = {}; + float d8[4*nrc_y]; + __m256i qx[4], sx[4]; + auto dot = [&qx, &sx, &m1] (const int8_t * qy) { + auto y128 = _mm_loadu_si128((const __m128i*)qy); + auto y = MM256_SET_M128I(y128, y128); + auto sumi1 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]))) + ); + auto sumi2 = _mm256_add_epi32( + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))), + _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]))) + ); + return _mm256_add_epi32(sumi1, sumi2); + }; + for (int ix = 0; ix < nrc_x; ix += 8) { + const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); + for (int ib4 = 0; ib4 < nb/4; ++ib4) { + for (int iy = 0; iy < nrc_y; ++iy) { + auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16)); + _mm_storeu_ps(d8 + 4*iy, scales); + } + for (int k = 0; k < 4; ++k) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][ib4].qs+32*k); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto sumi = dot(q8.y[iy][ib4].qs+32*k+16); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + } + for (int ib = 4*(nb/4); ib < nb; ++ib) { + auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_2 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + for (int j = 0; j < 4; ++j) { + qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j); + sx[j] = _mm256_sign_epi8(qx[j], qx[j]); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto qy = (const block_q8_2 *)q8.y[iy]; + auto sumi = dot(qy[ib].qs+16); + auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); + acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, acc[iy]); + acc[iy] = _mm256_setzero_ps(); + } + } +} +#endif + template void set_functions(std::array& funcs) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - funcs[0] = mul_mat_qX_0_q8_0_T; - funcs[1] = mul_mat_qX_0_q8_0_T; - funcs[2] = mul_mat_qX_0_q8_0_T; - funcs[3] = mul_mat_qX_0_q8_0_T; - funcs[4] = mul_mat_qX_0_q8_0_T; - funcs[5] = mul_mat_qX_0_q8_0_T; - funcs[6] = mul_mat_qX_0_q8_0_T; - funcs[7] = mul_mat_qX_0_q8_0_T; + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0_T, Dequantizer, funcs) } else if constexpr (std::is_same_v || std::is_same_v) { - funcs[0] = mul_mat_qX_1_q8_2_T; - funcs[1] = mul_mat_qX_1_q8_2_T; - funcs[2] = mul_mat_qX_1_q8_2_T; - funcs[3] = mul_mat_qX_1_q8_2_T; - funcs[4] = mul_mat_qX_1_q8_2_T; - funcs[5] = mul_mat_qX_1_q8_2_T; - funcs[6] = mul_mat_qX_1_q8_2_T; - funcs[7] = mul_mat_qX_1_q8_2_T; + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs) } else if constexpr (std::is_same_v) { #ifdef HAVE_FANCY_SIMD - funcs[0] = mul_mat_qX_1_q8_2_T; - funcs[1] = mul_mat_qX_1_q8_2_T; - funcs[2] = mul_mat_qX_1_q8_2_T; - funcs[3] = mul_mat_qX_1_q8_2_T; - funcs[4] = mul_mat_qX_1_q8_2_T; - funcs[5] = mul_mat_qX_1_q8_2_T; - funcs[6] = mul_mat_qX_1_q8_2_T; - funcs[7] = mul_mat_qX_1_q8_2_T; + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs) #else - funcs[0] = mul_mat_qX_0_q8_0_T; - funcs[1] = mul_mat_qX_0_q8_0_T; - funcs[2] = mul_mat_qX_0_q8_0_T; - funcs[3] = mul_mat_qX_0_q8_0_T; - funcs[4] = mul_mat_qX_0_q8_0_T; - funcs[5] = mul_mat_qX_0_q8_0_T; - funcs[6] = mul_mat_qX_0_q8_0_T; - funcs[7] = mul_mat_qX_0_q8_0_T; + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0_T, Dequantizer, funcs) #endif } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { - funcs[0] = mul_mat_qX_1_q8_2_T; - funcs[1] = mul_mat_qX_1_q8_2_T; - funcs[2] = mul_mat_qX_1_q8_2_T; - funcs[3] = mul_mat_qX_1_q8_2_T; - funcs[4] = mul_mat_qX_1_q8_2_T; - funcs[5] = mul_mat_qX_1_q8_2_T; - funcs[6] = mul_mat_qX_1_q8_2_T; - funcs[7] = mul_mat_qX_1_q8_2_T; + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_2_T, Dequantizer, funcs) } } } // namespace -bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels) { +bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { if (ne00%QK8_0 != 0) return false; auto expected_typeB = GGML_TYPE_Q8_2_X4; + func16 = nullptr; + switch (typeA) { case GGML_TYPE_Q4_0: set_functions(kernels); @@ -779,6 +1674,24 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array; +#endif + break; + case GGML_TYPE_Q5_0_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_0_r4_q8_2, kernels) + break; + case GGML_TYPE_Q6_0_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_0_r4_q8_2, kernels) + break; + case GGML_TYPE_Q8_0_R8: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_2, kernels) + break; + case GGML_TYPE_IQ4_NL_R4: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels) + break; default: return false; } diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.h b/ggml/src/iqk/iqk_gemm_legacy_quants.h index dd6d097a..7e37ddad 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.h +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.h @@ -6,6 +6,6 @@ #include -bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels); +bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16); #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 03d3afa6..108f8bfc 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -734,11 +734,6 @@ struct Q2Bits { BlockPermuter perm; }; -__m512i inline load_iq4nl_values_512() { - auto val256 = load_iq4nl_values_256(); - return _mm512_inserti32x8(_mm512_castsi256_si512(val256), val256, 1); -} - struct HighBit5 { inline void apply(const uint8_t * h, Q4Bits& bits) { auto hbits256 = _mm256_loadu_si256((const __m256i *)h); @@ -1185,934 +1180,6 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif // Zen4 or vanilla AVX2 -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto m4 = _mm512_set1_epi8(0xf); - auto values = load_iq4nl_values_512(); - int nb = n / QK4_NL; - __m512 acc[2*nrc_y] = {}; - __m512i qx[4]; - float d8[8*nrc_y]; - auto prepare = [&qx, &m4, &values] (const block_iq4_nl_r4& iq4l, const block_iq4_nl_r4& iq4h) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l.d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h.d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+0)), - _mm256_loadu_si256((const __m256i *)iq4h.qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+1)), - _mm256_loadu_si256((const __m256i *)iq4h.qs+1), 1); - qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4)); - qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4)); - qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4)); - qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4)); - return scales; - }; - auto dot = [&qx] (__m256i y8) { - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - return sumi; - }; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_nl_r4 * iq4l = (const block_iq4_nl_r4 *)((const char *)vx + (ix+0)*bx); - const block_iq4_nl_r4 * iq4h = (const block_iq4_nl_r4 *)((const char *)vx + (ix+4)*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); - auto dy = _mm512_set1_ps(d8[8*iy+k]); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq4l[ib], iq4h[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; - auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-64.f), acc[2*iy+1], acc[2*iy+0]); - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - } - } -} -#else -template -static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - auto m4 = _mm256_set1_epi8(0xf); - auto m1 = _mm256_set1_epi16(1); - auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values); - auto values = MM256_SET_M128I(values128, values128); - int nb = n / QK4_NL; - __m256 acc[nrc_y] = {}; - __m256i qs[4]; - float d8[4*nrc_y]; - auto prepare = [&qs, &values, &m4] (const block_iq4_nl_r4& iq4) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4.d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq4.qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq4.qs+1); - qs[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); - qs[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); - qs[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); - qs[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); - return scales; - }; - auto dot = [&qs, &m1] (__m256i y) { - auto u1 = _mm256_sign_epi8(qs[0], qs[0]); - auto u2 = _mm256_sign_epi8(qs[1], qs[1]); - auto sumi1 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qs[0]))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qs[1])))); - u1 = _mm256_sign_epi8(qs[2], qs[2]); - u2 = _mm256_sign_epi8(qs[3], qs[3]); - auto sumi2 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qs[2]))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(u2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qs[3])))); - return _mm256_add_epi32(sumi1, sumi2); - }; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm_storeu_ps(d8+4*iy, _mm_castsi128_ps(aux)); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq4[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq4[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - ggml_bf16_t d{qy[ib].d}; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); - acc[iy] = _mm256_setzero_ps(); - } - } -} -#endif - -inline void prepare_q4_0_quants_avx2(const uint8_t * qs, __m256i * v, const __m256i& m4) { - auto bits1 = _mm256_loadu_si256((const __m256i *)qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)qs+1); - auto bits3 = _mm256_loadu_si256((const __m256i *)qs+2); - auto bits4 = _mm256_loadu_si256((const __m256i *)qs+3); - v[0] = _mm256_and_si256(bits1, m4); - v[1] = _mm256_and_si256(bits2, m4); - v[2] = _mm256_and_si256(bits3, m4); - v[3] = _mm256_and_si256(bits4, m4); - v[4] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4); - v[5] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4); - v[6] = _mm256_and_si256(_mm256_srli_epi16(bits3, 4), m4); - v[7] = _mm256_and_si256(_mm256_srli_epi16(bits4, 4), m4); -} - -inline __m256i accum_q4_0_quants(const __m256i * v, const int8_t * qs) { - auto y4l = _mm_loadu_si128((const __m128i*)qs+0); - auto y4h = _mm_loadu_si128((const __m128i*)qs+1); - auto yl = MM256_SET_M128I(y4l, y4l); - auto yh = MM256_SET_M128I(y4h, y4h); -#ifdef HAVE_FANCY_SIMD - auto sumi = _mm256_setzero_si256(); - sumi = _mm256_dpbusd_epi32(sumi, v[0], _mm256_shuffle_epi32(yl, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, v[1], _mm256_shuffle_epi32(yl, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, v[2], _mm256_shuffle_epi32(yl, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, v[3], _mm256_shuffle_epi32(yl, 0xff)); - sumi = _mm256_dpbusd_epi32(sumi, v[4], _mm256_shuffle_epi32(yh, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, v[5], _mm256_shuffle_epi32(yh, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, v[6], _mm256_shuffle_epi32(yh, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, v[7], _mm256_shuffle_epi32(yh, 0xff)); -#else - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(v[0], _mm256_shuffle_epi32(yl, 0x00)), - _mm256_maddubs_epi16(v[1], _mm256_shuffle_epi32(yl, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(v[2], _mm256_shuffle_epi32(yl, 0xaa)), - _mm256_maddubs_epi16(v[3], _mm256_shuffle_epi32(yl, 0xff))); - auto sumi3 = _mm256_add_epi16(_mm256_maddubs_epi16(v[4], _mm256_shuffle_epi32(yh, 0x00)), - _mm256_maddubs_epi16(v[5], _mm256_shuffle_epi32(yh, 0x55))); - auto sumi4 = _mm256_add_epi16(_mm256_maddubs_epi16(v[6], _mm256_shuffle_epi32(yh, 0xaa)), - _mm256_maddubs_epi16(v[7], _mm256_shuffle_epi32(yh, 0xff))); - auto sumi = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_add_epi16(sumi1, sumi2), _mm256_add_epi16(sumi3, sumi4))); -#endif - return sumi; -} - -template -static void mul_mat_q4_0_r8_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto m4 = _mm256_set1_epi8(0xf); - int nb = n / QK4_NL; - __m256i v[8]; - GGML_ASSERT(nb%4 == 0); - if constexpr (nrc_y == 1) { - union { __m256 vec; float val[8]; } helper; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx); - auto acc1 = _mm256_setzero_ps(); - auto acc2 = _mm256_setzero_ps(); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - helper.vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16)); - for (int k = 0; k < 4; ++k) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); - prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); - auto sumi = accum_q4_0_quants(v, q8.y[0][ib4].qs+32*k); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(helper.val[k])); - acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); - acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(helper.val[k+4]), acc2); - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto qy = (const block_q8_1 *)q8.y[0]; - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); - prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); - auto sumi = accum_q4_0_quants(v, qy[ib].qs); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); - acc1 = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc1); - acc2 = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc2); - } - acc1 = _mm256_fmadd_ps(acc2, _mm256_set1_ps(-8.f), acc1); - info.store(ix, 0, acc1); - } - } - else { - __m256 acc[nrc_y] = {}; - float d8[8*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_iq4_nl_r8 * iq4 = (const block_iq4_nl_r8 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - { - __m256 d4[4]; - for (int k = 0; k < 4; ++k) { - d4[k] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); - _mm256_storeu_ps(d8 + 8*iy, scales); - auto m4 = _mm256_extractf128_ps(scales, 1); - auto m8 = _mm256_set_m128(m4, m4); - auto sumf = _mm256_mul_ps(d4[0], _mm256_shuffle_ps(m8, m8, 0x00)); - sumf = _mm256_fmadd_ps(d4[1], _mm256_shuffle_ps(m8, m8, 0x55), sumf); - sumf = _mm256_fmadd_ps(d4[2], _mm256_shuffle_ps(m8, m8, 0xaa), sumf); - sumf = _mm256_fmadd_ps(d4[3], _mm256_shuffle_ps(m8, m8, 0xff), sumf); - acc[iy] = _mm256_fmadd_ps(sumf, _mm256_set1_ps(-8.f), acc[iy]); - } - } - for (int k = 0; k < 4; ++k) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[4*ib4+k].d)); - prepare_q4_0_quants_avx2(iq4[4*ib4+k].qs, v, m4); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = accum_q4_0_quants(v, q8.y[iy][ib4].qs+32*k); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ib].d)); - auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-8.f)); - prepare_q4_0_quants_avx2(iq4[ib].qs, v, m4); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = accum_q4_0_quants(v, qy[ib].qs); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = _mm256_setzero_ps(); - } - } - } -} - -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if constexpr (nrc_y == 1) { - mul_mat_q4_0_r8_q8_2_avx2<1>(n, vx, bx, info, nrc_x); - return; - } - GGML_ASSERT(nrc_x%16 == 0); - Q8 q8(info); - auto m4 = _mm512_set1_epi8(0xf); - int nb = n / QK4_NL; - __m512 acc[2*nrc_y] = {}; - __m512i qx[8]; - auto prepare = [&qx, &m4] (const block_iq4_nl_r8& iq4l, const block_iq4_nl_r8& iq4h) { - auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4l.d)); - auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4h.d)); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - for (int j = 0; j < 4; ++j) { - auto bits = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l.qs+j)), - _mm256_loadu_si256((const __m256i *)iq4h.qs+j), 1); - qx[j+0] = _mm512_and_si512(bits, m4); - qx[j+4] = _mm512_and_si512(_mm512_srli_epi16(bits, 4), m4); - } - return scales; - }; - auto dot = [&qx] (const int8_t * qy) { - auto y4l = _mm_loadu_si128((const __m128i*)qy+0); - auto y4h = _mm_loadu_si128((const __m128i*)qy+1); - auto y8l = MM256_SET_M128I(y4l, y4l); - auto y8h = MM256_SET_M128I(y4h, y4h); - auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); - auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); - return sumi; - }; - float d8[8*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 16) { - const block_iq4_nl_r8 * iq4l = (const block_iq4_nl_r8 *)((const char *)vx + (ix+0)*bx); - const block_iq4_nl_r8 * iq4h = (const block_iq4_nl_r8 *)((const char *)vx + (ix+8)*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq4l[4*ib4+k], iq4h[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(q8.y[iy][ib4].qs+32*k); - auto dy = _mm512_set1_ps(d8[8*iy+k]); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq4l[ib], iq4h[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(qy[ib].qs); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - info.store(ix, iy, sum); - } - } -} -#else -template -static void mul_mat_q4_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q4_0_r8_q8_2_avx2(n, vx, bx, info, nrc_x); -} -#endif - -template -static void mul_mat_q5_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - auto m4 = _mm256_set1_epi8(0xf); - auto m5 = _mm256_set1_epi8(0x10); -#ifndef HAVE_FANCY_SIMD - auto m1 = _mm256_set1_epi16(1); -#endif - auto mscale = _mm256_set_m128(_mm_set1_ps(-8.f), _mm_set1_ps(1.f)); - int nb = n / QK5_0; - __m256 acc[nrc_y] = {}; - __m256i qx[4]; - float d8[8*nrc_y]; - auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5.d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq5.qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq5.qs+1); - auto hbits = _mm_loadu_si128((const __m128i *)iq5.qh); - auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 1), hbits); - qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 4), m5)); - qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hb, 2), m5)); - qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hb, m5)); - qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hb, 2), m5));; - return scales; - }; -#ifdef HAVE_FANCY_SIMD - auto dot = [&qx] (__m256i y) { - auto sumi = _mm256_setzero_si256(); - sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); - return sumi; - }; -#else - auto dot = [&qx, &m1] (__m256i y) { - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); - return sumi; - }; -#endif - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); - _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(mscale, scales)); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq5[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq5[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-8.f*GGML_BF16_TO_FP32(s)), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); - acc[iy] = _mm256_setzero_ps(); - } - } -} - -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if constexpr (nrc_y == 1) { - mul_mat_q5_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); - } else { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto m4 = _mm512_set1_epi8(0xf); - auto m5 = _mm512_set1_epi8(0x10); - int nb = n / QK5_0; - __m512 acc[2*nrc_y] = {}; - __m512i qx[4]; - float d8[8*nrc_y]; - auto prepare = [&qx, &m4, &m5] (const block_q5_0_r4& iq5l, const block_q5_0_r4& iq5h) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5l.d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq5h.d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+0)), - _mm256_loadu_si256((const __m256i *)iq5h.qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq5l.qs+1)), - _mm256_loadu_si256((const __m256i *)iq5h.qs+1), 1); - auto hbits1 = _mm_loadu_si128((const __m128i *)iq5l.qh); - auto hbits2 = _mm_loadu_si128((const __m128i *)iq5h.qh); - auto hb1 = MM256_SET_M128I(_mm_srli_epi16(hbits1, 1), hbits1); - auto hb2 = MM256_SET_M128I(_mm_srli_epi16(hbits2, 1), hbits2); - auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hb1), hb2, 1); - qx[0] = _mm512_or_si512(_mm512_and_si512(bits1, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 4), m5)); - qx[1] = _mm512_or_si512(_mm512_and_si512(bits2, m4), _mm512_and_si512(_mm512_slli_epi16(hb, 2), m5)); - qx[2] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4), _mm512_and_si512(hb, m5)); - qx[3] = _mm512_or_si512(_mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4), _mm512_and_si512(_mm512_srli_epi16(hb, 2), m5)); - return scales; - }; - auto dot = [&qx] (__m256i y8) { - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - return sumi; - }; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q5_0_r4 * iq5l = (const block_q5_0_r4 *)((const char *)vx + (ix+0)*bx); - const block_q5_0_r4 * iq5h = (const block_q5_0_r4 *)((const char *)vx + (ix+4)*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16))); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq5l[4*ib4+k], iq5h[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); - auto dy = _mm512_set1_ps(d8[8*iy+k]); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq5l[ib], iq5h[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-8.f), acc[2*iy+1], acc[2*iy+0]); - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - } - } - } -} -#else -template -static void mul_mat_q5_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q5_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); -} -#endif - -template -static void mul_mat_q6_0_r4_q8_2_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - auto m4 = _mm256_set1_epi8(0xf); - auto m6 = _mm256_set1_epi8(0x30); - auto mscale = _mm256_set_m128(_mm_set1_ps(-16.f), _mm_set1_ps(1.f)); -#ifndef HAVE_FANCY_SIMD - auto m1 = _mm256_set1_epi16(1); -#endif - int nb = n / QK6_0; - __m256 acc[nrc_y] = {}; - float d8[8*nrc_y]; - __m256i qx[4]; - auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6.d)); - auto scales = _mm256_set_m128(scales128, scales128); - auto bits1 = _mm256_loadu_si256((const __m256i *)iq6.qs+0); - auto bits2 = _mm256_loadu_si256((const __m256i *)iq6.qs+1); - auto hbits = _mm256_loadu_si256((const __m256i *)iq6.qh); - qx[0] = _mm256_or_si256(_mm256_and_si256(bits1, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), m6)); - qx[1] = _mm256_or_si256(_mm256_and_si256(bits2, m4), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), m6)); - qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4), _mm256_and_si256(hbits, m6)); - qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m6)); - return scales; - }; -#ifdef HAVE_FANCY_SIMD - auto dot = [&qx] (__m256i y) { - auto sumi = _mm256_dpbusd_epi32(_mm256_setzero_si256(), qx[0], _mm256_shuffle_epi32(y, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); - return sumi; - }; -#else - auto dot = [&qx, &m1] (__m256i y) { - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2)); - return sumi; - }; -#endif - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q6_0_r4 * iq6 = (const block_q6_0_r4 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); - _mm256_storeu_ps(d8 + 8*iy, _mm256_mul_ps(scales, mscale)); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq6[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[8*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[8*iy+k+4]), acc[iy]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq6[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - acc[iy] = _mm256_fmadd_ps(scales, _mm256_set1_ps(-16.f*GGML_BF16_TO_FP32(s)), acc[iy]); - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)); - info.store(ix, iy, sum); - acc[iy] = _mm256_setzero_ps(); - } - } -} - -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if constexpr (nrc_y == 1) { - mul_mat_q6_0_r4_q8_2_avx2<1>(n, vx, bx, info, nrc_x); - } else { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto m4 = _mm512_set1_epi8(0xf); - auto m6 = _mm512_set1_epi8(0x30); - int nb = n / QK6_0; - __m512 acc[2*nrc_y] = {}; - __m512i qx[4]; - float d8[8*nrc_y]; - auto prepare = [&qx, &m4, &m6] (const block_q6_0_r4& iq6l, const block_q6_0_r4& iq6h) { - auto scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6l.d)); - auto scales1 = _mm256_set_m128(scales128, scales128); - scales128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6h.d)); - auto scales2 = _mm256_set_m128(scales128, scales128); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+0)), - _mm256_loadu_si256((const __m256i *)iq6h.qs+0), 1); - auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq6l.qs+1)), - _mm256_loadu_si256((const __m256i *)iq6h.qs+1), 1); - auto hbits1 = _mm256_loadu_si256((const __m256i *)iq6l.qh); - auto hbits2 = _mm256_loadu_si256((const __m256i *)iq6h.qh); - auto hb = _mm512_inserti32x8(_mm512_castsi256_si512(hbits1), hbits2, 1); - qx[0] = _mm512_and_si512(bits1, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 4), m6); - qx[1] = _mm512_and_si512(bits2, m4) | _mm512_and_si512(_mm512_slli_epi16(hb, 2), m6);; - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4) | _mm512_and_si512(hb, m6); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4) | _mm512_and_si512(_mm512_srli_epi16(hb, 2), m6); - return scales; - }; - auto dot = [&qx] (__m256i y8) { - auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - return sumi; - }; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q6_0_r4 * iq6l = (const block_q6_0_r4 *)((const char *)vx + (ix+0)*bx); - const block_q6_0_r4 * iq6h = (const block_q6_0_r4 *)((const char *)vx + (ix+4)*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16)); - _mm256_storeu_ps(d8 + 8*iy, scales); - } - for (int k = 0; k < 4; ++k) { - auto scales = prepare(iq6l[4*ib4+k], iq6h[4*ib4+k]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(_mm256_loadu_si256((const __m256i*)q8.y[iy][ib4].qs+k)); - auto dy = _mm512_set1_ps(d8[8*iy+k]); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = prepare(iq6l[ib], iq6h[ib]); - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = dot(_mm256_loadu_si256((const __m256i*)qy[ib].qs)); - ggml_bf16_t d{qy[ib].d}, s{qy[ib].s}; - auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-16.f), acc[2*iy+1], acc[2*iy+0]); - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 0), _mm512_extractf32x4_ps(sum512, 1)); - auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(sum512, 2), _mm512_extractf32x4_ps(sum512, 3)); - info.store(ix+0, iy, sum1); - info.store(ix+4, iy, sum2); - } - } - } -} -#else -template -static void mul_mat_q6_0_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - mul_mat_q6_0_r4_q8_2_avx2(n, vx, bx, info, nrc_x); -} -#endif - -#ifdef HAVE_FANCY_SIMD -inline __m512i qx_r8_q8_dot_product(const __m512i * qx, const int8_t * y) { - auto y4l = _mm_loadu_si128((const __m128i*)y+0); - auto y4h = _mm_loadu_si128((const __m128i*)y+1); - auto y8l = MM256_SET_M128I(y4l, y4l); - auto y8h = MM256_SET_M128I(y4h, y4h); - auto yl = _mm512_inserti32x8(_mm512_castsi256_si512(y8l), y8l, 1); - auto yh = _mm512_inserti32x8(_mm512_castsi256_si512(y8h), y8h, 1); - auto sumi = _mm512_setzero_si512(); - sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(yl, _MM_PERM_ENUM(0xff))); - sumi = _mm512_dpbusd_epi32(sumi, qx[4], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x00))); - sumi = _mm512_dpbusd_epi32(sumi, qx[5], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0x55))); - sumi = _mm512_dpbusd_epi32(sumi, qx[6], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xaa))); - sumi = _mm512_dpbusd_epi32(sumi, qx[7], _mm512_shuffle_epi32(yh, _MM_PERM_ENUM(0xff))); - return sumi; -} -inline __m256i qx_r8_q8_dot_product(const __m256i * qx, const int8_t * y) { - auto y4l = _mm_loadu_si128((const __m128i*)y+0); - auto y4h = _mm_loadu_si128((const __m128i*)y+1); - auto yl = MM256_SET_M128I(y4l, y4l); - auto yh = MM256_SET_M128I(y4h, y4h); - auto sumi = _mm256_setzero_si256(); - sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(yl, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(yl, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(yl, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(yl, 0xff)); - sumi = _mm256_dpbusd_epi32(sumi, qx[4], _mm256_shuffle_epi32(yh, 0x00)); - sumi = _mm256_dpbusd_epi32(sumi, qx[5], _mm256_shuffle_epi32(yh, 0x55)); - sumi = _mm256_dpbusd_epi32(sumi, qx[6], _mm256_shuffle_epi32(yh, 0xaa)); - sumi = _mm256_dpbusd_epi32(sumi, qx[7], _mm256_shuffle_epi32(yh, 0xff)); - return sumi; -} -inline __m256i q8_0_r8_dot_product(const uint8_t * x, const int8_t * y, __m256i * qx) { - for (int i = 0; i < 8; ++i) { - qx[i] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)x+i), _mm256_set1_epi8(127)); - } - return qx_r8_q8_dot_product(qx, y); -} -template -static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%16 == 0); - Q8 q8(info); - int nb = n / QK8_0; - if constexpr (nrc_y == 1) { - __m256 acc[2] = {}; - __m256i qx[8]; - float d8[8]; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[0][ib4].d)), 16); - _mm256_storeu_ps(d8, _mm256_castsi256_ps(aux)); - for (int k = 0; k < 4; ++k) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); - auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[4*ib4+k].qs, q8.y[0][ib4].qs+32*k, qx); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[k])); - acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(d8[k+4]), acc[1]); - } - } - if (4*(nb/4) < nb) { - auto qy = (const block_q8_1 *)q8.y[0]; - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); - auto sumi = q8_0_r8_dot_product((const uint8_t *)iq8[ib].qs, qy[ib].qs, qx); - ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(d))); - acc[0] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[0]); - acc[1] = _mm256_fmadd_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(s)), acc[1]); - } - } - info.store(ix, 0, _mm256_fmadd_ps(_mm256_set1_ps(-127.f), acc[1], acc[0])); - acc[0] = acc[1] = _mm256_setzero_ps(); - } - } else { - __m512 acc[2*nrc_y] = {}; - __m512i qx[8]; - float d8[8*nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 16) { - const block_q8_0_r8 * q8l = (const block_q8_0_r8 *)((const char *)vx + (ix+0)*bx); - const block_q8_0_r8 * q8h = (const block_q8_0_r8 *)((const char *)vx + (ix+8)*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto aux = _mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)q8.y[iy][ib4].d)), 16); - _mm256_storeu_ps(d8+8*iy, _mm256_castsi256_ps(aux)); - } - for (int k = 0; k < 4; ++k) { - auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[4*ib4+k].d)); - auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[4*ib4+k].d)); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - for (int j = 0; j < 8; ++j) { - qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[4*ib4+k].qs+j)), - _mm256_loadu_si256((const __m256i *)q8h[4*ib4+k].qs+j), 1); - qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = qx_r8_q8_dot_product(qx, q8.y[iy][ib4].qs+32*k); - auto dy = _mm512_set1_ps(d8[8*iy+k]); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(d8[8*iy+k+4]), acc[2*iy+1]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8l[ib].d)); - auto scales2 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)q8h[ib].d)); - auto scales = _mm512_insertf32x8(_mm512_castps256_ps512(scales1), scales2, 1); - for (int j = 0; j < 8; ++j) { - qx[j] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8l[ib].qs+j)), - _mm256_loadu_si256((const __m256i *)q8h[ib].qs+j), 1); - qx[j] = _mm512_add_epi8(qx[j], _mm512_set1_epi8(127)); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_1 *)q8.y[iy]; - auto sumi = qx_r8_q8_dot_product(qx, qy[ib].qs); - ggml_bf16_t d, s; d.bits = qy[ib].d; s.bits = qy[ib].s; - auto dy = _mm512_set1_ps(GGML_BF16_TO_FP32(d)); - acc[2*iy+0] = _mm512_fmadd_ps(_mm512_mul_ps(scales, dy), _mm512_cvtepi32_ps(sumi), acc[2*iy+0]); - acc[2*iy+1] = _mm512_fmadd_ps(scales, _mm512_set1_ps(GGML_BF16_TO_FP32(s)), acc[2*iy+1]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sum512 = _mm512_fmadd_ps(_mm512_set1_ps(-127.f), acc[2*iy+1], acc[2*iy+0]); - info.store(ix, iy, sum512); - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_ps(); - } - } - } -} -#else -template -static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%8 == 0); - Q8 q8(info); - auto m1 = _mm256_set1_epi16(1); - int nb = n / QK8_0; - __m256 acc[nrc_y] = {}; - float d8[4*nrc_y]; - __m256i qx[4], sx[4]; - auto dot = [&qx, &sx, &m1] (const int8_t * qy) { - auto y128 = _mm_loadu_si128((const __m128i*)qy); - auto y = MM256_SET_M128I(y128, y128); - auto sumi1 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]))) - ); - auto sumi2 = _mm256_add_epi32( - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]))), - _mm256_madd_epi16(m1, _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]))) - ); - return _mm256_add_epi32(sumi1, sumi2); - }; - for (int ix = 0; ix < nrc_x; ix += 8) { - const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][ib4].d)), 16)); - _mm_storeu_ps(d8 + 4*iy, scales); - } - for (int k = 0; k < 4; ++k) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*ib4+k].d)); - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+j); - sx[j] = _mm256_sign_epi8(qx[j], qx[j]); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(q8.y[iy][ib4].qs+32*k); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*ib4+k].qs+4+j); - sx[j] = _mm256_sign_epi8(qx[j], qx[j]); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto sumi = dot(q8.y[iy][ib4].qs+32*k+16); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+k])); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - } - } - for (int ib = 4*(nb/4); ib < nb; ++ib) { - auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ib].d)); - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+j); - sx[j] = _mm256_sign_epi8(qx[j], qx[j]); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_2 *)q8.y[iy]; - auto sumi = dot(qy[ib].qs); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - for (int j = 0; j < 4; ++j) { - qx[j] = _mm256_loadu_si256((const __m256i *)iq8[ib].qs+4+j); - sx[j] = _mm256_sign_epi8(qx[j], qx[j]); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto qy = (const block_q8_2 *)q8.y[iy]; - auto sumi = dot(qy[ib].qs+16); - auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(GGML_BF16_TO_FP32(ggml_bf16_t{qy[ib].d}))); - acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = _mm256_setzero_ps(); - } - } -} -#endif - template static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%8 == 0); @@ -4968,19 +4035,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_Q6_0: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: - return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs); + case GGML_TYPE_Q4_0_R8: + case GGML_TYPE_Q5_0_R4: + case GGML_TYPE_Q6_0_R4: + case GGML_TYPE_Q8_0_R8: case GGML_TYPE_IQ4_NL_R4: - assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_iq4_nl_r4_q8_2<1>; - mm.funcs[1] = mul_mat_iq4_nl_r4_q8_2<2>; - mm.funcs[2] = mul_mat_iq4_nl_r4_q8_2<3>; - mm.funcs[3] = mul_mat_iq4_nl_r4_q8_2<4>; - mm.funcs[4] = mul_mat_iq4_nl_r4_q8_2<5>; - mm.funcs[5] = mul_mat_iq4_nl_r4_q8_2<6>; - mm.funcs[6] = mul_mat_iq4_nl_r4_q8_2<7>; - mm.funcs[7] = mul_mat_iq4_nl_r4_q8_2<8>; - expected_typeB = GGML_TYPE_Q8_2_X4; - break; + return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, mm.funcs, mm.func16); case GGML_TYPE_IQ4_XS_R8: assert (ne00 % QK_K == 0); mm.funcs[0] = mul_mat_iq4_xs_r8_q8_k<1>; @@ -5248,57 +4308,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { #endif expected_typeB = GGML_TYPE_Q8_K; break; - case GGML_TYPE_Q4_0_R8: - assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q4_0_r8_q8_2<1>; - mm.funcs[1] = mul_mat_q4_0_r8_q8_2<2>; - mm.funcs[2] = mul_mat_q4_0_r8_q8_2<3>; - mm.funcs[3] = mul_mat_q4_0_r8_q8_2<4>; - mm.funcs[4] = mul_mat_q4_0_r8_q8_2<5>; - mm.funcs[5] = mul_mat_q4_0_r8_q8_2<6>; - mm.funcs[6] = mul_mat_q4_0_r8_q8_2<7>; - mm.funcs[7] = mul_mat_q4_0_r8_q8_2<8>; -#ifdef HAVE_FANCY_SIMD - mm.func16 = mul_mat_q4_0_r8_q8_2<16>; -#endif - expected_typeB = GGML_TYPE_Q8_2_X4; - break; - case GGML_TYPE_Q5_0_R4: - assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q5_0_r4_q8_2<1>; - mm.funcs[1] = mul_mat_q5_0_r4_q8_2<2>; - mm.funcs[2] = mul_mat_q5_0_r4_q8_2<3>; - mm.funcs[3] = mul_mat_q5_0_r4_q8_2<4>; - mm.funcs[4] = mul_mat_q5_0_r4_q8_2<5>; - mm.funcs[5] = mul_mat_q5_0_r4_q8_2<6>; - mm.funcs[6] = mul_mat_q5_0_r4_q8_2<7>; - mm.funcs[7] = mul_mat_q5_0_r4_q8_2<8>; - expected_typeB = GGML_TYPE_Q8_2_X4; - break; - case GGML_TYPE_Q6_0_R4: - assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q6_0_r4_q8_2<1>; - mm.funcs[1] = mul_mat_q6_0_r4_q8_2<2>; - mm.funcs[2] = mul_mat_q6_0_r4_q8_2<3>; - mm.funcs[3] = mul_mat_q6_0_r4_q8_2<4>; - mm.funcs[4] = mul_mat_q6_0_r4_q8_2<5>; - mm.funcs[5] = mul_mat_q6_0_r4_q8_2<6>; - mm.funcs[6] = mul_mat_q6_0_r4_q8_2<7>; - mm.funcs[7] = mul_mat_q6_0_r4_q8_2<8>; - expected_typeB = GGML_TYPE_Q8_2_X4; - break; - case GGML_TYPE_Q8_0_R8: - assert (ne00 % QK4_NL == 0); - mm.funcs[0] = mul_mat_q8_0_r8_q8_2<1>; - mm.funcs[1] = mul_mat_q8_0_r8_q8_2<2>; - mm.funcs[2] = mul_mat_q8_0_r8_q8_2<3>; - mm.funcs[3] = mul_mat_q8_0_r8_q8_2<4>; - mm.funcs[4] = mul_mat_q8_0_r8_q8_2<5>; - mm.funcs[5] = mul_mat_q8_0_r8_q8_2<6>; - mm.funcs[6] = mul_mat_q8_0_r8_q8_2<7>; - mm.funcs[7] = mul_mat_q8_0_r8_q8_2<8>; - expected_typeB = GGML_TYPE_Q8_2_X4; - break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: