diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 980a7189..53d5e924 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3447,7 +3447,11 @@ static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn auto m03 = _mm256_set1_epi8(0x03); static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); +#ifdef HAVE_FANCY_SIMD __m256 d4s[nrc_y]; +#else + auto m1 = _mm256_set1_epi16(1); +#endif int nbl = n / QK_K; __m256 acc[nrc_y] = {}; __m256i qx[4]; @@ -3475,6 +3479,7 @@ static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn for (int iy = 0; iy < nrc_y; ++iy) { auto bsums = q8.load_bsums(iy, ibl); auto sumi = _mm256_setzero_si256(); +#ifdef HAVE_FANCY_SIMD sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00)); sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55)); sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa)); @@ -3482,6 +3487,17 @@ static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn auto d8 = _mm256_set1_ps(q8.scale(iy, ibl)); acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]); d4s[iy] = _mm256_mul_ps(d4, d8); +#else + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); + auto d8 = _mm256_set1_ps(q8.scale(iy, ibl)); + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]); + if constexpr (nrc_y == 1) { + d4 = _mm256_mul_ps(d4, d8); + } +#endif } } all_scales1 = _mm256_and_si256(all_scales1, mxf); @@ -3490,7 +3506,11 @@ static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn _mm256_storeu_si256((__m256i *)scales+1, all_scales2); for (int ib = 0; ib < QK_K/32; ++ib) { auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib))); +#ifdef HAVE_FANCY_SIMD auto scales = _mm256_cvtepi32_ps(iscales); +#else + auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales)); +#endif auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib); qx[0] = _mm256_and_si256(lb, m03); qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03); @@ -3498,12 +3518,26 @@ static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataIn qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03); for (int iy = 0; iy < nrc_y; ++iy) { auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib); +#ifdef HAVE_FANCY_SIMD auto sumi = _mm256_setzero_si256(); sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00)); sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55)); sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa)); sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff)); acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, d4s[iy]), _mm256_cvtepi32_ps(sumi), acc[iy]); +#else + auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)), + _mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55))); + auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)), + _mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff))); + // Quants are in 0...3, so we can add add up all of them as int16_t without overflowing + auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)); + if constexpr (nrc_y == 1) { + acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]); + } else { + acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]); + } +#endif } } }