Perhaps a slightly better version for IQ2_XXS, IQ3_XXS, IQ3_S GEMV (#524)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-06-13 07:55:57 +03:00
committed by GitHub
parent dc663fe632
commit fb30146ce8

View File

@@ -145,35 +145,6 @@ struct SignHelper {
const __m256i mone = _mm256_set1_epi8(1);
};
// for (int i = 0; i < nb; ++i) {
//
// __m256i sumi[nrc_y], all_scales;
// //for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = _mm256_setzero_si256();
// __m256i mins;
// float dmin = deq.new_block(i, &all_scales, mins);
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto bsums = q8.load_bsums(iy, i);
// auto prod = _mm256_madd_epi16(mins, bsums);
// accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(dmin*q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accd[iy]);
// }
//
// for (int j = 0; j < QK_K/128; ++j) {
// deq.prepare(i, j);
// set_scales_8(&all_scales, j, scales);
// //multiply_add_iq(deq.bits, scales, j, i, q8, sumi);
// multiply_add(deq.bits, scales, j, i, q8, sumi);
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
// accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
// }
// }
//
// for (int iy = 0; iy < nrc_y; ++iy) {
// info.store(ix, iy, hsum_float_8(accd[iy]));
// }
// }
struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
DequantizerIQ2XXS(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
@@ -221,7 +192,7 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
}
IQK_ALWAYS_INLINE void sign_values(const uint32_t * aux32, __m256i * values) const {
#if defined HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
#if defined z_HAVE_FANCY_SIMD && defined __AVX512VPOPCNTDQ__
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[3]), _mm_set1_epi32(aux32[1])), values+0);
esh.sign_2_values(MM256_SET_M128I(_mm_set1_epi32(aux32[7]), _mm_set1_epi32(aux32[5])), values+2);
#else
@@ -246,7 +217,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
}
inline void prepare(int i, int j, const Q8<1>& q8, __m256i * q8_quants) {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
make4(data.val, bits.values, q8_quants);
}
inline void prepare(int i, int j, __m256i * q8_quants) {
Data data; data.vec = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
make4(data.val, bits.values, q8_quants);
}
@@ -526,6 +501,13 @@ struct DequantizerIQ3XXS final : public BaseDequantizer<block_iq3_xxs> {
sign_2_values(signs+0, q8_quants+0);
sign_2_values(signs+4, q8_quants+2);
}
inline void prepare(int i, int j, __m256i * q8_quants) {
auto qs = x[i].qs + 32*j;
const uint16_t * signs = (const uint16_t *)(x[i].qs + QK_K/4) + 8*j;
make4_unsigned(qs, bits.values);
sign_2_values(signs+0, q8_quants+0);
sign_2_values(signs+4, q8_quants+2);
}
constexpr static int minv = 64;
@@ -625,6 +607,10 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
for (int k = 0; k < 4; ++k) q8_quants[k] = q8.load_quants(0, i, 4*j+k);
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
}
inline void prepare(int i, int j, __m256i * q8_quants) {
prepare_unsigned(i, j);
sh.sign_4_values((const uint16_t *)x[i].signs + 8*j, q8_quants);
}
inline void prepare_unsigned(int i, int j) {
auto qs = x[i].qs + 32*j;
@@ -787,15 +773,69 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
}
}
template <typename Dequantizer, int nrc_y>
template <int n_sum>
inline __m256i compute_dot_4(const __m256i * x, const __m256i * y) {
#ifdef HAVE_FANCY_SIMD
auto sumi0 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[0], y[0]);
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[1], y[1]);
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[2], y[2]);
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), x[3], y[3]);
sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
#else
auto m1 = _mm256_set1_epi16(1);
if constexpr (n_sum == 2) {
auto sumi0 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[0], y[0]));
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[1], y[1]));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[2], y[2]));
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(x[3], y[3]));
sumi0 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
}
else {
auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
if constexpr (n_sum == 4) {
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
sumi0 = _mm256_madd_epi16(m1, sumi0);
sumi2 = _mm256_madd_epi16(m1, sumi2);
return _mm256_add_epi32(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
}
else {
auto sumi0 = _mm256_maddubs_epi16(x[0], y[0]);
auto sumi1 = _mm256_maddubs_epi16(x[1], y[1]);
auto sumi2 = _mm256_maddubs_epi16(x[2], y[2]);
auto sumi3 = _mm256_maddubs_epi16(x[3], y[3]);
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi0, sumi1), _mm256_unpackhi_epi32(sumi0, sumi1));
sumi2 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi2, sumi3), _mm256_unpackhi_epi32(sumi2, sumi3));
sumi0 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi0, sumi2), _mm256_unpackhi_epi64(sumi0, sumi2));
return _mm256_madd_epi16(m1, sumi0);
}
}
#endif
}
template <typename Dequantizer, int nrc_y, int n_sum = 2>
static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
static_assert(Dequantizer::num_blocks == 8);
static_assert(n_sum == 2 || n_sum == 4 || n_sum == 8);
#ifdef HAVE_FANCY_SIMD
constexpr bool use_1_row = nrc_y == 1;
#else
constexpr bool use_1_row = nrc_y == 1 && !std::is_same_v<Dequantizer, DequantizerIQ2XXS>;
#endif
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
__m256 scales[3];
__m256 accd[nrc_y];
__m256i sumi[4];
__m256i vy[4];
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -806,35 +846,33 @@ static void mul_mat_qX_K_q8_2_IQ_N(int n, const void * vx, size_t bx, const Data
for (int i = 0; i < nb; ++i) {
deq.new_block_f(i, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
if constexpr (!use_1_row) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto my1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d + 4)));
auto my2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d + 4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(my2, my1), 16));
accd[iy] = _mm256_fmadd_ps(scales[2], my, accd[iy]);
}
}
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
auto& values = deq.bits.values;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qs = q8.y[iy][2*i+j].qs;
#ifdef HAVE_FANCY_SIMD
sumi[0] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[0], _mm256_loadu_si256((const __m256i*)qs+0));
sumi[1] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[1], _mm256_loadu_si256((const __m256i*)qs+1));
sumi[2] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[2], _mm256_loadu_si256((const __m256i*)qs+2));
sumi[3] = _mm256_dpbusd_epi32(_mm256_setzero_si256(), values[3], _mm256_loadu_si256((const __m256i*)qs+3));
#else
sumi[0] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[0], _mm256_loadu_si256((const __m256i*)qs+0)));
sumi[1] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[1], _mm256_loadu_si256((const __m256i*)qs+1)));
sumi[2] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[2], _mm256_loadu_si256((const __m256i*)qs+2)));
sumi[3] = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(values[3], _mm256_loadu_si256((const __m256i*)qs+3)));
#endif
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[0], sumi[1]), _mm256_unpackhi_epi32(sumi[0], sumi[1]));
sumi[2] = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi[2], sumi[3]), _mm256_unpackhi_epi32(sumi[2], sumi[3]));
sumi[0] = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi[0], sumi[2]), _mm256_unpackhi_epi64(sumi[0], sumi[2]));
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
if constexpr (use_1_row) {
for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)q8.y[0][2*i+j].qs+k);
deq.prepare(i, j, vy);
auto sumi = compute_dot_4<2*n_sum>(deq.bits.values, vy);
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[0][2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi[0]), accd[iy]);
accd[0] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[0]);
} else {
deq.prepare(i, j);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qs = q8.y[iy][2*i+j].qs;
for (int k = 0; k < 4; ++k) vy[k] = _mm256_loadu_si256((const __m256i*)qs+k);
auto sumi = compute_dot_4<n_sum>(deq.bits.values, vy);
auto d4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][2*i+j].d)), 16));
auto dy = _mm256_set_m128(d4, d4);
accd[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales[j], dy), _mm256_cvtepi32_ps(sumi), accd[iy]);
}
}
}
}
@@ -1934,7 +1972,15 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
if (ggml_type(typeA) == GGML_TYPE_IQ3_S) {
if (ggml_type(typeB) == GGML_TYPE_Q8_2_X4) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
//IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_IQ_N, DequantizerIQ3S, kernels);
kernels[0] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 1, 8>;
kernels[1] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 2, 8>;
kernels[2] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 3, 8>;
kernels[3] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 4, 8>;
kernels[4] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 5, 8>;
kernels[5] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 6, 8>;
kernels[6] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 7, 8>;
kernels[7] = mul_mat_qX_K_q8_2_IQ_N<DequantizerIQ3S, 8, 8>;
func16 = nullptr;
return true;
}