From e10f7d1f10e114c151bf11dac9e1862ad1bf55d3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 15 Jun 2025 10:37:12 +0300 Subject: [PATCH] q3_K: repack to q8_k_r8 instead of q8_0_r8 With that we hit 360 t/s for LlaMA-3.1-8B on a Ryzen-7950X. q8_k_r8 is 386 t/s, so for a batch size of 512 repacking costs ~7% of the time taken by the actual GEMM. --- ggml/src/ggml-common.h | 3 +- ggml/src/ggml.c | 6 +- ggml/src/iqk/iqk_gemm_kquants.cpp | 128 +++++++++++++++++++++++------- ggml/src/iqk/iqk_mul_mat.cpp | 2 +- ggml/src/iqk/iqk_quantize.cpp | 34 +++++--- 5 files changed, 125 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 5fe27b29..2bfe5d39 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -386,10 +386,11 @@ static_assert(sizeof(block_q6_k_r4) == 4*sizeof(ggml_half) + QK_K/4 + 3*QK_K, "w // This is only used for intermediate quantization and dot products typedef struct { float d; // delta + float sum; // sum of quants in the entire block int8_t qs[QK_K]; // quants int16_t bsums[QK_K/16]; // sum of quants in groups of 16 } block_q8_K; -static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); +static_assert(sizeof(block_q8_K) == 2*sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); typedef struct { float d; // delta int8_t qs[64]; // quants diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 96d581cc..4d8dedf0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -950,11 +950,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q3_K, .from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref, .vec_dot = ggml_vec_dot_q3_K_q8_K, -#ifdef __AVX2__ - .vec_dot_type = GGML_TYPE_Q8_2_X4, -#else .vec_dot_type = GGML_TYPE_Q8_K, -#endif .nrows = 1, .row_meta_size = 0, }, @@ -1071,7 +1067,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_q8_k_r8, .from_float_ref = (ggml_from_float_t) quantize_row_q8_k_r8_ref, .vec_dot = vec_dot_q8_k_r8_q8_k, - .vec_dot_type = GGML_TYPE_Q8_KR8, + .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, .row_meta_size = 0, }, diff --git a/ggml/src/iqk/iqk_gemm_kquants.cpp b/ggml/src/iqk/iqk_gemm_kquants.cpp index c8f3020c..c173b4dc 100644 --- a/ggml/src/iqk/iqk_gemm_kquants.cpp +++ b/ggml/src/iqk/iqk_gemm_kquants.cpp @@ -1845,8 +1845,7 @@ static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataIn auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))); acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]); #ifdef HAVE_FANCY_SIMD - auto bsums = (const float *)q8.y[iy][ibl].bsums; - acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(bsums[0]), acc[iy]); + acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(q8.y[iy][ibl].sum), acc[iy]); #endif isum[iy] = _mm256_setzero_si256(); } @@ -2236,27 +2235,6 @@ void iqk_convert_q6_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int } } -//struct DequantizerQ3K final : public BaseDequantizer { -// DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {} -// -// template -// inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) { -// d = GGML_FP16_TO_FP32(x[i].d); -// hbits.load(x[i].hmask); -// process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales); -// } -// inline void prepare(int i, int j) { -// bits.prepare(x[i].qs, j); -// hbits.apply(bits, j == 0); -// } -// -// Q2Bits bits; -// HighBit3 hbits; -// ScaleQ3 sc3; -// -// const __m128i m32 = _mm_set1_epi8(-32); -//}; - void iqk_convert_q3_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { GGML_ASSERT(n%QK_K == 0); GGML_ASSERT(nrc_x%8 == 0); @@ -2348,6 +2326,97 @@ void iqk_convert_q3_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int } } +void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) { + GGML_ASSERT(n%QK_K == 0); + GGML_ASSERT(nrc_x%8 == 0); + + int nb = n/QK_K; + + const block_q3_K * x8[8]; + + block_q8_k_r8 * y = (block_q8_k_r8 *)vy; + + uint32_t block[8]; + __m256i values[8]; + + ScaleQ3 sc3; + auto ml = _mm256_set1_epi8(0x03); + auto mh = _mm256_set1_epi8(0x04); + + union { __m256i vec; int16_t val[16]; } helper; + + for (int ix = 0; ix < nrc_x; ix += 8) { + for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx); + for (int i = 0; i < nb; ++i) { + for (int k = 0; k < 8; ++k) { + float d = GGML_FP16_TO_FP32(x8[k][i].d); + auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask); + helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales)); + auto max_i16 = _mm256_setzero_si256(); + for (int i128 = 0; i128 < 2; ++i128) { + auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128); + values[4*i128+0] = _mm256_and_si256(q2bits, ml); + values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml); + values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml); + values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml); + values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)); + values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)); + values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh)); + values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)); + values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh); + values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh); + values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh); + values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh); + hbits = _mm256_srli_epi16(hbits, 4); + + for (int l = 0; l < 4; ++l) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[4*i128+l])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[4*i128+l], 1)); + q16_l = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+0]), q16_l); + q16_h = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+1]), q16_h); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l)); + max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h)); + } + } + auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1))); + auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1)); + auto max4 = _mm_cvtepi32_ps(imax4); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float dnew = std::max(1.f, _mm_cvtss_f32(max4) / 127); + d *= dnew; + y[i].d[k] = GGML_FP32_TO_FP16(d); + auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f); + for (int ib32 = 0; ib32 < 8; ++ib32) { + auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32])); + auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1)); + q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0])); + q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1])); + auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l)); + auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1)); + auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h)); + auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1)); + i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST)); + i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST)); + i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST)); + i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST)); + i0 = _mm256_packs_epi32(i0, i1); + i2 = _mm256_packs_epi32(i2, i3); + i0 = _mm256_packs_epi16(i0, i2); + i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7)); + _mm256_storeu_si256((__m256i *)block, i0); + + auto qs = (uint32_t *)y[i].qs + 64*ib32; + for (int l = 0; l < 8; ++l) { + qs[8*l + k] = block[l]; + } + } + } + } + y += nb; + } +} + } // namespace @@ -2355,10 +2424,11 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array(kernels); break; case GGML_TYPE_Q3_K: - //set_functions(kernels); - IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ3K_AVX2, kernels); + set_functions(kernels); + //IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ3K_AVX2, kernels); break; case GGML_TYPE_Q4_K: IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels); @@ -2434,7 +2504,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; - case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type; + case GGML_TYPE_Q3_K : return nrc_y >= 32 ? GGML_TYPE_Q8_K_R8 : type; case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type; case GGML_TYPE_Q6_K : return nrc_y >= 64 ? GGML_TYPE_Q8_0_R8 : type; diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 2eb53d1c..9261d02e 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -2831,6 +2831,8 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { const __m256 mul = _mm256_set1_ps( id ); xx = xb; int8_t * q8 = y[i].qs; + int block_sum_i32 = 0; + float block_sum_f32 = 0; for (int ib = 0; ib < QK_K/32; ++ib) { __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8; @@ -2844,13 +2846,15 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { __m256i i1 = _mm256_cvtps_epi32(v1); __m256i i2 = _mm256_cvtps_epi32(v2); __m256i i3 = _mm256_cvtps_epi32(v3); - if constexpr (q8_type > 0) { + if constexpr (q8_type == 1) { int bsum = hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); auto bs = (float *)y[i].bsums; bs[ib] = d*bsum; + block_sum_f32 += bs[ib]; } else { y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1)); y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3)); + block_sum_i32 += y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]; } i0 = _mm256_packs_epi32( i0, i1 ); i2 = _mm256_packs_epi32( i2, i3 ); @@ -2859,12 +2863,17 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { _mm256_storeu_si256((__m256i *)q8, i0); q8 += 32; } - if constexpr (q8_type == 2) { - auto bs = (float *)y[i].bsums; - float sum = 0; - for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib]; - bs[0] = sum; + if constexpr (q8_type == 1) { + y[i].sum = block_sum_f32; + } else { + y[i].sum = d*block_sum_i32; } + //if constexpr (q8_type == 2) { + // auto bs = (float *)y[i].bsums; + // float sum = 0; + // for (int ib = 0; ib < QK_K/32; ++ib) sum += bs[ib]; + // bs[0] = sum; + //} } #else for (int i = 0; i < nb; i++) { @@ -2890,9 +2899,9 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { int v = nearest_int(iscale*x[j]); y[i].qs[j] = MIN(127, v); } - if constexpr (q8_type > 0) { + float d = 1/iscale; + if constexpr (q8_type == 1) { auto bs = (float *)y[i].bsums; - float d = 1/iscale; float sum = 0; for (int j = 0; j < QK_K/32; ++j) { int sum = 0; @@ -2902,19 +2911,20 @@ void iqk_quantize_row_q8_K_T(const float * x, void * vy, int64_t k) { bs[j] = d*sum; sum += bs[j]; } - if constexpr (q8_type == 2) { - bs[0] = sum; - } + y[i].sum = sum; } else { + int tot = 0; for (int j = 0; j < QK_K/16; ++j) { int sum = 0; for (int ii = 0; ii < 16; ++ii) { sum += y[i].qs[j*16 + ii]; } y[i].bsums[j] = sum; + tot += sum; } + y[i].sum = d*tot; } - y[i].d = 1/iscale; + y[i].d = d; x += QK_K; } #endif