From d355ff997b2a2525eae447e53cff89a9c8de4a43 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 17 May 2025 15:45:15 +0300 Subject: [PATCH] Refactor iqk: fix AVX2 --- ggml/src/iqk/iqk_common.h | 63 +++++++++++++++++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 68 +----------------------------------- 2 files changed, 64 insertions(+), 67 deletions(-) diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index ed2142c8..19416386 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -291,6 +291,69 @@ struct BaseDequantizer { float d; }; +template +static inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { + if (j == 0) { +#ifdef HAVE_FANCY_SIMD + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); + } +#else + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); + const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); + const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); + const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); + sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); + } +#endif + } else { +#ifdef HAVE_FANCY_SIMD + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); + sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); + } +#else + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); + const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); + const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); + const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); + } +#endif + } +} + +template +static inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { + __m256i p[4]; + if (j == 0) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + for (int k = 0; k < 4; ++k) { + auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); + p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k]))); + } + sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3])); + } + } else { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + for (int k = 0; k < 4; ++k) { + auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); + p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k]))); + } + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2])); + sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3])); + } + } +} + #endif #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b9c14a44..a1f7de6b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -1438,69 +1438,6 @@ inline void set_scales_16(const __m256i& all_scales, __m256i * scales) { scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3)); } -template -inline void multiply_add(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { - if (j == 0) { -#ifdef HAVE_FANCY_SIMD - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - sumi[iy] = _mm256_dpwssd_epi32(_mm256_setzero_si256(), scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); - } -#else - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 0))); - const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 1))); - const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 2))); - const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 3))); - sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p1, p3), _mm256_add_epi32(p2, p4)); - } -#endif - } else { -#ifdef HAVE_FANCY_SIMD - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); - sumi[iy] = _mm256_dpwssd_epi32(sumi[iy], scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); - } -#else - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - const __m256i p1 = _mm256_madd_epi16(scales[0], _mm256_maddubs_epi16(bits.values[0], q8.load_quants(iy, i, 4))); - const __m256i p2 = _mm256_madd_epi16(scales[1], _mm256_maddubs_epi16(bits.values[1], q8.load_quants(iy, i, 5))); - const __m256i p3 = _mm256_madd_epi16(scales[2], _mm256_maddubs_epi16(bits.values[2], q8.load_quants(iy, i, 6))); - const __m256i p4 = _mm256_madd_epi16(scales[3], _mm256_maddubs_epi16(bits.values[3], q8.load_quants(iy, i, 7))); - sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p1, p3)); - sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p2, p4)); - } -#endif - } -} - -template -inline void multiply_add_avx2(const Bits& bits, const __m256i * scales, int j, int i, const Q8& q8, __m256i * sumi) { - __m256i p[4]; - if (j == 0) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - for (int k = 0; k < 4; ++k) { - auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); - p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, k), bits.values[k]))); - } - sumi[iy] = _mm256_add_epi32(_mm256_add_epi32(p[0], p[1]), _mm256_add_epi32(p[2], p[3])); - } - } else { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - for (int k = 0; k < 4; ++k) { - auto s = _mm256_sign_epi8(bits.values[k], bits.values[k]); - p[k] = _mm256_madd_epi16(scales[k], _mm256_maddubs_epi16(s, _mm256_sign_epi8(q8.load_quants(iy, i, 4+k), bits.values[k]))); - } - sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[0], p[2])); - sumi[iy] = _mm256_add_epi32(sumi[iy], _mm256_add_epi32(p[1], p[3])); - } - } -} - struct SignHelper { inline __m256i make_signs(uint32_t sign_bits) const { auto aux256 = _mm256_set1_epi32(sign_bits); @@ -8209,10 +8146,7 @@ template void MulMat::set_functions(MulMat& m) { m.funcs[7] = mul_mat_qX_K_q8_K_AVX512; } #else - if constexpr (std::is_same_v || - std::is_same_v || - std::is_same_v || - std::is_same_v|| + if constexpr (std::is_same_v|| std::is_same_v|| std::is_same_v|| std::is_same_v||