q4_K: GEMM with q8_2_X4

This commit is contained in:
Iwan Kawrakow
2025-06-12 12:30:23 +03:00
parent 8de4c019d0
commit 4b8f765870

View File

@@ -719,6 +719,135 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif
// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
// make_q4_scales(data, utmp);
// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
// accum_mins(mins128, q8, i, c, accd);
// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
// return MM256_SET_M128I(sc128, sc128);
// }
//
// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
// d = GGML_FP16_TO_FP32(x[i].d);
// bits.prepare(x[i].qs);
// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
// }
struct Q4Bits_AVX2 {
inline void prepare(const uint8_t * q4, int j) {
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
values[0] = _mm256_and_si256(q4bits, ml);
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
values[2] = _mm256_and_si256(q4bits, ml);
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
}
__m256i values[4];
const __m256i ml = _mm256_set1_epi8(0xf);
};
struct DequantizerQ4K_AVX2 final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
d = GGML_FP16_TO_FP32(x[i].d);
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
}
Q4Bits_AVX2 bits;
Scales8K s8k;
};
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
uint32_t utmp[4];
__m256 accd[nrc_y];
__m256 scales[2];
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin));
make_q4_scales(deq.x[i].scales, utmp);
auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2)))));
mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins);
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
_mm256_storeu_ps(d8 + 8*iy, dy);
auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16));
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
}
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales));
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
scales[1] = _mm256_set_m128(scales_h, scales_h);
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& y = q8.y[iy][2*i+j];
#ifdef z_HAVE_FANCY_SIMD
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
#else
auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
#endif
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
@@ -1781,6 +1910,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: etypeA == GGML_TYPE_Q4_K ? GGML_TYPE_Q8_2_X4
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
@@ -1797,7 +1927,8 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
set_functions<DequantizerQ3K>(kernels);
break;
case GGML_TYPE_Q4_K:
set_functions<DequantizerQ4K>(kernels);
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels);
//set_functions<DequantizerQ4K>(kernels);
break;
case GGML_TYPE_Q5_K:
set_functions<DequantizerQ5K>(kernels);