diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index f47f92ee..04b1e36e 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -3960,7 +3960,8 @@ static void mul_mat_bf16_r16_bf16(int n, const void * vx, size_t bx, const DataI #endif template -IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, +//IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, +inline void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, __m256i * isum, int16_t min) { auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row @@ -4008,6 +4009,46 @@ IQK_ALWAYS_INLINE void iq234_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8 } } +template +inline void iq2345_k_accum_mins(int ibl, __m256i i8scales1, __m256i i8scales2, const Q8& q8, __m256i shuff, + __m256i extra, __m256i * isum, int8_t min, int8_t delta) { + auto mask = _mm256_set_epi64x(0x0808080808080808, 0x0404040404040404, 0x0202020202020202, 0x0101010101010101); + auto vdelta = _mm256_set1_epi8(delta); + auto vmin = _mm256_set1_epi8(min); + auto min1 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(extra, mask), mask))); + auto min2 = _mm256_add_epi8(vmin, _mm256_and_si256(vdelta, _mm256_cmpeq_epi8(_mm256_and_si256(_mm256_srli_epi16(extra, 4), mask), mask))); + auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(i8scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto m1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 0)), shuff); // blocks 0, 1, 2, 3 for each row + auto m2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min1, 1)), shuff); // blocks 4, 5, 6, 7 for each row + auto m3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 0)), shuff); // blocks 8, 9, 10, 11 for each row + auto m4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(min2, 1)), shuff); // blocks 12, 13, 14, 15 for each row + auto s1 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 0), _mm256_extracti128_si256(m1, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0))); // blocks 0, 1, 8, 9 + auto s2 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m3, 1), _mm256_extracti128_si256(m1, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1))); // blocks 2, 3, 10, 11 + auto s3 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 0), _mm256_extracti128_si256(m2, 0)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0))); // blocks 4, 5, 12, 13 + auto s4 = _mm256_mullo_epi16(MM256_SET_M128I(_mm256_extracti128_si256(m4, 1), _mm256_extracti128_si256(m2, 1)), + MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1))); // blocks 6, 7, 14, 15 + for (int iy = 0; iy < nrc_y; ++iy) { + auto bsums = q8.load_bsums(iy, ibl); +#ifdef HAVE_FANCY_SIMD + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s1, _mm256_shuffle_epi32(bsums, 0x00)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s2, _mm256_shuffle_epi32(bsums, 0x55)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s3, _mm256_shuffle_epi32(bsums, 0xaa)); + isum[iy] = _mm256_dpwssd_epi32(isum[iy], s4, _mm256_shuffle_epi32(bsums, 0xff)); +#else + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa))); + isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff))); +#endif + } +} + template static void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); @@ -4285,8 +4326,8 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI values[0] = MM256_SET_M128I(val1, val1); values[1] = MM256_SET_M128I(val2, val2); #ifdef HAVE_FANCY_SIMD - values[0] = _mm256_add_epi8(values[0], _mm256_set1_epi8(127)); - values[1] = _mm256_add_epi8(values[1], _mm256_set1_epi8(127)); + values[0] = _mm256_sub_epi8(values[0], _mm256_set1_epi8(-128)); + values[1] = _mm256_sub_epi8(values[1], _mm256_set1_epi8(-128)); #endif } #ifdef HAVE_FANCY_SIMD @@ -4295,6 +4336,7 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI #else auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100); #endif + auto m128 = _mm256_set1_epi8(-128); int nbl = n / QK_K; __m256 acc[nrc_y] = {}; __m256i qx[4]; @@ -4316,7 +4358,11 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI _mm256_storeu_si256((__m256i *)stored_scales+1, i8scales2); __m256i isum[nrc_y] = {}; #ifdef HAVE_FANCY_SIMD - iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -127); + if constexpr (nrc_y == 1) { + iq234_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, isum, -128); + } else { + iq2345_k_accum_mins(ibl, i8scales1, i8scales2, q8, shuff, extra, isum, -128, 2); + } #endif for (int ib = 0; ib < QK_K/32; ++ib) { #ifdef HAVE_FANCY_SIMD @@ -4328,8 +4374,6 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1); auto hbits = _mm_loadu_si128((const __m128i *)iq5[ibl].qh+ib); auto hb = MM256_SET_M128I(_mm_srli_epi16(hbits, 2), hbits); - auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); - shift = _mm256_shuffle_epi8(shift, shift_shuffle); qx[0] = _mm256_and_si256(lbits1, m4); qx[1] = _mm256_and_si256(lbits2, m4); qx[2] = _mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4); @@ -4352,31 +4396,41 @@ static void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataI // qx[2] = _mm256_add_epi8(qx[2], shift); // qx[3] = _mm256_add_epi8(qx[3], shift); //#else - auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), _mm256_set1_epi8(-128)); + auto qh = _mm256_and_si256(_mm256_slli_epi16(hb, 7), m128); auto q5vl = _mm256_or_si256(qx[0], qh); - auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, _mm256_set1_epi8(-128))); + auto q5vh = _mm256_or_si256(qx[0], _mm256_xor_si256(qh, m128)); qx[0] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); - qx[0] = _mm256_add_epi8(qx[0], shift); - qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), _mm256_set1_epi8(-128)); + qh = _mm256_and_si256(_mm256_slli_epi16(hb, 3), m128); q5vl = _mm256_or_si256(qx[1], qh); - q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, _mm256_set1_epi8(-128))); + q5vh = _mm256_or_si256(qx[1], _mm256_xor_si256(qh, m128)); qx[1] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); - qx[1] = _mm256_add_epi8(qx[1], shift); - qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), _mm256_set1_epi8(-128)); + qh = _mm256_and_si256(_mm256_slli_epi16(hb, 6), m128); q5vl = _mm256_or_si256(qx[2], qh); - q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, _mm256_set1_epi8(-128))); + q5vh = _mm256_or_si256(qx[2], _mm256_xor_si256(qh, m128)); qx[2] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); - qx[2] = _mm256_add_epi8(qx[2], shift); - qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), _mm256_set1_epi8(-128)); + qh = _mm256_and_si256(_mm256_slli_epi16(hb, 2), m128); q5vl = _mm256_or_si256(qx[3], qh); - q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, _mm256_set1_epi8(-128))); + q5vh = _mm256_or_si256(qx[3], _mm256_xor_si256(qh, m128)); qx[3] = _mm256_or_si256(_mm256_shuffle_epi8(values[0], q5vl), _mm256_shuffle_epi8(values[1], q5vh)); - qx[3] = _mm256_add_epi8(qx[3], shift); + if constexpr (nrc_y == 1) { + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); + } #ifndef HAVE_FANCY_SIMD + auto shift = _mm256_and_si256(ms, _mm256_slli_epi16(extra, 1)); extra = _mm256_srli_epi16(extra, 1); + shift = _mm256_shuffle_epi8(shift, shift_shuffle); + qx[0] = _mm256_add_epi8(qx[0], shift); + qx[1] = _mm256_add_epi8(qx[1], shift); + qx[2] = _mm256_add_epi8(qx[2], shift); + qx[3] = _mm256_add_epi8(qx[3], shift); auto s1 = _mm256_sign_epi8(qx[0], qx[0]); auto s2 = _mm256_sign_epi8(qx[1], qx[1]); auto s3 = _mm256_sign_epi8(qx[2], qx[2]);