diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index a9a88c04..b70640c1 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3126,12 +3126,10 @@ static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4); auto d4 = _mm_loadu_ps(dptr); for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256 - //const uint32_t * aux32 = (const uint32_t *)iq4[ibl].scales; auto scales = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales); - //h_shift.vec = _mm256_add_epi32(_mm256_set1_epi32(-64), _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1)); + h.vec = _mm256_sub_epi8(_mm256_and_si256(scales, _mm256_set1_epi8(-2)), _mm256_set1_epi8(127)); +#ifndef HAVE_FANCY_SIMD h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 2); - scales = _mm256_and_si256(scales, _mm256_set1_epi8(-2)); - h.vec = _mm256_sub_epi8(scales, _mm256_set1_epi8(127)); { __m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[0])))), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[4])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[0]))))); @@ -3141,14 +3139,6 @@ static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[6])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[2]))))); __m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h.val[3])))), _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[7])), _mm_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[3]))))); - //__m256 v1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h.val[4], h.val[0]))), - // _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h_shift.val[4], h_shift.val[0])))); - //__m256 v2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h.val[5], h.val[1]))), - // _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h_shift.val[5], h_shift.val[1])))); - //__m256 v3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h.val[6], h.val[2]))), - // _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h_shift.val[6], h_shift.val[2])))); - //__m256 v4 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h.val[7], h.val[3]))), - // _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_set_epi32(0, 0, h_shift.val[7], h_shift.val[3])))); for (int iy = 0; iy < nrc_y; ++iy) { auto m8 = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums); acc[iy] = _mm256_fmadd_ps(v1, _mm256_shuffle_ps(m8, m8, 0x00), acc[iy]); @@ -3157,30 +3147,27 @@ static void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data acc[iy] = _mm256_fmadd_ps(v4, _mm256_shuffle_ps(m8, m8, 0xff), acc[iy]); } } +#else + h_shift.vec = _mm256_slli_epi16(_mm256_and_si256(scales, _mm256_set1_epi8(1)), 1); +#endif for (int ib = 0; ib < QK_K/32; ++ib) { -// auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); -// //auto ishifts = _mm256_add_epi32(_mm256_set1_epi32(-64), _mm256_cvtepi8_epi32(_mm_set1_epi32((aux32[ib] & 0x01010101) << 1))); -// //auto ishifts = _mm256_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[ib])); -//#ifdef HAVE_FANCY_SIMD -// auto ishifts = _mm256_add_epi32(_mm256_set1_epi32(-64), _mm256_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[ib]))); -//#else -// auto ishifts = _mm256_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[ib])); -//#endif -// auto scales = _mm256_cvtepi32_ps(iscales); -// auto scales_m = _mm256_mul_ps(scales, _mm256_cvtepi32_ps(ishifts)); -// for (int iy = 0; iy < nrc_y; ++iy) { -// float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; -// acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); -// } +#ifdef HAVE_FANCY_SIMD + auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); + auto ishifts = _mm256_add_epi32(_mm256_set1_epi32(-64), _mm256_cvtepi8_epi32(_mm_set1_epi32(h_shift.val[ib]))); + auto scales = _mm256_cvtepi32_ps(iscales); + auto scales_m = _mm256_mul_ps(scales, _mm256_cvtepi32_ps(ishifts)); + for (int iy = 0; iy < nrc_y; ++iy) { + float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib]; + acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]); + } +#endif auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0); auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1); qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits1, m4)); qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(bits2, m4)); qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits1, 4), m4)); qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(_mm256_srli_epi16(bits2, 4), m4)); -#ifdef HAVE_FANCY_SIMD - auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi32(h.val[ib])); -#else +#ifndef HAVE_FANCY_SIMD auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi32(h.val[ib])), s_shuffle); auto s1 = _mm256_sign_epi8(qx[0], qx[0]); auto s2 = _mm256_sign_epi8(qx[1], qx[1]);