q6_K dequantizing GEMM

This commit is contained in:
Iwan Kawrakow
2025-06-13 15:22:03 +03:00
parent 066ed4fd11
commit 853d581de0
3 changed files with 92 additions and 1 deletions

View File

@@ -1036,7 +1036,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q6_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
.vec_dot = ggml_vec_dot_q6_K_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},

View File

@@ -6,6 +6,7 @@
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#include "ggml-quants.h"
#ifdef __x86_64__
@@ -1982,6 +1983,90 @@ void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int
}
}
void iqk_convert_q6_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q6_K * x8[8];
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
float all_s[64];
uint32_t block[8];
__m256i values[8];
auto ml = _mm256_set1_epi8(0x0f);
auto mh = _mm256_set1_epi8(0x30);
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q6_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
auto ql = x8[k][i].ql;
auto qh = x8[k][i].qh;
for (int i128 = 0; i128 < 2; ++i128) {
auto lbits1 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 1);
auto hbits = _mm256_loadu_si256((const __m256i *)qh + i128);
values[4*i128+0] = _mm256_or_si256(_mm256_and_si256(lbits1, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
values[4*i128+1] = _mm256_or_si256(_mm256_and_si256(lbits2, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
values[4*i128+2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), ml), _mm256_and_si256(hbits, mh));
values[4*i128+3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
// We have two blocks of 16 with different scales
// We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale.
auto q8 = _mm256_add_epi8(values[ib32], _mm256_set1_epi8(-32));
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8));
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8, 1));
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(x8[k][i].scales[2*ib32+0]));
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(x8[k][i].scales[2*ib32+1]));
auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
auto max4 = _mm_cvtepi32_ps(imax4);
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
float max = _mm_cvtss_f32(max4) / 127;
all_s[8*ib32+k] = d*max;
if (max > 1e-9f) {
auto scale = _mm256_set1_ps(1/max);
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_si256((__m256i *)block, i0);
} else {
_mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
}
auto qs = (uint32_t *)y[ib32].qs;
for (int l = 0; l < 4; ++l) {
qs[8*l + k + 0] = block[l + 0];
qs[8*l + k + 32] = block[l + 4];
}
}
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
} // namespace
@@ -2066,6 +2151,7 @@ bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, voi
switch (ggml_type(type)) {
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;

View File

@@ -245,6 +245,7 @@ struct MulMat {
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q6_K : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#else
@@ -347,7 +348,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
//case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ4_XS:
//case GGML_TYPE_Q2_K_R4:
//case GGML_TYPE_Q3_K_R4: