diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ddb15ae6..f91acdad 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1110,11 +1110,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq2_xs, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_xs_ref, .vec_dot = ggml_vec_dot_iq2_xs_q8_K, -#ifdef __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_2_X4, -#else .vec_dot_type = GGML_TYPE_Q8_K, -#endif .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp index 59590b61..92bb9508 100644 --- a/ggml/src/iqk/iqk_gemm_iquants.cpp +++ b/ggml/src/iqk/iqk_gemm_iquants.cpp @@ -1839,6 +1839,62 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI } } +inline float convert_to_q8_k_r8(int k, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) { + auto max_i16 = _mm256_setzero_si256(); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1])); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l)); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h)); + } + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + bool needs_scaling = true; + float dnew = _mm_cvtss_f32(max4) / 127; + if (dnew < 1.f) { + dnew = 1.f; needs_scaling = false; + } + auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(scales[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(scales[2*ib32+1])); + if (needs_scaling) { + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + } else { + // 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 + auto i0 = _mm256_packs_epi16(q16_l, q16_h); + auto i0_l = _mm256_castsi256_si128(i0); + auto i0_h = _mm256_extracti128_si256(i0, 1); + _mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h)); + _mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h)); + } + auto qs = (uint32_t *)q8_k + 64*ib32; + for (int l = 0; l < 8; ++l) { + qs[8*l + k] = block[l]; + } + } + return dnew; +} + void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -1888,6 +1944,50 @@ void iqk_convert_iq2_xxs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, i } } +void iqk_convert_iq2_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_iq2_xs * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + uint32_t block[8]; + + union { __m256i vec; int16_t val[16]; } helper; + __m256i qx[8]; + +#if !(defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__) + DequantizerIQ2XS::Helper sign_helper; +#endif + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_iq2_xs *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = 0.125f * GGML_FP16_TO_FP32(x8[k][i].d); + helper.vec = DequantizerIQ2XS::make_scales(x8[k][i].scales); + auto q2l = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+0); + auto q2h = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+1); + DequantizerIQ2XS::make4(q2l, _mm256_set1_epi16(511), qx+0); + DequantizerIQ2XS::make4(q2h, _mm256_set1_epi16(511), qx+4); +#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__ + DequantizerIQ2XS::sign_values_popcnt(q2l, qx+0); + DequantizerIQ2XS::sign_values_popcnt(q2h, qx+4); +#else + DequantizerIQ2XS::sign_values_helper(q2l, sign_helper, qx+0); + DequantizerIQ2XS::sign_values_helper(q2h, sign_helper, qx+4); +#endif + float dnew = convert_to_q8_k_r8(k, qx, helper.val, block, y[i].qs); + y[i].d[k] = GGML_FP32_TO_FP16(d*dnew); + } + } + y += nb; + } +} + void iqk_convert_iq2_xs_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -2505,14 +2605,14 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ4_KT : return nrc_y >= 32 ? GGML_TYPE_F32 : type; case GGML_TYPE_IQ2_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_IQ2_S : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ2_XS : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; + case GGML_TYPE_IQ2_S : return nrc_y >= 16 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;