diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index cd7ff843..44a5df4e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1200,7 +1200,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq2_s, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref, .vec_dot = ggml_vec_dot_iq2_s_q8_K, +#ifdef __AVX2__ + .vec_dot_type = GGML_TYPE_Q8_2_X4, +#else .vec_dot_type = GGML_TYPE_Q8_K, +#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp index eb79bb12..7f30ce83 100644 --- a/ggml/src/iqk/iqk_gemm_iquants.cpp +++ b/ggml/src/iqk/iqk_gemm_iquants.cpp @@ -384,13 +384,16 @@ struct DequantizerIQ2S final : public BaseDequantizer { constexpr static int num_blocks = 16; - inline __m256i load_scales(int i) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - auto tmp = _mm_loadl_epi64((const __m128i *)x[i].scales); + static inline __m256i make_scales(const uint8_t * scales) { + auto tmp = _mm_loadl_epi64((const __m128i *)scales); auto all = _mm_and_si128(_mm_unpacklo_epi8(tmp, _mm_srli_epi16(tmp, 4)), _mm_set1_epi8(0xf)); auto scales8 = _mm_or_si128(_mm_slli_epi16(all, 1), _mm_set1_epi8(1)); return _mm256_cvtepi8_epi16(scales8); } + inline __m256i load_scales(int i) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return make_scales(x[i].scales); + } inline static void prepare_scales(const __m256i& all, __m256i * scales) { auto scales_l = _mm256_castsi256_si128(all); auto scales_h = _mm256_extractf128_si256(all, 1); @@ -445,6 +448,38 @@ struct DequantizerIQ2S final : public BaseDequantizer { q8_quants[2] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+2), sh.make_signs(signs[4] | (signs[5] << 16))); q8_quants[3] = _mm256_sign_epi8(q8.load_quants(0, i, 4*j+3), sh.make_signs(signs[6] | (signs[7] << 16))); } + static inline void prepare(const uint8_t * qs, const uint8_t * qh, const uint16_t * signs, const SignHelper& sh, __m256i * values) { + auto idx_shift = _mm256_set_epi32(2, 4, 6, 8, 2, 4, 6, 8); + auto idx_mask = _mm256_set1_epi32(0x300); + make2(qs+0, qh+0, idx_shift, idx_mask, values+0); + make2(qs+8, qh+2, idx_shift, idx_mask, values+2); + values[0] = _mm256_sign_epi8(values[0], sh.make_signs(signs[0] | (signs[1] << 16))); + values[1] = _mm256_sign_epi8(values[1], sh.make_signs(signs[2] | (signs[3] << 16))); + values[2] = _mm256_sign_epi8(values[2], sh.make_signs(signs[4] | (signs[5] << 16))); + values[3] = _mm256_sign_epi8(values[3], sh.make_signs(signs[6] | (signs[7] << 16))); + } + inline void prepare_signed(int i, int j, __m256i * us, __m256i * s) { + auto qs = x[i].qs + 16*j; + auto qh = x[i].qh + 4*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j; + make2(qs+0, qh+0, idx_shift, idx_mask, us+0); + make2(qs+8, qh+2, idx_shift, idx_mask, us+2); + s[0] = _mm256_sign_epi8(s[0], sh.make_signs(signs[0] | (signs[1] << 16))); + s[1] = _mm256_sign_epi8(s[1], sh.make_signs(signs[2] | (signs[3] << 16))); + s[2] = _mm256_sign_epi8(s[2], sh.make_signs(signs[4] | (signs[5] << 16))); + s[3] = _mm256_sign_epi8(s[3], sh.make_signs(signs[6] | (signs[7] << 16))); + } + inline void prepare_signed(int i, int j, __m256i * us) { + auto qs = x[i].qs + 16*j; + auto qh = x[i].qh + 4*j; + const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/8) + 8*j; + make2(qs+0, qh+0, idx_shift, idx_mask, us+0); + make2(qs+8, qh+2, idx_shift, idx_mask, us+2); + bits.values[0] = _mm256_sign_epi8(us[0], sh.make_signs(signs[0] | (signs[1] << 16))); + bits.values[1] = _mm256_sign_epi8(us[1], sh.make_signs(signs[2] | (signs[3] << 16))); + bits.values[2] = _mm256_sign_epi8(us[2], sh.make_signs(signs[4] | (signs[5] << 16))); + bits.values[3] = _mm256_sign_epi8(us[3], sh.make_signs(signs[6] | (signs[7] << 16))); + } constexpr static int minv = 43; @@ -2058,6 +2093,200 @@ static void mul_mat_iq2_xs_q8_2_X4(int n, const void * vx, size_t bx, const Data } } +void iqk_convert_iq2_s_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq2_s * x8[8]; + + block_q8_0_r8 * y = (block_q8_0_r8 *)vy; + + float all_s[64]; + + uint32_t block[8]; + + union { __m256i vec; int16_t val[16]; } helper; + __m256i qx[8]; + + SignHelper sh; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_s *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d); + helper.vec = DequantizerIQ2S::make_scales(x8[k][i].scales); + DequantizerIQ2S::prepare(x8[k][i].qs+ 0, x8[k][i].qh+0, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 0, sh, qx+0); + DequantizerIQ2S::prepare(x8[k][i].qs+16, x8[k][i].qh+4, (const uint16_t *)(x8[k][i].qs + QK_K/8) + 8, sh, qx+4); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1])); + auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l); + auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h); + auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h); + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + float max = _mm_cvtss_f32(max4) / 127; + all_s[8*ib32+k] = d*max; + if (max > 1e-9f) { + auto scale = _mm256_set1_ps(1/max); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + _mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256()); + } + auto qs = (uint32_t *)y[ib32].qs; + for (int l = 0; l < 4; ++l) { + qs[8*l + k + 0] = block[l + 0]; + qs[8*l + k + 32] = block[l + 4]; + } + } + } + for (int ib32 = 0; ib32 < 8; ++ib32) { + _mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT)); + } + y += QK_K/32; + } + } +} + +template +static void mul_mat_iq2_s_q8_2_X4(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n % QK_K == 0); + const int nb = n / QK_K; + + Q8 q8(info); + + DequantizerIQ2S deq(vx, bx); + + __m256 accd[nrc_y]; + __m256 scales[2]; + float d8[8*nrc_y]; + __m256i us[4]; + + uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; + auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff); + + for (int ix = 0; ix < nrc_x; ++ix) { + + for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps(); + + deq.new_row(ix); + + for (int i = 0; i < nb; ++i) { + + deq.d = 0.125f * GGML_FP16_TO_FP32(deq.x[i].d); + auto vd = _mm256_set1_ps(deq.d); + auto sc16 = _mm256_shuffle_epi8(DequantizerIQ2S::make_scales(deq.x[i].scales), shuff); + scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16)))); + scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1)))); + for (int iy = 0; iy < nrc_y; ++iy) { + auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d))); + auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d))); + auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16)); + if constexpr (nrc_y == 1) { + auto dyh = _mm256_extractf128_ps(dy, 1); + scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy))); + scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh)); + } else { + _mm256_storeu_ps(d8 + 8*iy, dy); + } + } + + for (int j = 0; j < QK_K/128; ++j) { + + if constexpr (nrc_y == 1) { + auto qs = q8.y[0][2*i+j].qs; + for (int k = 0; k < 4; ++k) us[k] = _mm256_loadu_si256((const __m256i*)qs+k); + deq.prepare_signed(i, j, deq.bits.values, us); +#ifdef HAVE_FANCY_SIMD + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], us[0]); + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], us[1]); + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], us[2]); + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], us[3]); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); +#else + auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], us[0]); + auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], us[1]); + auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], us[2]); + auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], us[3]); + sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2))); + sumi3 = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4))); + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); +#endif + accd[0] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[0]); + } + else { + deq.prepare_signed(i, j, us); + + for (int iy = 0; iy < nrc_y; ++iy) { + auto qs = q8.y[iy][2*i+j].qs; +#ifdef HAVE_FANCY_SIMD + // 0...31 + auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0])); + // 32...63 + auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1])); + // 64...95 + auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2])); + // 96...128 + auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3])); + // 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 + + // 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63 + // b0 b2 b0 b2 b1 b3 b1 b3 + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + // same as above + 64, so + // b4 b6, b4 b6 b5 b7 b5 b7 + sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + // b0 b2 b4 b6 b1 b3 b5 b7 + + // b0 b2 b4 b6 b1 b3 b5 b7 + sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); +#else + auto sumi1 = _mm256_maddubs_epi16(us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0])); + auto sumi2 = _mm256_maddubs_epi16(us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1])); + auto sumi3 = _mm256_maddubs_epi16(us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2])); + auto sumi4 = _mm256_maddubs_epi16(us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3])); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2)); + sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4)); + sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3)); + sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1); +#endif + auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j); + auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4)); + accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]); + } + } + + } + + } + + for (int iy = 0; iy < nrc_y; ++iy) { + info.store(ix, iy, hsum_float_8(accd[iy])); + } + + } +} + void iqk_convert_iq3_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); @@ -2200,6 +2429,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;