From bcacf33350d47335f6d488d601406b7f015a70cd Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 19 Apr 2025 09:12:50 +0300 Subject: [PATCH] WIP --- ggml/src/iqk/iqk_mul_mat.cpp | 82 +++++------------------------------- 1 file changed, 10 insertions(+), 72 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 83861d37..d75e4d0c 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -8654,6 +8654,14 @@ void mul_mat_q8_0_q8_x4_2x2(int n, const void * vx, size_t bx, const DataInfo& i acc[0] = _mm256_add_ps(acc[0], acc[1]); // x1,y0 x1,y0, x1,y1, x1,y1, x1,y0 x1,y0, x1,y1, x1,y1, acc[2] = _mm256_add_ps(acc[2], acc[3]); + //auto sum1 = _mm_add_ps(_mm256_castps256_ps128(acc[0]), _mm256_extractf128_ps(acc[0], 1)); + //auto sum2 = _mm_add_ps(_mm256_castps256_ps128(acc[2]), _mm256_extractf128_ps(acc[2], 1)); + //auto sum = _mm_hadd_ps(sum1, sum2); + //_mm_storeu_ps(daux, sum); + //info.store(ix+0, 0, daux[0]); + //info.store(ix+0, 1, daux[1]); + //info.store(ix+1, 0, daux[2]); + //info.store(ix+1, 1, daux[3]); auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[0]), _mm256_extractf128_ps(acc[0], 1)); _mm_storeu_ps(daux, sum); info.store(ix+0, 0, daux[0] + daux[1]); @@ -8666,76 +8674,6 @@ void mul_mat_q8_0_q8_x4_2x2(int n, const void * vx, size_t bx, const DataInfo& i } } -//void mul_mat_q8_0_q8_x4_2x2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { -// Q8<2, block_q8_0> q8(info); -// int nblock = n/QK8_0; -// auto cx = (const char *)vx; -// auto y4l = (const block_q8_0_x4 *)q8.y[0]; -// auto y4h = (const block_q8_0_x4 *)q8.y[1]; -// ggml_half d4[4]; -// __m256i dot[4]; -// for (int ix = 0; ix < nrc_x; ix += 2) { -// auto q8xl = (const block_q8_0 *)cx; cx += bx; -// auto q8xh = (const block_q8_0 *)cx; cx += bx; -// auto accl = _mm256_setzero_ps(); -// auto acch = _mm256_setzero_ps(); -// for (int ib4 = 0; ib4 < nblock/4; ++ib4) { -// auto d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4l[ib4].d)); -// auto d4yl = _mm256_set_m128(d4y_128, d4y_128); -// d4y_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)y4h[ib4].d)); -// auto d4yh = _mm256_set_m128(d4y_128, d4y_128); -// for (int k = 0; k < 4; ++k) { -// d4[2*k+0] = q8xl[k].d; -// d4[2*k+1] = q8xh[k].d; -// auto qxl = _mm256_loadu_si256((const __m256i *)q8xl[k].qs); -// auto qxh = _mm256_loadu_si256((const __m256i *)q8xh[k].qs); -// auto uxl = _mm256_sign_epi8(qxl, qxl); -// auto uxh = _mm256_sign_epi8(qxh, qxh); -// auto qyl = _mm256_loadu_si256((const __m256i*)y4l[ib4].qs + k); -// auto qyh = _mm256_loadu_si256((const __m256i*)y4h[ib4].qs + k); -// auto pll = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyl, qxl))); -// auto plh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxl, _mm256_sign_epi8(qyh, qxl))); -// auto phl = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyl, qxh))); -// auto phh = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(uxh, _mm256_sign_epi8(qyh, qxh))); -// auto pyl = _mm256_add_epi32(_mm256_unpacklo_epi32(pll, phl), _mm256_unpackhi_epi32(pll, phl)); -// auto pyh = _mm256_add_epi32(_mm256_unpacklo_epi32(plh, phh), _mm256_unpackhi_epi32(plh, phh)); -// // ll, hl, lh, hh, ll, hl, lh, hh -// dot[k] = _mm256_add_epi32(_mm256_unpacklo_epi64(pyl, pyh), _mm256_unpackhi_epi64(pyl, pyh)); -// } -// auto d4x_128 = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)d4)); -// auto d4x = _mm256_set_m128(d4x_128, d4x_128); -// auto d4xyl = _mm256_mul_ps(d4x, d4yl); -// auto d4xyh = _mm256_mul_ps(d4x, d4yh); -// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x00), _mm256_cvtepi32_ps(dot[0]), acc); -// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0x55), _mm256_cvtepi32_ps(dot[1]), acc); -// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xaa), _mm256_cvtepi32_ps(dot[2]), acc); -// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(d4xyl, d4xyh, 0xff), _mm256_cvtepi32_ps(dot[3]), acc); -// q8x += 4; -// } -// if (int ib0 = 4*(nblock/4); ib0 < nblock) { -// auto yl = q8.y[0] + ib0; -// auto yh = q8.y[1] + ib0; -// for (int k = 0; k < nblock - ib0; ++k) { -// auto qx = _mm256_loadu_si256((const __m256i *)q8x[k].qs); -// auto ux = _mm256_sign_epi8(qx, qx); -// auto pl = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yl[k].qs), qx)); -// auto ph = _mm256_maddubs_epi16(ux, _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)yh[k].qs), qx)); -// pl = _mm256_madd_epi16(_mm256_set1_epi16(1), pl); -// ph = _mm256_madd_epi16(_mm256_set1_epi16(1), ph); -// auto p = _mm256_add_epi32(_mm256_unpacklo_epi64(pl, ph), _mm256_unpackhi_epi64(pl, ph)); -// auto d = GGML_FP16_TO_FP32(q8x[k].d); -// auto dxyl = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yl[k].d)); -// auto dxyh = _mm256_set1_ps(d*GGML_FP16_TO_FP32(yh[k].d)); -// acc = _mm256_fmadd_ps(_mm256_shuffle_ps(dxyl, dxyh, 0x00), _mm256_cvtepi32_ps(p), acc); -// } -// } -// auto sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1)); -// float daux[4]; _mm_storeu_ps(daux, sum); -// info.store(ix, 0, daux[0] + daux[1]); -// info.store(ix, 1, daux[2] + daux[3]); -// } -//} - struct Dequantizer4bit { const __m256i m4 = _mm256_set1_epi8(0xf); inline __m256i dequant(const uint8_t * qs) const { @@ -16877,8 +16815,8 @@ struct FlashQKfp32 { #ifdef HAVE_FANCY_SIMD MAKE_FUNCS(mul_mat_qX_1_q8_2_T