diff --git a/ggml/src/iqk/iqk_gemm_1bit.cpp b/ggml/src/iqk/iqk_gemm_1bit.cpp index 499224c4..f7e4058e 100644 --- a/ggml/src/iqk/iqk_gemm_1bit.cpp +++ b/ggml/src/iqk/iqk_gemm_1bit.cpp @@ -1043,6 +1043,492 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI } } +template struct Q8_K64 { + + constexpr static int nrc_y = nrc; + + Q8_K64(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) { + const float * dptr = (const float *)info.src1_row(iy); + std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); + y[iy] = (const int8_t *)(dptr + 8); + } + } + + inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); } + inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); } + inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); } + + float d[8*nrc_y]; + const int8_t * y[nrc_y]; +}; + +struct DequantizerIQ1BN { + const __m256i m1_8 = _mm256_set1_epi8(1); + static __m256i load_shuffle(int i) { + static const uint8_t data[128] = { + 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255, + 3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255, + 6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255, + 9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255, + }; + return _mm256_loadu_si256((const __m256i*)data + i); + } + const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) }; + const __m256i mult[4] = { + _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), + }; + const __m256i m3 = _mm256_set1_epi16(3); +#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ + const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); +#endif + + IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const { + auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! + auto data = MM256_SET_M128I(data128, data128); + auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3); + auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3); + auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3); + auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3); +#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ + v1 = _mm256_permutex2var_epi8(val1, bmask, val2); + v2 = _mm256_permutex2var_epi8(val3, bmask, val4); +#else + v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216); + v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216); +#endif + } + +}; + +template +IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + Q8_K64 q8(info); + DequantizerIQ1BN deq; + __m256i accd[nrc_y]; + __m256i val[4]; + +#ifndef HAVE_FANCY_SIMD + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + const block_iq1_bn * x; + const char * cx0 = (const char *)vx; + float scale; + ggml_half d16; + + for (int ix = 0; ix < nrc_x; ++ix) { + + const char * cx = cx0 + ix*bx; + std::memcpy(&d16, cx, sizeof(d16)); + scale = GGML_FP16_TO_FP32(d16); + cx += sizeof(d16); + x = (const block_iq1_bn *)cx; + + if constexpr (nrc_y == 1) { + __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); + for (int i = 0; i < nb/2; ++i) { + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); +#ifdef HAVE_FANCY_SIMD + acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); + acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); + acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1)); + acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2)); +#endif + } + accd[0] = _mm256_add_epi32(acc1, acc2); + } + else { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); + + for (int i = 0; i < nb/2; ++i) { + + deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); + deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); + + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], + val[0], q8.load_quants(iy, i, 0)), + val[1], q8.load_quants(iy, i, 1)), + val[2], q8.load_quants(iy, i, 2)), + val[3], q8.load_quants(iy, i, 3)); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))); + dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); + accd[iy] = _mm256_add_epi32(dot1, accd[iy]); +#endif + } + } + } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare_iq1bn_quants(x + i, val[0], val[1]); + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], + val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1)))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); +#endif + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto vd = q8.scale(iy); + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); + auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); + info.store(ix, iy, scale*hsum_float_4(sumf)); + } + + } +} + +struct DequantizeIQ2BN final : public BaseDequantizer { + DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} + + IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const { + auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs); + auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2); + make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0); + make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2); + } + IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const { + val[0] = _mm256_and_si256(q2_1, mask2); + val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2); + } + IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const { + auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val); + } + const __m256i m1_8 = _mm256_set1_epi8(1); + const __m256i mf_8 = _mm256_set1_epi8(16); + const __m256i mask2 = _mm256_set1_epi8(0x03); + const __m256i mask3 = _mm256_set1_epi8(0x30); +}; + +template +IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + const int nb = n / QK_IQ1BN; + Q8_K64 q8(info); + DequantizeIQ2BN deq(vx, bx); + __m256i accd[nrc_y]; + __m256i val[4]; + +#ifndef HAVE_FANCY_SIMD + const auto m1_16 = _mm256_set1_epi16(1); +#endif + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + if constexpr (nrc_y == 1) { + __m256i acc[2] = {}; + for (int i = 0; i < nb/2; ++i) { + deq.prepare4(i, val); +#ifdef HAVE_FANCY_SIMD + acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)), + val[1], q8.load_quants(0, i, 1)); + acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), + val[3], q8.load_quants(0, i, 3)); +#else + auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); + auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); + acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); + acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); +#endif + } + accd[0] = _mm256_add_epi32(acc[0], acc[1]); + } + else { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); + + for (int i = 0; i < nb/2; ++i) { + deq.prepare4(i, val); + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], + val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), + val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))), + _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), + _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); +#endif + } + } + } + int i = 2*(nb/2); + if (i < nb) { + deq.prepare2(i, val); + for (int iy = 0; iy < nrc_y; ++iy) { +#ifdef HAVE_FANCY_SIMD + accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), + val[1], q8.load_quants(iy, i/2, 1)); +#else + auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), + _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0)))); + accd[iy] = _mm256_add_epi32(dot, accd[iy]); +#endif + } + } + + for (int iy = 0; iy < nrc_y; ++iy) { + auto vd = q8.scale(iy); + auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); + auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); + info.store(ix, iy, deq.d*hsum_float_4(sumf)); + } + } +} + +template +static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + Q8_16 q8(info); + auto m3 = _mm256_set1_epi8(0x3); + auto m1 = _mm256_set1_epi16(1); + int nb = n / QK_IQ1BN; + __m256i qx[4]; + if constexpr (nrc_y > 4) { + __m256i acc[nrc_y] = {}; + __m128 sum4[nrc_y]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4); + sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4); + acc[iy] = _mm256_setzero_si256(); + } + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); + auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]); + s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4); + info.store(ix, iy, s4); + acc[iy] = _mm256_setzero_si256(); + } + } + } else { + __m256i acc[2*nrc_y] = {}; + for (int ix = 0; ix < nrc_x; ix += 4) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+0); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); + qx[0] = _mm256_and_si256(bits, m3); + qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); + qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); + qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants(iy, 2*ib+1); + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]); + auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]); + auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256(); + } + } + } +} + + +#ifdef HAVE_FANCY_SIMD +template +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + if constexpr (nrc_y == 1) { + mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x); + } else { + Q8_16 q8(info); + auto m3 = _mm512_set1_epi8(0x3); + int nb = n / QK_IQ1BN; + __m512i acc[2*nrc_y] = {}; + __m512i qx[8]; + for (int ix = 0; ix < nrc_x/8; ++ix) { + const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx); + const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx); + auto dl = _mm_loadu_ps(dptr1); + auto dh = _mm_loadu_ps(dptr2); + const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4); + const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(bits_h, m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3); + qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3); + qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy); + sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)); + acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy); + acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + __m128 sum4; + for (int k = 0; k < 2; ++k) { + const auto& dx = k == 0 ? dl : dh; + auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]); + sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(8*ix + 4*k, iy, sum4); + } + acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); + } + } + if (int ix = 8*(nrc_x/8); ix < nrc_x) { + const float * dptr = (const float *)((const char *)vx + ix*bx); + auto dl = _mm_loadu_ps(dptr); + const uint8_t * iq2l = (const uint8_t *)(dptr + 4); + for (int ib = 0; ib < nb; ++ib) { + auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); + qx[0] = _mm512_and_si512(bits_l, m3); + qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); + qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); + qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); + for (int iy = 0; iy < nrc_y; ++iy) { + auto y = q8.load_quants64(iy, ib); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); + acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); + } + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto dy = q8.scale(iy); + auto sumf = _mm512_cvtepi32_ps(acc[iy]); + auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); + sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); + sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); + info.store(ix, iy, sum4); + } + } + } +} +#else +template +static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + if (nrc_x%4) { + printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); + GGML_ABORT("fatal error"); + } + mul_mat_iq2_bn_r4_q8_k16_avx2(n, vx, bx, info, nrc_x); +} +#endif + + } // namespace bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array& funcs, mul_mat_t& func16) { @@ -1095,6 +1581,43 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array; #endif break; + case GGML_TYPE_IQ1_BN: + assert (ne00 % QK_IQ1BN == 0); + funcs[0] = mul_mat_iq1bn_q8_K64<1>; + funcs[1] = mul_mat_iq1bn_q8_K64<2>; + funcs[2] = mul_mat_iq1bn_q8_K64<3>; + funcs[3] = mul_mat_iq1bn_q8_K64<4>; + funcs[4] = mul_mat_iq1bn_q8_K64<5>; + funcs[5] = mul_mat_iq1bn_q8_K64<6>; + funcs[6] = mul_mat_iq1bn_q8_K64<7>; + funcs[7] = mul_mat_iq1bn_q8_K64<8>; + expected_typeB = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ2_BN: + assert (ne00 % QK_IQ1BN == 0); + funcs[0] = mul_mat_iq2bn_q8_K64<1>; + funcs[1] = mul_mat_iq2bn_q8_K64<2>; + funcs[2] = mul_mat_iq2bn_q8_K64<3>; + funcs[3] = mul_mat_iq2bn_q8_K64<4>; + funcs[4] = mul_mat_iq2bn_q8_K64<5>; + funcs[5] = mul_mat_iq2bn_q8_K64<6>; + funcs[6] = mul_mat_iq2bn_q8_K64<7>; + funcs[7] = mul_mat_iq2bn_q8_K64<8>; + expected_typeB = GGML_TYPE_Q8_K64; + break; + case GGML_TYPE_IQ2_BN_R4: + assert (ne00 % QK_IQ1BN == 0); + funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; + funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; + funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; + funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; + funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; + funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; + funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; + funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; + expected_typeB = GGML_TYPE_Q8_K16; + break; + default: return false; } diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index fbd6c730..03d3afa6 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1185,228 +1185,6 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf #endif // Zen4 or vanilla AVX2 -template -static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if (nrc_x%4) { - printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); - GGML_ABORT("fatal error"); - } - Q8_16 q8(info); - auto m3 = _mm256_set1_epi8(0x3); - auto m1 = _mm256_set1_epi16(1); - int nb = n / QK_IQ1BN; - __m256i qx[4]; - if constexpr (nrc_y > 4) { - __m256i acc[nrc_y] = {}; - __m128 sum4[nrc_y]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const float * dptr = (const float *)((const char *)vx + ix*bx); - auto dl = _mm_loadu_ps(dptr); - const uint8_t * iq2l = (const uint8_t *)(dptr + 4); - for (int ib = 0; ib < nb; ++ib) { - auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); - qx[0] = _mm256_and_si256(bits, m3); - qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); - qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); - qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants(iy, 2*ib+0); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto dy = q8.scale(iy); - auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); - auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); - s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4); - sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4); - acc[iy] = _mm256_setzero_si256(); - } - for (int ib = 0; ib < nb; ++ib) { - auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); - qx[0] = _mm256_and_si256(bits, m3); - qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); - qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); - qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants(iy, 2*ib+1); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto dy = q8.scale(iy); - auto sumf1 = _mm256_cvtepi32_ps(acc[iy]); - auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]); - s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4); - info.store(ix, iy, s4); - acc[iy] = _mm256_setzero_si256(); - } - } - } else { - __m256i acc[2*nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { - const float * dptr = (const float *)((const char *)vx + ix*bx); - auto dl = _mm_loadu_ps(dptr); - const uint8_t * iq2l = (const uint8_t *)(dptr + 4); - for (int ib = 0; ib < nb; ++ib) { - auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0); - qx[0] = _mm256_and_si256(bits, m3); - qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); - qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); - qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants(iy, 2*ib+0); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); - } - bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1); - qx[0] = _mm256_and_si256(bits, m3); - qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3); - qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3); - qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants(iy, 2*ib+1); - auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), - _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); - auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), - _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); - acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2))); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto dy = q8.scale(iy); - auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]); - auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]); - auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); - sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); - sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); - sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); - sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); - info.store(ix, iy, sum4); - acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256(); - } - } - } -} - -#ifdef HAVE_FANCY_SIMD -template -static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if (nrc_x%4) { - printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); - GGML_ABORT("fatal error"); - } - if constexpr (nrc_y == 1) { - mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x); - } else { - Q8_16 q8(info); - auto m3 = _mm512_set1_epi8(0x3); - int nb = n / QK_IQ1BN; - __m512i acc[2*nrc_y] = {}; - __m512i qx[8]; - for (int ix = 0; ix < nrc_x/8; ++ix) { - const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx); - const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx); - auto dl = _mm_loadu_ps(dptr1); - auto dh = _mm_loadu_ps(dptr2); - const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4); - const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4); - for (int ib = 0; ib < nb; ++ib) { - auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); - auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib); - qx[0] = _mm512_and_si512(bits_l, m3); - qx[1] = _mm512_and_si512(bits_h, m3); - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3); - qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); - qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3); - qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); - qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants64(iy, ib); - auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)); - acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy); - acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy); - sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)); - acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy); - acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy); - sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)); - acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy); - acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy); - sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)); - acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy); - acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto dy = q8.scale(iy); - __m128 sum4; - for (int k = 0; k < 2; ++k) { - const auto& dx = k == 0 ? dl : dh; - auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]); - sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00))); - sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4); - sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); - sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4); - sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4); - info.store(8*ix + 4*k, iy, sum4); - } - acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512(); - } - } - if (int ix = 8*(nrc_x/8); ix < nrc_x) { - const float * dptr = (const float *)((const char *)vx + ix*bx); - auto dl = _mm_loadu_ps(dptr); - const uint8_t * iq2l = (const uint8_t *)(dptr + 4); - for (int ib = 0; ib < nb; ++ib) { - auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib); - qx[0] = _mm512_and_si512(bits_l, m3); - qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3); - qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3); - qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3); - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = q8.load_quants64(iy, ib); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00))); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55))); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa))); - acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff))); - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto dy = q8.scale(iy); - auto sumf = _mm512_cvtepi32_ps(acc[iy]); - auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00))); - sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4); - sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4); - sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4); - sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4); - info.store(ix, iy, sum4); - } - } - } -} -#else -template -static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - if (nrc_x%4) { - printf("%s: %d is not a multiple of 4\n", __func__, nrc_x); - GGML_ABORT("fatal error"); - } - mul_mat_iq2_bn_r4_q8_k16_avx2(n, vx, bx, info, nrc_x); -} -#endif - #ifdef HAVE_FANCY_SIMD template static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -5127,268 +4905,6 @@ static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data } } -template struct Q8_K64 { - - constexpr static int nrc_y = nrc; - - Q8_K64(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) { - const float * dptr = (const float *)info.src1_row(iy); - std::memcpy(d + 8*iy, dptr, 8*sizeof(float)); - y[iy] = (const int8_t *)(dptr + 8); - } - } - - inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); } - inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); } - inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); } - - float d[8*nrc_y]; - const int8_t * y[nrc_y]; -}; - -struct DequantizerIQ1BN { - const __m256i m1_8 = _mm256_set1_epi8(1); - static __m256i load_shuffle(int i) { - static const uint8_t data[128] = { - 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255, - 3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255, - 6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255, - 9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255, - }; - return _mm256_loadu_si256((const __m256i*)data + i); - } - const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) }; - const __m256i mult[4] = { - _mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), - _mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), - _mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), - _mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100), - }; - const __m256i m3 = _mm256_set1_epi16(3); -#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ - const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); -#endif - - IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const { - auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes! - auto data = MM256_SET_M128I(data128, data128); - auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3); - auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3); - auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3); - auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3); -#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__ - v1 = _mm256_permutex2var_epi8(val1, bmask, val2); - v2 = _mm256_permutex2var_epi8(val3, bmask, val4); -#else - v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216); - v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216); -#endif - } - -}; - -template -IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - const int nb = n / QK_IQ1BN; - Q8_K64 q8(info); - DequantizerIQ1BN deq; - __m256i accd[nrc_y]; - __m256i val[4]; - -#ifndef HAVE_FANCY_SIMD - const auto m1_16 = _mm256_set1_epi16(1); -#endif - - const block_iq1_bn * x; - const char * cx0 = (const char *)vx; - float scale; - ggml_half d16; - - for (int ix = 0; ix < nrc_x; ++ix) { - - const char * cx = cx0 + ix*bx; - std::memcpy(&d16, cx, sizeof(d16)); - scale = GGML_FP16_TO_FP32(d16); - cx += sizeof(d16); - x = (const block_iq1_bn *)cx; - - if constexpr (nrc_y == 1) { - __m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256(); - for (int i = 0; i < nb/2; ++i) { - deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); - deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); -#ifdef HAVE_FANCY_SIMD - acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1)); - acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3)); -#else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), - _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), - _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); - acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1)); - acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2)); -#endif - } - accd[0] = _mm256_add_epi32(acc1, acc2); - } - else { - - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); - - for (int i = 0; i < nb/2; ++i) { - - deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]); - deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]); - - for (int iy = 0; iy < nrc_y; ++iy) { -#ifdef HAVE_FANCY_SIMD - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], - val[0], q8.load_quants(iy, i, 0)), - val[1], q8.load_quants(iy, i, 1)), - val[2], q8.load_quants(iy, i, 2)), - val[3], q8.load_quants(iy, i, 3)); -#else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), - _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), - _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))); - dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2)); - accd[iy] = _mm256_add_epi32(dot1, accd[iy]); -#endif - } - } - } - int i = 2*(nb/2); - if (i < nb) { - deq.prepare_iq1bn_quants(x + i, val[0], val[1]); - for (int iy = 0; iy < nrc_y; ++iy) { -#ifdef HAVE_FANCY_SIMD - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], - val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1)); -#else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), - _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1)))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); -#endif - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - auto vd = q8.scale(iy); - auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); - auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); - info.store(ix, iy, scale*hsum_float_4(sumf)); - } - - } -} - -struct DequantizeIQ2BN final : public BaseDequantizer { - DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} - - IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const { - auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs); - auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2); - make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0); - make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2); - } - IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const { - val[0] = _mm256_and_si256(q2_1, mask2); - val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2); - } - IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const { - auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val); - } - const __m256i m1_8 = _mm256_set1_epi8(1); - const __m256i mf_8 = _mm256_set1_epi8(16); - const __m256i mask2 = _mm256_set1_epi8(0x03); - const __m256i mask3 = _mm256_set1_epi8(0x30); -}; - -template -IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - const int nb = n / QK_IQ1BN; - Q8_K64 q8(info); - DequantizeIQ2BN deq(vx, bx); - __m256i accd[nrc_y]; - __m256i val[4]; - -#ifndef HAVE_FANCY_SIMD - const auto m1_16 = _mm256_set1_epi16(1); -#endif - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - if constexpr (nrc_y == 1) { - __m256i acc[2] = {}; - for (int i = 0; i < nb/2; ++i) { - deq.prepare4(i, val); -#ifdef HAVE_FANCY_SIMD - acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)), - val[1], q8.load_quants(0, i, 1)); - acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)), - val[3], q8.load_quants(0, i, 3)); -#else - auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)), - _mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1))); - auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)), - _mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3))); - acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1)); - acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2)); -#endif - } - accd[0] = _mm256_add_epi32(acc[0], acc[1]); - } - else { - - for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256(); - - for (int i = 0; i < nb/2; ++i) { - deq.prepare4(i, val); - for (int iy = 0; iy < nrc_y; ++iy) { -#ifdef HAVE_FANCY_SIMD - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], - val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)), - val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3)); -#else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16( - _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)), - _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))), - _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)), - _mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3))))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); -#endif - } - } - } - int i = 2*(nb/2); - if (i < nb) { - deq.prepare2(i, val); - for (int iy = 0; iy < nrc_y; ++iy) { -#ifdef HAVE_FANCY_SIMD - accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)), - val[1], q8.load_quants(iy, i/2, 1)); -#else - auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)), - _mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0)))); - accd[iy] = _mm256_add_epi32(dot, accd[iy]); -#endif - } - } - - for (int iy = 0; iy < nrc_y; ++iy) { - auto vd = q8.scale(iy); - auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1)); - auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy)); - info.store(ix, iy, deq.d*hsum_float_4(sumf)); - } - } -} - template void MulMat::set_functions(MulMat& m) { #ifdef HAVE_FANCY_SIMD m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1; @@ -5445,44 +4961,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ5_K: case GGML_TYPE_IQ6_K: return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs) : false; - case GGML_TYPE_IQ1_BN: - assert (ne00 % QK_IQ1BN == 0); - mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>; - mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>; - mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>; - mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>; - mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>; - mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>; - mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>; - mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>; - expected_typeB = GGML_TYPE_Q8_K64; - break; - case GGML_TYPE_IQ2_BN: - assert (ne00 % QK_IQ1BN == 0); - mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>; - mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>; - mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>; - mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>; - mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>; - mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>; - mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>; - mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>; - expected_typeB = GGML_TYPE_Q8_K64; - break; - case GGML_TYPE_IQ2_BN_R4: - assert (ne00 % QK_IQ1BN == 0); - mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>; - mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>; - mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>; - mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>; - mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>; - mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>; -//#ifdef HAVE_FANCY_SIMD - mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>; - mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>; -//#endif - expected_typeB = GGML_TYPE_Q8_K16; - break; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5824,6 +5302,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: + case GGML_TYPE_IQ1_BN: + case GGML_TYPE_IQ2_BN: + case GGML_TYPE_IQ2_BN_R4: return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16); default: