diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b72ce2e1..97a4ca1b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -8492,6 +8492,184 @@ void mul_mat_qX_1_q8_2_T(int n, const void * vx, size_t bx, const DataInfo& info } } +void mul_mat_q8_0_q8_x4_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q8<1, block_q8_0> q8(info); + int nblock = n/QK8_0; + auto cx = (const char *)vx; + auto y4 = (const block_q8_0_x4 *)q8.y[0]; + ggml_half d4[4]; + __m256 dot[4]; + for (int ix = 0; ix < nrc_x; ++ix) { + auto q8x = (const block_q8_0 *)cx; + auto acc = _mm256_setzero_ps(); + for (int ib4 = 0; ib4 < nblock/4; ++ib4) { + auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4[ib4].d)); + auto d4y = _mm256_set_m128(d4y_128, d4y_128); + for (int k = 0; k < 4; ++k) { + d4[k] = q8x[k].d; + auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs); + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4[ib4].qs + k), qx)); + dot[k] = _mm256_cvtepi32_ps(_mm256_madd_epi16(_mm256_set1_epi16(1), p)); + } + auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4)); + auto d4x = _mm256_set_m128(d4x_128, d4x_128); + auto d4xy = _mm256_mul_ps(d4x, d4y); + acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0x00), dot[0], acc); + acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0x55), dot[1], acc); + acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0xaa), dot[2], acc); + acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xy, d4xy, 0xff), dot[3], acc); + q8x += 4; + } + if (int ib0 = 4*(nblock/4); ib0 < nblock) { + auto y = q8.y[0] + ib0; + for (int k = 0; k < nblock - ib0; ++k) { + auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs); + auto p = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y[k].qs), qx)); + p = _mm256_madd_epi16(_mm256_set1_epi16(1), p); + auto dxy = GGML_FP16_TO_FP32(q8x[k].d)*GGML_FP16_TO_FP32(y[k].d); + acc = _mm256_fmadd_ps(_mm256_set1_ps(dxy), _mm256_cvtepi32_ps(p), acc); + } + } + info.store(ix, 0, hsum_float_8(acc)); + cx += bx; + } +} + +void mul_mat_q8_0_q8_x4_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q8<2, block_q8_0> q8(info); + int nblock = n/QK8_0; + auto cx = (const char *)vx; + auto y4l = (const block_q8_0_x4 *)q8.y[0]; + auto y4h = (const block_q8_0_x4 *)q8.y[1]; + ggml_half d4[4]; + __m256i dot[4]; + __m256 acc[4] = {}; + for (int ix = 0; ix < nrc_x; ++ix) { + auto q8x = (const block_q8_0 *)cx; + for (int ib4 = 0; ib4 < nblock/4; ++ib4) { + auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d)); + auto d4yl = _mm256_set_m128(d4y_128, d4y_128); + d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d)); + auto d4yh = _mm256_set_m128(d4y_128, d4y_128); + for (int k = 0; k < 4; ++k) { + d4[k] = q8x[k].d; + auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs); + auto ux = _mm256_sign_epi8(qx, qx); + auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k), qx)); + auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k), qx)); + pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl); + ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph); + dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph)); + } + auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4)); + auto d4x = _mm256_set_m128(d4x_128, d4x_128); + auto d4xyl = _mm256_mul_ps(d4x, d4yl); + auto d4xyh = _mm256_mul_ps(d4x, d4yh); + acc[0] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc[0]); + acc[1] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc[1]); + acc[2] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc[2]); + acc[3] = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc[3]); + q8x += 4; + } + if (int ib0 = 4*(nblock/4); ib0 < nblock) { + auto yl = q8.y[0] + ib0; + auto yh = q8.y[1] + ib0; + for (int k = 0; k < nblock - ib0; ++k) { + auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs); + auto ux = _mm256_sign_epi8(qx, qx); + auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx)); + auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx)); + pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl); + ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph); + auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph)); + auto d = GGML_FP16_TO_FP32(q8x[k].d); + auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d)); + auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d)); + acc[k] = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc[k]); + } + } + acc[0] = _mm256_add_ps(acc[0], acc[1]); + acc[2] = _mm256_add_ps(acc[2], acc[3]); + acc[0] = _mm256_add_ps(acc[0], acc[2]); + auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[0]), _mm256_extractf128_ps(acc[0], 1)); + float daux[4]; _mm_storeu_ps(daux, sum); + info.store(ix, 0, daux[0] + daux[1]); + info.store(ix, 1, daux[2] + daux[3]); + acc[0] = acc[1] = acc[2] = acc[3] = _mm256_setzero_ps(); + cx += bx; + } +} + +//void mul_mat_q8_0_q8_x4_2x2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +// Q8<2, block_q8_0> q8(info); +// int nblock = n/QK8_0; +// auto cx = (const char *)vx; +// auto y4l = (const block_q8_0_x4 *)q8.y[0]; +// auto y4h = (const block_q8_0_x4 *)q8.y[1]; +// ggml_half d4[4]; +// __m256i dot[4]; +// for (int ix = 0; ix < nrc_x; ix += 2) { +// auto q8xl = (const block_q8_0 *)cx; cx += bx; +// auto q8xh = (const block_q8_0 *)cx; cx += bx; +// auto accl = _mm256_setzero_ps(); +// auto acch = _mm256_setzero_ps(); +// for (int ib4 = 0; ib4 < nblock/4; ++ib4) { +// auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d)); +// auto d4yl = _mm256_set_m128(d4y_128, d4y_128); +// d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d)); +// auto d4yh = _mm256_set_m128(d4y_128, d4y_128); +// for (int k = 0; k < 4; ++k) { +// d4[2*k+0] = q8xl[k].d; +// d4[2*k+1] = q8xh[k].d; +// auto qxl = _mm256_loadu_si256((const __m256i *)q8xl[k].qs); +// auto qxh = _mm256_loadu_si256((const __m256i *)q8xh[k].qs); +// auto uxl = _mm256_sign_epi8(qxl, qxl); +// auto uxh = _mm256_sign_epi8(qxh, qxh); +// auto qyl = _mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k); +// auto qyh = _mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k); +// auto pll = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyl, qxl))); +// auto plh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyh, qxl))); +// auto phl = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyl, qxh))); +// auto phh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyh, qxh))); +// auto pyl = _mm256_add_epi32(_mm256_unpacklo_epi32(pll, phl), _mm256_unpackhi_epi32(pll, phl)); +// auto pyh = _mm256_add_epi32(_mm256_unpacklo_epi32(plh, phh), _mm256_unpackhi_epi32(plh, phh)); +// // ll, hl, lh, hh, ll, hl, lh, hh +// dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pyl, pyh), _mm256_unpackhi_epi64(pyl, pyh)); +// } +// auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4)); +// auto d4x = _mm256_set_m128(d4x_128, d4x_128); +// auto d4xyl = _mm256_mul_ps(d4x, d4yl); +// auto d4xyh = _mm256_mul_ps(d4x, d4yh); +// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc); +// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc); +// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc); +// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc); +// q8x += 4; +// } +// if (int ib0 = 4*(nblock/4); ib0 < nblock) { +// auto yl = q8.y[0] + ib0; +// auto yh = q8.y[1] + ib0; +// for (int k = 0; k < nblock - ib0; ++k) { +// auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs); +// auto ux = _mm256_sign_epi8(qx, qx); +// auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx)); +// auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx)); +// pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl); +// ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph); +// auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph)); +// auto d = GGML_FP16_TO_FP32(q8x[k].d); +// auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d)); +// auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d)); +// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc); +// } +// } +// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); +// float daux[4]; _mm_storeu_ps(daux, sum); +// info.store(ix, 0, daux[0] + daux[1]); +// info.store(ix, 1, daux[2] + daux[3]); +// } +//} + struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -9156,6 +9334,16 @@ void mul_mat_q80_q80_T(int n, const void * vx, size_t bx, const DataInfo& info, } template void MulMat::set_functions(MulMat& m) { + //if (std::is_same_v) { + // m.funcs[0] = mul_mat_q8_0_q8_x4_1; + // m.funcs[1] = mul_mat_q8_0_q8_x4_2; + // m.funcs[2] = mul_mat_qX_0_q8_0_T; + // m.funcs[3] = mul_mat_qX_0_q8_0_T; + // m.funcs[4] = mul_mat_qX_0_q8_0_T; + // m.funcs[5] = mul_mat_qX_0_q8_0_T; + // m.funcs[6] = mul_mat_qX_0_q8_0_T; + // m.funcs[7] = mul_mat_qX_0_q8_0_T; + //} if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0_T; @@ -16080,12 +16268,13 @@ struct FlashMS { } return F16::reduce_max(vk); } - static inline __m256 apply_mask(int l, const char * mask, __m256 val, __m256 vinf) { - auto m128 = _mm_loadu_si128((const __m128i *)mask+l); - m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); - auto m256 = _mm256_cvtepi16_epi32(m128); - auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); - return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); + static inline __m256 apply_mask(int l, const char * mask, __m256 val, [[maybe_unused]] __m256 vinf) { + return _mm256_add_ps(val, _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)mask+l))); + //auto m128 = _mm_loadu_si128((const __m128i *)mask+l); + //m128 = _mm_cmpeq_epi16(m128, _mm_setzero_si128()); + //auto m256 = _mm256_cvtepi16_epi32(m128); + //auto mf = _mm256_castsi256_ps(_mm256_or_si256(m256, _mm256_slli_epi32(m256, 16))); + //return _mm256_or_ps(_mm256_and_ps(mf, val), _mm256_andnot_ps(mf, vinf)); } #ifdef __AVX512F__ static inline __m512 apply_mask(int l, const char * mask, __m512 val, __m512 vinf) { @@ -16622,6 +16811,8 @@ struct FlashQKfp32 { #ifdef HAVE_FANCY_SIMD MAKE_FUNCS(mul_mat_qX_1_q8_2_T