Faster CPU prompt processing for Q4_K and Q5_K (#525)

* q4_K: dequantize to q8_1_r8 for batch >= 32

We get 268 t/s, up from 186 t/s.

* q4_K: GEMM with q8_2_X4

* q5_K: GEMM with q8_2_X4 and repack to q8_1_r8

* Remove the scales, they are not needed

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-06-13 07:58:15 +03:00
committed by GitHub
parent ed868d928c
commit b7768e203f
5 changed files with 391 additions and 5 deletions

View File

@@ -976,7 +976,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q4_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
.vec_dot = ggml_vec_dot_q4_K_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},
@@ -1002,7 +1006,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.from_float = quantize_row_q5_K,
.from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
.vec_dot = ggml_vec_dot_q5_K_q8_K,
#ifdef __AVX2__
.vec_dot_type = GGML_TYPE_Q8_2_X4,
#else
.vec_dot_type = GGML_TYPE_Q8_K,
#endif
.nrows = 1,
.row_meta_size = 0,
},

View File

@@ -719,6 +719,147 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif
// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
// make_q4_scales(data, utmp);
// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
// accum_mins(mins128, q8, i, c, accd);
// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
// return MM256_SET_M128I(sc128, sc128);
// }
//
// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
// d = GGML_FP16_TO_FP32(x[i].d);
// bits.prepare(x[i].qs);
// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
// }
struct Q4Bits_AVX2 {
inline void prepare(const uint8_t * q4, int j) {
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
values[0] = _mm256_and_si256(q4bits, ml);
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
values[2] = _mm256_and_si256(q4bits, ml);
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
}
__m256i values[4];
const __m256i ml = _mm256_set1_epi8(0xf);
};
struct DequantizerQ4K_AVX2 final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
}
Q4Bits_AVX2 bits;
};
struct DequantizerQ5K_AVX2 final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4);
apply_hbits();
}
inline void apply_hbits() {
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
}
const __m256i mh = _mm256_set1_epi8(0x10);
Q4Bits_AVX2 bits;
__m256i hbits;
};
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
uint32_t utmp[4];
__m256 accd[nrc_y];
__m256 scales[2];
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin));
make_q4_scales(deq.x[i].scales, utmp);
auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2)))));
mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins);
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
_mm256_storeu_ps(d8 + 8*iy, dy);
auto m4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
auto m4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
auto my = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(m4_2, m4_1), 16));
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
}
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales));
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
scales[1] = _mm256_set_m128(scales_h, scales_h);
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& y = q8.y[iy][2*i+j];
#ifdef HAVE_FANCY_SIMD
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
#else
auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
#endif
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
@@ -1702,6 +1843,146 @@ static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const Data
}
}
typedef struct {
ggml_half d[16];
int8_t qs[8*QK8_1];
} block_q8_1_r8;
void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q4_K * x8[8];
block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
ggml_half dh[16];
uint16_t all_ls[128];
uint32_t utmp[4];
const uint8_t * u8 = (const uint8_t *)utmp;
uint32_t block[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q4_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
dh[k+0] = x8[k][i].d;
dh[k+8] = x8[k][i].dmin;
make_q4_scales(x8[k][i].scales, utmp);
auto qs = x8[k][i].qs;
for (int ib64 = 0; ib64 < 4; ++ib64) {
all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
_mm256_storeu_si256((__m256i *)block, values1);
auto q8 = (uint32_t *)y[2*ib64+0].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
_mm256_storeu_si256((__m256i *)block, values2);
q8 = (uint32_t *)y[2*ib64+1].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
iscales32 = _mm256_cvtepi16_epi32(iscales16);
scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q5_K * x8[8];
block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
ggml_half dh[16];
uint16_t all_ls[128];
uint32_t utmp[4];
const uint8_t * u8 = (const uint8_t *)utmp;
uint32_t block[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q5_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
dh[k+0] = x8[k][i].d;
dh[k+8] = x8[k][i].dmin;
make_q4_scales(x8[k][i].scales, utmp);
auto qs = x8[k][i].qs;
auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
for (int ib64 = 0; ib64 < 4; ++ib64) {
all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
values1 = _mm256_or_si256(values1, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 4)));
values2 = _mm256_or_si256(values2, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 3)));
hbits = _mm256_srli_epi16(hbits, 2);
_mm256_storeu_si256((__m256i *)block, values1);
auto q8 = (uint32_t *)y[2*ib64+0].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
_mm256_storeu_si256((__m256i *)block, values2);
q8 = (uint32_t *)y[2*ib64+1].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
iscales32 = _mm256_cvtepi16_epi32(iscales16);
scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
} // namespace
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -1710,6 +1991,7 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
@@ -1726,10 +2008,12 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
set_functions<DequantizerQ3K>(kernels);
break;
case GGML_TYPE_Q4_K:
set_functions<DequantizerQ4K>(kernels);
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels);
//set_functions<DequantizerQ4K>(kernels);
break;
case GGML_TYPE_Q5_K:
set_functions<DequantizerQ5K>(kernels);
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ5K_AVX2, kernels);
//set_functions<DequantizerQ5K>(kernels);
break;
case GGML_TYPE_Q6_K:
set_functions<DequantizerQ6K>(kernels);
@@ -1778,6 +2062,15 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
}
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}
#else
// --------------------------------- __aarch64__ --------------------------------------

View File

@@ -10,4 +10,6 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step);
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x);
#endif

View File

@@ -1615,6 +1615,81 @@ static void mul_mat_q8_0_r8_q8_2(int n, const void * vx, size_t bx, const DataIn
}
#endif
typedef struct {
ggml_half d[16];
uint8_t qs[256];
} block_q8_1_r8;
template <int nrc_y>
static void mul_mat_q8_1_r8_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_2_x4> q8(info);
int nb = n / QK8_0;
__m256 acc[nrc_y] = {};
float d8[4*nrc_y];
__m256i qx[4];
auto dot = [&qx] (const int8_t * qy) {
auto y128 = _mm_loadu_si128((const __m128i*)qy);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
return sumi;
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
return _mm256_add_epi32(_mm256_madd_epi16(_mm256_set1_epi16(1), sumi1), _mm256_madd_epi16(_mm256_set1_epi16(1), sumi2));
#endif
};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_1_r8 * iq8 = (const block_q8_1_r8 *)((const char *)vx + ix*bx);
for (int i4 = 0; i4 < nb/4; ++i4) {
{
__m256 mx[4];
for (int ib32 = 0; ib32 < 4; ++ib32) mx[ib32] = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d+1));
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)q8.y[iy][i4].d)), 16));
_mm_storeu_ps(d8 + 4*iy + 0, scales);
auto bsums4 = _mm_castsi128_ps(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][i4].d+4))), 16));
auto bsums = _mm256_set_m128(bsums4, bsums4);
acc[iy] = _mm256_fmadd_ps(mx[0], _mm256_shuffle_ps(bsums, bsums, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(mx[1], _mm256_shuffle_ps(bsums, bsums, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(mx[2], _mm256_shuffle_ps(bsums, bsums, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(mx[3], _mm256_shuffle_ps(bsums, bsums, 0xff), acc[iy]);
}
}
for (int ib32 = 0; ib32 < 4; ++ib32) {
auto scales = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[4*i4+ib32].d));
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+j);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = dot(q8.y[iy][i4].qs+32*ib32);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
for (int j = 0; j < 4; ++j) {
qx[j] = _mm256_loadu_si256((const __m256i *)iq8[4*i4+ib32].qs+4+j);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = dot(q8.y[iy][i4].qs+32*ib32+16);
auto d4d8 = _mm256_mul_ps(scales, _mm256_set1_ps(d8[4*iy+ib32]));
acc[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi), acc[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
}
}
}
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
if constexpr (std::is_same_v<Dequantizer, Q4_0_Unpacker> || std::is_same_v<Dequantizer, Q5_0_Unpacker> ||
std::is_same_v<Dequantizer, Q8_0_Unpacker>) {
@@ -1694,6 +1769,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
case GGML_TYPE_IQ4_NL_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels)
break;
case GGML_TYPE_Q8_1: // Note: we are misusing the Q8_1 type for Q8_1_R8
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_1_r8_q8_2, kernels)
break;
default:
return false;
}

View File

@@ -243,6 +243,8 @@ struct MulMat {
case GGML_TYPE_IQ3_XXS: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ3_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_IQ1_S : return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
case GGML_TYPE_Q4_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
case GGML_TYPE_Q5_K : return nrc_y >= 32 ? GGML_TYPE_Q8_1 : type;
default: break;
}
#else
@@ -283,6 +285,7 @@ struct MulMat {
case GGML_TYPE_Q5_K_R4:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q8_0_R8:
@@ -318,6 +321,7 @@ struct MulMat {
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_Q8_KV:
case GGML_TYPE_Q8_KV_R8:
case GGML_TYPE_Q8_1:
case GGML_TYPE_Q8_K_R8: return 8;
case GGML_TYPE_BF16_R16: return 16;
default: return 1;
@@ -341,8 +345,8 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
// return iqk_set_kernels_float(ne00, typeA, typeB, mm.funcs);
//case GGML_TYPE_Q2_K:
//case GGML_TYPE_Q3_K:
//case GGML_TYPE_Q4_K:
//case GGML_TYPE_Q5_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
//case GGML_TYPE_Q6_K:
//case GGML_TYPE_IQ4_XS:
//case GGML_TYPE_Q2_K_R4:
@@ -354,7 +358,7 @@ bool iqk_convert_repack(int typeA, int n, const void * vx, size_t bx, void * vy,
//case GGML_TYPE_Q8_K_R8:
//case GGML_TYPE_Q8_KV:
//case GGML_TYPE_Q8_KV_R8:
// return iqk_set_kernels_kquants(ne00, typeA, typeB, mm.funcs, mm.func16);
return iqk_convert_kquants_q8X_r8(typeA, n, vx, bx, vy, nrc_x);
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
@@ -790,6 +794,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q5_0_R4: