Files
ik_llama.cpp/ggml/src/iqk/iqk_gemm_kquants.cpp
Iwan Kawrakow 1e534df8ea Fix NEON build
2025-06-19 18:35:16 +03:00

4062 lines
204 KiB
C++

#include "iqk_gemm_kquants.h"
#ifdef IQK_IMPLEMENT
#include "ggml-impl.h"
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#include "ggml-quants.h"
#ifdef __x86_64__
namespace {
// Handles q4_K and q5_K scales/mins
struct Scales8K {
template <typename Q8>
inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
make_q4_scales(data, utmp);
const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
accum_mins(mins128, q8, i, c, accd);
const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
return MM256_SET_M128I(sc128, sc128);
}
#ifdef HAVE_FANCY_SIMD
template <typename Q8>
inline __m512i process_mins_and_scales_64(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
auto scales = process_mins_and_scales(data, c, i, q8, accd);
return _mm512_inserti32x8(_mm512_castsi256_si512(scales), scales, 1);
}
#endif
template <typename Q8>
inline void accum_mins(const __m128i& mins128, const Q8& q8, int i, float c, __m256 * accd) const {
base.accum_mins(mins128, q8, i, c, accd);
}
#ifdef HAVE_FANCY_SIMD
const __m512i shuffles512[2] = {
_mm512_set_epi64(0x0706070607060706, 0x0302030203020302, 0x0706070607060706, 0x0302030203020302,
0x0504050405040504, 0x0100010001000100, 0x0504050405040504, 0x0100010001000100),
_mm512_set_epi64(0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a, 0x0f0e0f0e0f0e0f0e, 0x0b0a0b0a0b0a0b0a,
0x0d0c0d0c0d0c0d0c, 0x0908090809080908, 0x0d0c0d0c0d0c0d0c, 0x0908090809080908)
};
#endif
Scales8KBase base;
uint32_t utmp[4];
};
template <typename Q8>
inline void process_mins_16(const __m256i& all_scales, const Q8& q8, int i, float d, __m256 * accm) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
const __m256i prod = _mm256_madd_epi16(all_scales, q8.load_bsums(iy, i));
accm[iy] = _mm256_fmadd_ps(_mm256_set1_ps(d * q8.scale(iy, i)), _mm256_cvtepi32_ps(prod), accm[iy]);
}
}
inline void prepare_scales_16(const __m256i& all_scales, __m256i * scales) {
const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
scales[0] = MM256_SET_M128I(l_scales, l_scales);
scales[1] = MM256_SET_M128I(h_scales, h_scales);
}
// Handles q3_K scales
struct ScaleQ3 {
inline __m128i make_scales(const uint16_t * s8) const {
const uint16_t * scales16 = (const uint16_t *)s8;
uint32_t aux0 = scales16[0] | (scales16[1] << 16);
uint32_t aux1 = scales16[2] | (scales16[3] << 16);
uint32_t aux2 = scales16[4] | (scales16[5] << 16);
__m128i scales128 = _mm_set_epi32(
((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030),
((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030),
(aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030),
(aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030));
return _mm_add_epi8(scales128, m32);
}
const __m128i m32 = _mm_set1_epi8(-32);
};
struct Scale16 {
inline void make_scales(const __m128i& scales8, __m512i * scales) const {
auto all_scales8 = MM256_SET_M128I(scales8, scales8);
auto scales1 = _mm256_shuffle_epi8(all_scales8, shuffle1);
auto scales2 = _mm256_shuffle_epi8(all_scales8, shuffle2);
scales[0] = _mm512_cvtepi8_epi16(scales1);
scales[1] = _mm512_cvtepi8_epi16(scales2);
}
template <typename Q8>
inline void process_mins_and_scales(int i, float c, const __m128i& mins8, const __m128i& scales8,
const Q8& q8, __m256 * accm, __m512i * scales) const {
process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, c, accm);
make_scales(scales8, scales);
}
const __m256i shuffle1 = _mm256_set_epi32(0x07070707, 0x03030303, 0x06060606, 0x02020202,
0x05050505, 0x01010101, 0x04040404, 0x00000000);
const __m256i shuffle2 = _mm256_set_epi32(0x0f0f0f0f, 0x0b0b0b0b, 0x0e0e0e0e, 0x0a0a0a0a,
0x0d0d0d0d, 0x09090909, 0x0c0c0c0c, 0x08080808);
};
template <typename Q8>
inline void process_mins_and_scales_16(const __m128i& scales128, const Q8& q8, int i, float d,
__m256 * accm, __m256i * scales) {
const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
process_mins_16(all_scales, q8, i, d, accm);
prepare_scales_16(all_scales, scales);
}
inline __m256i get_scale_shuffle_8(int i) {
return _mm256_set1_epi16((2*i) | ((2*i+1) << 8));
}
inline void set_scales_8(const __m256i& all_scales, int j, __m256i * scales) {
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+0));
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+1));
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+2));
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_8(4*j+3));
}
inline __m256i get_scale_shuffle_16(int i) {
static const uint8_t k_shuffle[128] = {
0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
};
return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
}
inline void set_scales_16(const __m256i& all_scales, __m256i * scales) {
scales[0] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(0));
scales[1] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(1));
scales[2] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(2));
scales[3] = _mm256_shuffle_epi8(all_scales, get_scale_shuffle_16(3));
}
struct ScaleIQ4XS {
inline __m128i make_scales(const uint32_t scales_l, const uint16_t scales_h) {
uint32_t tmp32 = scales_h | (scales_h << 14);
const __m128i sh = _mm_slli_epi16(_mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(tmp32), hshift), hmask), 4);
const __m128i sl = _mm_and_si128(_mm_srlv_epi32(_mm_set1_epi32(scales_l), lshift), lmask);
return _mm_add_epi16(_mm_or_si128(sh, _mm_cvtepi8_epi16(_mm_shuffle_epi8(sl, lshuffle))), m32);
}
const __m128i hshift = _mm_set_epi32(12, 8, 4, 0);
const __m128i lshift = _mm_set_epi32(4, 0, 4, 0);
const __m128i hmask = _mm_set1_epi16(0x03);
const __m128i lmask = _mm_set1_epi8(0xf);
const __m128i lshuffle = _mm_set_epi32(0x07030602, 0x05010400, 0x07030602, 0x05010400);
const __m128i m32 = _mm_set1_epi16(-32);
};
#ifdef HAVE_FANCY_SIMD
//====================================== Zen4 ==================================================
struct HighBit5 {
inline void apply(const uint8_t * h, Q4Bits& bits) {
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(hbits, mh));
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
}
const __m512i mh = _mm512_set1_epi8(0x10);
};
struct HighBit3 {
inline void apply(const uint8_t * h, Q2Bits& bits) {
auto hbits256 = _mm256_loadu_si256((const __m256i *)h);
auto hbits = _mm512_inserti32x8(_mm512_castsi256_si512(hbits256), _mm256_srli_epi16(hbits256, 1), 1);
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh));
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_and_si512(hbits, mh));
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh));
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_and_si512(_mm512_srli_epi16(hbits, 4), mh));
}
const __m512i mh = _mm512_set1_epi8(0x04);
};
template <typename Q8>
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
sc16.process_mins_and_scales(i, -GGML_FP16_TO_FP32(x[i].dmin), mins8, scales8, q8, accm, scales);
}
Q2Bits bits;
Scale16 sc16;
const __m128i m4 = _mm_set1_epi8(0xf);
};
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
hbits.apply(x[i].hmask, bits);
auto scales128 = sc3.make_scales((const uint16_t *)x[i].scales);
sc16.process_mins_and_scales(i, -4.f*d, scales128, scales128, q8, accm, scales);
}
Q2Bits bits;
HighBit3 hbits;
ScaleQ3 sc3;
Scale16 sc16;
const __m128i m4 = _mm_set1_epi8(0xf);
const __m128i m32 = _mm_set1_epi8(-32);
};
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
}
Q4Bits bits;
Scales8K s8k;
};
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare(x[i].qs);
hbits.apply(x[i].qh, bits);
auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
}
Q4Bits bits;
HighBit5 hbits;
Scales8K s8k;
};
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
bits.prepare64(x[i].ql);
add_high_bits(x[i].qh, bits);
auto scales128 = _mm_loadu_si128((const __m128i *)x[i].scales);
sc16.process_mins_and_scales(i, -32.f*d, scales128, scales128, q8, accm, scales);
}
inline void add_high_bits(const uint8_t * qh, Q4Bits& bits) const {
auto hbits = _mm512_loadu_si512((const __m512i *)qh);
auto tmp1 = _mm512_and_si512(_mm512_slli_epi16(hbits, 4), mh);
auto tmp2 = _mm512_and_si512(_mm512_slli_epi16(hbits, 2), mh);
bits.values[0] = _mm512_or_si512(bits.values[0], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
bits.values[2] = _mm512_or_si512(bits.values[2], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
tmp1 = _mm512_and_si512(hbits, mh);
tmp2 = _mm512_and_si512(_mm512_srli_epi16(hbits, 2), mh);
bits.values[1] = _mm512_or_si512(bits.values[1], _mm512_permutex2var_epi64(tmp1, bits.perm.permute1, tmp2));
bits.values[3] = _mm512_or_si512(bits.values[3], _mm512_permutex2var_epi64(tmp1, bits.perm.permute2, tmp2));
}
Q4Bits bits;
HighBit3 hbits;
Scale16 sc16;
const __m512i mh = _mm512_set1_epi8(0x30);
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_512()) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
prepare(x[i].qs);
auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
auto scales256 = MM256_SET_M128I(scales128, scales128);
auto all_scales = _mm512_inserti32x8(_mm512_castsi256_si512(scales256), scales256, 1);
scales[0] = _mm512_shuffle_epi8(all_scales, shuffles[0]);
scales[1] = _mm512_shuffle_epi8(all_scales, shuffles[1]);
scales[2] = _mm512_shuffle_epi8(all_scales, shuffles[2]);
scales[3] = _mm512_shuffle_epi8(all_scales, shuffles[3]);
}
inline void prepare(const uint8_t * q4) {
bits.prepare64(q4);
// We now have in bits.valuse[0]: 0...15, 32...47, 64...79, 96...111
// bits.valuse[1]: 16..31, 48...63, 80...95, 112..127
// etc.
auto tmp = _mm512_permutex2var_epi64(bits.values[0], permute1, bits.values[1]);
bits.values[1] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[0], permute2, bits.values[1]));
bits.values[0] = _mm512_shuffle_epi8(values, tmp);
tmp = _mm512_permutex2var_epi64(bits.values[2], permute1, bits.values[3]);
bits.values[3] = _mm512_shuffle_epi8(values, _mm512_permutex2var_epi64(bits.values[2], permute2, bits.values[3]));
bits.values[2] = _mm512_shuffle_epi8(values, tmp);
}
Q4Bits bits;
Scales8KBase s8k;
ScaleIQ4XS siq4;
const __m512i values;
const __m512i permute1 = _mm512_set_epi64(11, 10, 3, 2, 9, 8, 1, 0);
const __m512i permute2 = _mm512_set_epi64(15, 14, 7, 6, 13, 12, 5, 4);
const __m512i shuffles[4] = {
_mm512_inserti32x8(_mm512_set1_epi16(0x0100), _mm256_set1_epi16(0x0302), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0504), _mm256_set1_epi16(0x0706), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0908), _mm256_set1_epi16(0x0b0a), 1),
_mm512_inserti32x8(_mm512_set1_epi16(0x0d0c), _mm256_set1_epi16(0x0f0e), 1),
};
};
template <typename Dequantizer>
static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
constexpr int k_nx = 2;
Q8<1> q8(info);
Dequantizer deq1(vx, bx);
Dequantizer deq2(vx, bx);
Dequantizer * deq[k_nx];
deq[0] = &deq1;
deq[1] = &deq2;
__m512i scales[2*k_nx];
for (int ix = 0; ix < nrc_x; ++ix) {
auto accd = _mm512_setzero_ps();
auto accm = _mm256_setzero_ps();
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_row(ix);
for (int i = 0; i < nb/k_nx; ++i) {
for (int kx = 0; kx < k_nx; ++kx) deq[kx]->new_block(k_nx*i+kx, q8, &accm, scales+2*kx);
for (int kx = 0; kx < k_nx; ++kx) {
compute_block(0, k_nx*i+kx, deq[kx]->d, q8, deq[kx]->bits.values, scales+2*kx, &accd);
}
}
if (2*(nb/2) < nb) {
int i0 = 2*(nb/2);
deq[0]->new_block(i0, q8, &accm, scales);
compute_block(0, i0, deq[0]->d, q8, deq[0]->bits.values, scales, &accd);
}
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
}
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y> q8(info);
Dequantizer deq(vx, bx);
__m256 accm[nrc_y];
__m512 accd[nrc_y];
__m512i scales[2];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), deq.bits.values[3], q8.load_quants64(iy, i, 3));
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
}
}
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y> q8(info);
Dequantizer deq(vx, bx);
__m256 accm[nrc_y];
__m512 accd[nrc_y];
__m512i scales[4];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm512_setzero_ps();
for (int iy = 0; iy < nrc_y; ++iy) accm[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
deq.new_block(i, q8, accm, scales);
for (int iy = 0; iy < nrc_y; ++iy) {
const __m512i p1 = _mm512_maddubs_epi16(deq.bits.values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_maddubs_epi16(deq.bits.values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_maddubs_epi16(deq.bits.values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_maddubs_epi16(deq.bits.values[3], q8.load_quants64(iy, i, 3));
auto sumi = _mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_dpwssd_epi32(_mm512_setzero_si512(),
p1, scales[0]), p2, scales[1]), p3, scales[2]), p4, scales[3]);
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(deq.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
}
}
}
#else
//====================================== AVX2 ==================================================
struct HighBit5 {
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
inline void apply(Q4Bits& bits, bool do_shift) {
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
if (do_shift) {
hbits = _mm256_srli_epi16(hbits, 4);
}
}
const __m256i mh = _mm256_set1_epi8(0x10);
__m256i hbits;
};
struct HighBit3 {
inline void load(const uint8_t * h) { hbits = _mm256_loadu_si256((const __m256i *)h); }
inline void apply(Q2Bits& bits, bool do_shift) {
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
if (do_shift) {
hbits = _mm256_srli_epi16(hbits, 4);
}
}
const __m256i mh = _mm256_set1_epi8(0x04);
__m256i hbits;
};
template <typename Q8>
inline void compute_block(int iy, int i, float d, const Q8& q8, const __m512i * values, const __m512i * scales, __m512 * accd) {
const __m512i p1 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[0], q8.load_quants64(iy, i, 0));
const __m512i p2 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[1], q8.load_quants64(iy, i, 1));
const __m512i p3 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[2], q8.load_quants64(iy, i, 2));
const __m512i p4 = _mm512_dpbusd_epi32(_mm512_setzero_si512(), values[3], q8.load_quants64(iy, i, 3));
auto sumi = _mm512_dpwssd_epi32(_mm512_setzero_si512(), scales[0], _mm512_packs_epi32(p1, p2));
sumi = _mm512_dpwssd_epi32(sumi, scales[1], _mm512_packs_epi32(p3, p4));
accd[iy] = _mm512_fmadd_ps(_mm512_set1_ps(d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi), accd[iy]);
}
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
DequantizerQ2K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
process_mins_16(_mm256_cvtepi8_epi16(mins8), q8, i, -GGML_FP16_TO_FP32(x[i].dmin), accm);
prepare_scales_16(_mm256_cvtepi8_epi16(scales8), scales);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
}
Q2Bits bits;
const __m128i m4 = _mm_set1_epi8(0xf);
};
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
hbits.load(x[i].hmask);
process_mins_and_scales_16(sc3.make_scales((const uint16_t *)x[i].scales), q8, i, -4.f*d, accm, scales);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
hbits.apply(bits, j == 0);
}
Q2Bits bits;
HighBit3 hbits;
ScaleQ3 sc3;
const __m128i m32 = _mm_set1_epi8(-32);
};
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
d = GGML_FP16_TO_FP32(x[i].d);
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
}
Q4Bits bits;
Scales8K s8k;
};
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
d = GGML_FP16_TO_FP32(x[i].d);
hbits.load(x[i].qh);
return s8k.process_mins_and_scales(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
hbits.apply(bits, j == 0);
}
Q4Bits bits;
HighBit5 hbits;
Scales8K s8k;
};
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
DequantizerQ6K(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
template <typename Q8>
inline void new_block(int i, const Q8& q8, __m256 * accm, __m256i * scales) {
d = GGML_FP16_TO_FP32(x[i].d);
process_mins_and_scales_16(_mm_loadu_si128((const __m128i *)x[i].scales), q8, i, -32.f*d, accm, scales);
}
inline void prepare(int i, int j) {
bits.prepare64(x[i].ql, j);
auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
}
Q4Bits bits;
const __m256i mh = _mm256_set1_epi8(0x30);
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
DequantizerIQ4XS(const void * vx, size_t bx) : BaseDequantizer(vx, bx), values(load_iq4nl_values_256()) {}
template <typename Q8>
inline __m256i new_block(int i, const Q8& q8, __m256 * accd) {
d = GGML_FP16_TO_FP32(x[i].d);
auto scales128 = siq4.make_scales(*(const uint32_t *)x[i].scales_l, x[i].scales_h);
s8k.accum_mins(scales128, q8, i, -128.f*d, accd);
return MM256_SET_M128I(scales128, scales128);
}
inline void prepare(int i, int j) {
bits.prepare16(x[i].qs, j);
bits.values[0] = _mm256_shuffle_epi8(values, bits.values[0]);
bits.values[1] = _mm256_shuffle_epi8(values, bits.values[1]);
bits.values[2] = _mm256_shuffle_epi8(values, bits.values[2]);
bits.values[3] = _mm256_shuffle_epi8(values, bits.values[3]);
}
Q4Bits bits;
Scales8K s8k;
ScaleIQ4XS siq4;
const __m256i values;
};
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y> q8(info);
Dequantizer deq(vx, bx);
__m256 accd[nrc_y];
__m256i scales[4];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
auto all_scales = deq.new_block(i, q8, accd);
__m256i sumi[nrc_y];
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
set_scales_8(all_scales, j, scales);
multiply_add(deq.bits, scales, j, i, q8, sumi);
}
for (int iy = 0; iy < nrc_y; ++iy) {
const __m256 vd = _mm256_set1_ps(deq.d*q8.scale(iy, i));
accd[iy] = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
Q8<nrc_y> q8(info);
__m256i all_scales[2];
__m256i scales[4];
__m256 accd[nrc_y];
Dequantizer deq(vx, bx);
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
deq.new_block(i, q8, accd, all_scales);
__m256i sumi[nrc_y];
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
set_scales_16(all_scales[j], scales);
multiply_add(deq.bits, scales, j, i, q8, sumi);
}
for (int iy = 0; iy < nrc_y; ++iy) {
accd[iy] = _mm256_fmadd_ps(_mm256_set1_ps(deq.d*q8.scale(iy, i)), _mm256_cvtepi32_ps(sumi[iy]), accd[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
#endif
// inline __m256i process_mins_and_scales(const uint8_t * data, float c, int i, const Q8& q8, __m256 * accd) {
// make_q4_scales(data, utmp);
// const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
// const __m128i mins128 = _mm256_extracti128_si256(mins_and_scales, 1);
// accum_mins(mins128, q8, i, c, accd);
// const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
// return MM256_SET_M128I(sc128, sc128);
// }
//
// inline void new_block(int i, const Q8& q8, __m256 * accd, __m512i * scales) {
// d = GGML_FP16_TO_FP32(x[i].d);
// bits.prepare(x[i].qs);
// auto all_scales = s8k.process_mins_and_scales_64(x[i].scales, -GGML_FP16_TO_FP32(x[i].dmin), i, q8, accd);
// scales[0] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[0]);
// scales[1] = _mm512_shuffle_epi8(all_scales, s8k.shuffles512[1]);
// }
struct Q4Bits_AVX2 {
inline void prepare(const uint8_t * q4, int j) {
auto q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+0);
values[0] = _mm256_and_si256(q4bits, ml);
values[1] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
q4bits = _mm256_loadu_si256((const __m256i*)q4 + 2*j+1);
values[2] = _mm256_and_si256(q4bits, ml);
values[3] = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), ml);
}
__m256i values[4];
const __m256i ml = _mm256_set1_epi8(0xf);
};
struct DequantizerQ4K_AVX2 final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
}
Q4Bits_AVX2 bits;
};
struct DequantizerQ5K_AVX2 final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs, j);
hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].qh) : _mm256_srli_epi16(hbits, 4);
apply_hbits();
}
inline void apply_hbits() {
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 3), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
}
const __m256i mh = _mm256_set1_epi8(0x10);
Q4Bits_AVX2 bits;
__m256i hbits;
};
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
uint32_t utmp[4];
__m256 accd[nrc_y];
__m256 scales[2];
float d8[8*nrc_y];
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
auto vm = _mm256_cvtph_ps(_mm_set1_epi16(deq.x[i].dmin));
make_q4_scales(deq.x[i].scales, utmp);
auto mins = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)(utmp + 2)))));
mins = _mm256_mul_ps(_mm256_set1_ps(-1.f), mins);
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
_mm256_storeu_ps(d8 + 8*iy, dy);
auto m4_1 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d+4)));
auto m4_2 = _mm_cvtepi16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d+4)));
auto myi = MM256_SET_M128I(m4_2, m4_1);
auto my = _mm256_mul_ps(dy, _mm256_cvtepi32_ps(myi));
accd[iy] = _mm256_fmadd_ps(my, mins, accd[iy]);
}
auto all_scales = _mm256_mul_ps(_mm256_set1_ps(deq.d), _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i *)utmp))));
scales[0] = _mm256_set_m128(_mm256_castps256_ps128(all_scales), _mm256_castps256_ps128(all_scales));
auto scales_h = _mm256_extractf128_ps(all_scales, 1);
scales[1] = _mm256_set_m128(scales_h, scales_h);
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare(i, j);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_2_x4& y = q8.y[iy][2*i+j];
#ifdef HAVE_FANCY_SIMD
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
#else
auto sumi1 = _mm256_maddubs_epi16(deq.bits.values[0], _mm256_loadu_si256((const __m256i*)y.qs+0));
auto sumi2 = _mm256_maddubs_epi16(deq.bits.values[1], _mm256_loadu_si256((const __m256i*)y.qs+1));
auto sumi3 = _mm256_maddubs_epi16(deq.bits.values[2], _mm256_loadu_si256((const __m256i*)y.qs+2));
auto sumi4 = _mm256_maddubs_epi16(deq.bits.values[3], _mm256_loadu_si256((const __m256i*)y.qs+3));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
#endif
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
struct DequantizerQ6K_AVX2 final : public BaseDequantizer<block_q6_K> {
DequantizerQ6K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
auto lbits1 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)x[i].ql + 2*j+1);
auto hbits = _mm256_loadu_si256((const __m256i *)x[i].qh + j);
bits.values[0] = _mm256_or_si256(_mm256_and_si256(lbits1, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
bits.values[1] = _mm256_or_si256(_mm256_and_si256(lbits2, bits.ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), bits.ml), _mm256_and_si256(hbits, mh));
bits.values[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), bits.ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
}
inline void prepare_signed(int i, int j, __m256i * us) {
prepare(i, j);
for (int k = 0; k < 4; ++k) {
bits.values[k] = _mm256_add_epi8(bits.values[k], _mm256_set1_epi8(-32));
us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]);
}
}
inline __m256i make_scales(int i) const {
return _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)x[i].scales));
}
const __m256i mh = _mm256_set1_epi8(0x30);
Q4Bits_AVX2 bits;
};
struct SimpleBits {
__m256i values[4];
};
struct DequantizerQ3K_AVX2 final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K_AVX2(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
inline void prepare(int i, int j) {
hbits = j == 0 ? _mm256_loadu_si256((const __m256i *)x[i].hmask) : _mm256_srli_epi16(hbits, 4);
auto q2bits = _mm256_loadu_si256((const __m256i *)x[i].qs + j);
bits.values[0] = _mm256_and_si256(q2bits, ml);
bits.values[1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
bits.values[2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
bits.values[3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
bits.values[0] = _mm256_or_si256(bits.values[0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
bits.values[1] = _mm256_or_si256(bits.values[1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
bits.values[2] = _mm256_or_si256(bits.values[2], _mm256_and_si256(hbits, mh));
bits.values[3] = _mm256_or_si256(bits.values[3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
//bits.values[0] = _mm256_sub_epi8(bits.values[0], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh)));
//bits.values[1] = _mm256_sub_epi8(bits.values[1], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh)));
//bits.values[2] = _mm256_sub_epi8(bits.values[2], _mm256_xor_si256(mh, _mm256_and_si256(hbits, mh)));
//bits.values[3] = _mm256_sub_epi8(bits.values[3], _mm256_xor_si256(mh, _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh)));
}
inline void prepare_signed(int i, int j, __m256i * us) {
prepare(i, j);
for (int k = 0; k < 4; ++k) {
bits.values[k] = _mm256_sub_epi8(bits.values[k], mh);
us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]);
}
//for (int k = 0; k < 4; ++k) {
// us[k] = _mm256_sign_epi8(bits.values[k], bits.values[k]);
//}
}
inline __m256i make_scales(int i) const {
return _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x[i].scales));
}
ScaleQ3 sc3;
__m256i hbits;
SimpleBits bits;
const __m256i ml = _mm256_set1_epi8(3);
const __m256i mh = _mm256_set1_epi8(4);
};
template <typename Dequantizer, int nrc_y>
static void mul_mat_qY_K_q8_2_X4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_2_x4> q8(info);
Dequantizer deq(vx, bx);
__m256 accd[nrc_y];
__m256 scales[2];
float d8[8*nrc_y];
__m256i us[4];
uint8_t k_shuff[32] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15, 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
for (int ix = 0; ix < nrc_x; ++ix) {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_ps();
deq.new_row(ix);
for (int i = 0; i < nb; ++i) {
deq.d = GGML_FP16_TO_FP32(deq.x[i].d);
auto vd = _mm256_set1_ps(deq.d);
auto sc16 = _mm256_shuffle_epi8(deq.make_scales(i), shuff);
scales[0] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(sc16))));
scales[1] = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(sc16, 1))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4_1 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+0].d)));
auto d4_2 = _mm_cvtepu16_epi32(_mm_loadl_epi64((const __m128i *)(q8.y[iy][2*i+1].d)));
auto dy = _mm256_castsi256_ps(_mm256_slli_epi32(MM256_SET_M128I(d4_2, d4_1), 16));
if constexpr (nrc_y == 1) {
auto dyh = _mm256_extractf128_ps(dy, 1);
scales[0] = _mm256_mul_ps(scales[0], _mm256_set_m128(_mm256_castps256_ps128(dy), _mm256_castps256_ps128(dy)));
scales[1] = _mm256_mul_ps(scales[1], _mm256_set_m128(dyh, dyh));
} else {
_mm256_storeu_ps(d8 + 8*iy, dy);
}
}
for (int j = 0; j < QK_K/128; ++j) {
deq.prepare_signed(i, j, us);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qs = q8.y[iy][2*i+j].qs;
#ifdef HAVE_FANCY_SIMD
// 0...31
auto sumi1 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
// 32...63
auto sumi2 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
// 64...95
auto sumi3 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
// 96...128
auto sumi4 = _mm256_dpbusd_epi32(_mm256_setzero_si256(), us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
// 0...3, 32...35, 4....7, 36...39, 16...19, 48...51, 20...23, 52...56 +
// 8..11, 40...43, 12...15, 44...47, 24...27, 56...59, 28...31, 60...63
// b0 b2 b0 b2 b1 b3 b1 b3
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
// same as above + 64, so
// b4 b6, b4 b6 b5 b7 b5 b7
sumi3 = _mm256_add_epi32(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
// b0 b2 b4 b6 b1 b3 b5 b7 +
// b0 b2 b4 b6 b1 b3 b5 b7
sumi1 = _mm256_add_epi32(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
#else
auto sumi1 = _mm256_maddubs_epi16(us[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+0), deq.bits.values[0]));
auto sumi2 = _mm256_maddubs_epi16(us[1], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+1), deq.bits.values[1]));
auto sumi3 = _mm256_maddubs_epi16(us[2], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+2), deq.bits.values[2]));
auto sumi4 = _mm256_maddubs_epi16(us[3], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i*)qs+3), deq.bits.values[3]));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi1, sumi2), _mm256_unpackhi_epi32(sumi1, sumi2));
sumi3 = _mm256_add_epi16(_mm256_unpacklo_epi32(sumi3, sumi4), _mm256_unpackhi_epi32(sumi3, sumi4));
sumi1 = _mm256_add_epi16(_mm256_unpacklo_epi64(sumi1, sumi3), _mm256_unpackhi_epi64(sumi1, sumi3));
sumi1 = _mm256_madd_epi16(_mm256_set1_epi16(1), sumi1);
#endif
if constexpr (nrc_y > 1) {
auto dy4 = _mm_loadu_ps(d8 + 8*iy + 4*j);
auto d4d8 = _mm256_mul_ps(scales[j], _mm256_set_m128(dy4, dy4));
accd[iy] = _mm256_fmadd_ps(d4d8, _mm256_cvtepi32_ps(sumi1), accd[iy]);
} else {
accd[iy] = _mm256_fmadd_ps(scales[j], _mm256_cvtepi32_ps(sumi1), accd[iy]);
}
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m30 = _mm256_set1_epi8(0x30);
auto m32 = _mm256_set1_epi8(32);
#ifndef HAVE_FANCY_SIMD
auto s_shuffle = _mm256_set_epi64x(0x0f0e0f0e0d0c0d0c, 0x0b0a0b0a09080908, 0x0706070605040504, 0x0302030201000100);
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
#else
auto values = load_iq4nl_values_256();
#endif
int nbl = n / QK_K;
using helper_t = union { __m256i vec[2]; uint64_t val[8]; };
helper_t h;
__m256 acc[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto slbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto sl1 = _mm256_and_si256(slbits, m4);
auto sl2 = _mm256_and_si256(_mm256_srli_epi16(slbits, 4), m4);
auto shbits = _mm_loadu_si128((const __m128i*)iq4[ibl].scales_h);
auto sh = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
h.vec[0] = _mm256_sub_epi8(_mm256_or_si256(sl1, _mm256_and_si256(_mm256_slli_epi16(sh, 4), m30)), m32);
h.vec[1] = _mm256_sub_epi8(_mm256_or_si256(sl2, _mm256_and_si256(sh, m30)), m32);
__m256i isum[nrc_y] = {};
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto iscales = _mm256_cvtepi8_epi32(_mm_set1_epi64x(h.val[ib]));
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
auto scales_m = _mm256_mul_ps(scales, _mm256_set1_ps(-128.f));
for (int iy = 0; iy < nrc_y; ++iy) {
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm256_fmadd_ps(scales_m, _mm256_set1_ps(m8), acc[iy]);
}
#else
auto iscales = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_set1_epi64x(h.val[ib])), s_shuffle);
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+1);
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
#ifndef HAVE_FANCY_SIMD
auto s1 = _mm256_sign_epi8(qx[0], qx[0]);
auto s2 = _mm256_sign_epi8(qx[1], qx[1]);
auto s3 = _mm256_sign_epi8(qx[2], qx[2]);
auto s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+0);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+2);
bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+4*ib+3);
qx[0] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits1));
qx[1] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits1, 4)));
qx[2] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, bits2));
qx[3] = _mm256_shuffle_epi8(values, _mm256_and_si256(m4, _mm256_srli_epi16(bits2, 4)));
#ifndef HAVE_FANCY_SIMD
s1 = _mm256_sign_epi8(qx[0], qx[0]);
s2 = _mm256_sign_epi8(qx[1], qx[1]);
s3 = _mm256_sign_epi8(qx[2], qx[2]);
s4 = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+2*ib+1);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s4, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi = _mm256_add_epi32(_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi1), _mm256_madd_epi16(iscales, sumi2)),
_mm256_add_epi32(_mm256_madd_epi16(iscales, sumi3), _mm256_madd_epi16(iscales, sumi4)));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
}
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
return;
if constexpr (nrc_y == 1){
mul_mat_iq4_xs_r8_q8_k_avx2<1>(n, vx, bx, info, nrc_x);
} else {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm512_set1_epi8(0xf);
auto values = load_iq4nl_values_512();
int nbl = n / QK_K;
using helper_t = union { __m512i vec; uint32_t val[16]; };
helper_t h;
__m512 acc[nrc_y] = {};
__m512i isum[nrc_y] = {};
__m512i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r8 * iq4l = (const block_iq4_xs_r8 *)((const char *)vx + (ix+0)*bx);
const block_iq4_xs_r8 * iq4h = (const block_iq4_xs_r8 *)((const char *)vx + (ix+4)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4l[ibl].d));
auto dh = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq4h[ibl].d));
auto d4 = _mm512_insertf32x8(_mm512_castps256_ps512(_mm256_set_m128(dl, dl)), _mm256_set_m128(dh, dh), 1);
auto d4x64 = _mm512_mul_ps(d4, _mm512_set1_ps(-64.f));
auto slbits_l = _mm_loadu_si128((const __m128i *)iq4l[ibl].scales_l);
auto shbits_l = _mm_loadu_si128((const __m128i *)iq4h[ibl].scales_l);
auto sl_l = MM256_SET_M128I(_mm_srli_epi16(slbits_l, 4), slbits_l);
auto sh_l = MM256_SET_M128I(_mm_srli_epi16(shbits_l, 4), shbits_l);
auto slb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_l), sh_l, 1), m4);
auto aux64 = (const uint64_t *)iq4l[ibl].scales_h;
auto slbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
aux64 = (const uint64_t *)iq4h[ibl].scales_h;
auto shbits_h = _mm_set_epi64x(aux64[0] >> 2, aux64[0]);
auto sl_h = MM256_SET_M128I(slbits_h, _mm_slli_epi16(slbits_h, 4));
auto sh_h = MM256_SET_M128I(shbits_h, _mm_slli_epi16(shbits_h, 4));
auto shb = _mm512_and_si512(_mm512_inserti32x8(_mm512_castsi256_si512(sl_h), sh_h, 1), _mm512_set1_epi8(0x30));
h.vec = _mm512_sub_epi8(_mm512_or_si512(slb, shb), _mm512_set1_epi8(32));
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm512_cvtepi8_epi32(_mm_blend_epi32(_mm_set1_epi32(h.val[ib+0]), _mm_set1_epi32(h.val[ib+8]), 0x0c));
auto scales = _mm512_cvtepi32_ps(iscales);
auto scales_m = _mm512_mul_ps(scales, d4x64);
auto bits1 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+0)),
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+0), 1);
auto bits2 = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)iq4l[ibl].qs+2*ib+1)),
_mm256_loadu_si256((const __m256i *)iq4h[ibl].qs+2*ib+1), 1);
qx[0] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits1, m4));
qx[1] = _mm512_shuffle_epi8(values, _mm512_and_si512(bits2, m4));
qx[2] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits1, 4), m4));
qx[3] = _mm512_shuffle_epi8(values, _mm512_and_si512(_mm512_srli_epi16(bits2, 4), m4));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y8 = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y8), y8, 1);
auto sumi = _mm512_setzero_si512();
sumi = _mm512_dpbusd_epi32(sumi, qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
sumi = _mm512_dpbusd_epi32(sumi, qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
sumi = _mm512_dpbusd_epi32(sumi, qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
sumi = _mm512_dpbusd_epi32(sumi, qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
isum[iy] = _mm512_add_epi32(isum[iy], _mm512_mullo_epi32(iscales, sumi));
float m8 = ((const float *)q8.y[iy][ibl].bsums)[ib];
acc[iy] = _mm512_fmadd_ps(scales_m, _mm512_set1_ps(m8), acc[iy]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm512_fmadd_ps(_mm512_mul_ps(d4, _mm512_set1_ps(q8.scale(iy, ibl))), _mm512_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm512_setzero_si512();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum1 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 0), _mm512_extractf32x4_ps(acc[iy], 1));
auto sum2 = _mm_add_ps(_mm512_extractf32x4_ps(acc[iy], 2), _mm512_extractf32x4_ps(acc[iy], 3));
info.store(ix+0, iy, sum1);
info.store(ix+4, iy, sum2);
acc[iy] = _mm512_setzero_ps();
}
}
}
}
#else
template <int nrc_y>
static void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
mul_mat_iq4_xs_r8_q8_k_avx2<nrc_y>(n, vx, bx, info, nrc_x);
}
#endif
template <int nrc_y>
static void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mxf = _mm256_set1_epi8(0xf);
auto m03 = _mm256_set1_epi8(0x03);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifdef HAVE_FANCY_SIMD
__m256i isum[nrc_y] = {};
#else
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
int8_t scales[64];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq2[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dm), _mm256_castps256_ps128(dm));
auto m4 = _mm256_set_m128(_mm256_extractf128_ps(dm, 1), _mm256_extractf128_ps(dm, 1));
m4 = _mm256_mul_ps(m4, _mm256_set1_ps(-1.f));
auto all_scales1 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+0);
auto all_scales2 = _mm256_loadu_si256((const __m256i *)iq2[ibl].scales+1);
auto scales1 = _mm256_and_si256(_mm256_srli_epi16(all_scales1, 4), mxf);
auto scales2 = _mm256_and_si256(_mm256_srli_epi16(all_scales2, 4), mxf);
{
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
auto d8 = _mm256_set1_ps(q8.scale(iy, ibl));
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(m4, d8), _mm256_cvtepi32_ps(sumi), acc[iy]);
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, d8);
}
#endif
}
}
all_scales1 = _mm256_and_si256(all_scales1, mxf);
all_scales2 = _mm256_and_si256(all_scales2, mxf);
_mm256_storeu_si256((__m256i *)scales+0, all_scales1);
_mm256_storeu_si256((__m256i *)scales+1, all_scales2);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
#ifndef HAVE_FANCY_SIMD
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq2[ibl].qs+ib);
qx[0] = _mm256_and_si256(lb, m03);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(lb, 2), m03);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(lb, 4), m03);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(lb, 6), m03);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// Quants are in 0...3, so we can add add up all of them as int16_t without overflowing
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m30 = _mm256_set1_epi8(0x30);
auto m32 = _mm256_set1_epi8(32);
auto m03 = _mm256_set1_epi8(0x03);
auto m04 = _mm256_set1_epi8(0x04);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifdef HAVE_FANCY_SIMD
__m256i isum[nrc_y];
#else
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
int8_t scales[64];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq3[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
#ifndef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
#endif
auto slb = _mm256_loadu_si256((const __m256i *)iq3[ibl].scales_l);
auto shbits = _mm_loadu_si128((const __m128i *)iq3[ibl].scales_h);
auto shb = MM256_SET_M128I(_mm_srli_epi16(shbits, 2), shbits);
auto scales1 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(slb, m4), _mm256_and_si256(_mm256_slli_epi16(shb, 4), m30)), m32);
auto scales2 = _mm256_sub_epi8(_mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(slb, 4), m4), _mm256_and_si256(shb, m30)), m32);
_mm256_storeu_si256((__m256i *)scales+0, scales1);
_mm256_storeu_si256((__m256i *)scales+1, scales2);
{
#ifndef HAVE_FANCY_SIMD
auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-4.f));
#endif
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales1, 1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 0)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm256_extracti128_si256(scales2, 1)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
#ifdef HAVE_FANCY_SIMD
s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-4));
s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-4));
s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-4));
s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-4));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
isum[iy] = sumi;
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 8*ib)));
#ifndef HAVE_FANCY_SIMD
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
#endif
auto lb = _mm256_loadu_si256((const __m256i *)iq3[ibl].qs+ib);
auto hbits = _mm_loadu_si128((const __m128i *)iq3[ibl].qh+ib);
auto hb = MM256_SET_M128I(hbits, _mm_slli_epi16(hbits, 4));
qx[0] = _mm256_or_si256(_mm256_and_si256(lb, m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 2)));
qx[1] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 2), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 3)));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 4), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 4)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lb, 6), m03), _mm256_and_si256(m04, _mm256_srli_epi16(hb, 5)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// Quants are in 0...8, so we can add add up all of them as int16_t without overflowing
auto sumi = _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
inline void process_min_r4_b32(int ibl, __m256 m4, __m256i mins, const Q8<nrc_y, block_q8_K>& q8, __m256 * acc) {
auto mins_l = _mm256_castsi256_si128(mins);
auto mins_h = _mm256_extracti128_si256(mins, 1);
auto aux1 = _mm_unpacklo_epi32(mins_l, mins_h);
auto aux2 = _mm_unpackhi_epi32(mins_l, mins_h);
auto ic1 = _mm256_cvtepi8_epi32(aux1);
auto ic2 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux1, 0xee));
auto ic3 = _mm256_cvtepi8_epi32(aux2);
auto ic4 = _mm256_cvtepi8_epi32(_mm_shuffle_epi32(aux2, 0xee));
if constexpr (nrc_y == 1) {
auto bs = _mm256_loadu_ps((const float *)q8.y[0][ibl].bsums);
auto sumf = _mm256_mul_ps(_mm256_cvtepi32_ps(ic1), _mm256_shuffle_ps(bs, bs, 0x00));
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic2), _mm256_shuffle_ps(bs, bs, 0x55), sumf);
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic3), _mm256_shuffle_ps(bs, bs, 0xaa), sumf);
sumf = _mm256_fmadd_ps(_mm256_cvtepi32_ps(ic4), _mm256_shuffle_ps(bs, bs, 0xff), sumf);
acc[0] = _mm256_fmadd_ps(m4, sumf, acc[0]);
} else {
auto c1 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic1));
auto c2 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic2));
auto c3 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic3));
auto c4 = _mm256_mul_ps(m4, _mm256_cvtepi32_ps(ic4));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bs = _mm256_loadu_ps((const float *)q8.y[iy][ibl].bsums);
acc[iy] = _mm256_fmadd_ps(c1, _mm256_shuffle_ps(bs, bs, 0x00), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c2, _mm256_shuffle_ps(bs, bs, 0x55), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c3, _mm256_shuffle_ps(bs, bs, 0xaa), acc[iy]);
acc[iy] = _mm256_fmadd_ps(c4, _mm256_shuffle_ps(bs, bs, 0xff), acc[iy]);
}
}
}
template <int nrc_y>
static void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m3 = _mm256_set1_epi8(0x30);
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq4[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
auto lbits = _mm256_loadu_si256((const __m256i *)iq4[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq4[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m3));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m3));
process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
#else
auto aux = _mm_set1_epi32(hd.val[ib]);
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
auto scales_d = MM256_SET_M128I(aux, aux);
#endif
auto bits1 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+0);
auto bits2 = _mm256_loadu_si256((const __m256i *)iq4[ibl].qs+2*ib+1);
qx[0] = _mm256_and_si256(bits1, mf);
qx[1] = _mm256_and_si256(bits2, mf);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits1, 4), mf);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits2, 4), mf);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_madd_epi16(scales_d, _mm256_add_epi16(sumi1, sumi2)));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = _mm256_set1_epi8(0xf);
auto m10 = _mm256_set1_epi8(0x10);
auto m30 = _mm256_set1_epi8(0x30);
int nbl = n / QK_K;
union { __m256i vec; uint32_t val[8]; } hd;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq5[ibl].d));
auto d4 = _mm256_set_m128(_mm256_castps256_ps128(dl), _mm256_castps256_ps128(dl));
auto m4 = _mm256_mul_ps(_mm256_set1_ps(-1.0f), _mm256_set_m128(_mm256_extractf128_ps(dl, 1), _mm256_extractf128_ps(dl, 1)));
auto lbits = _mm256_loadu_si256((const __m256i *)iq5[ibl].scales_l);
auto hbits128 = _mm_loadu_si128((const __m128i *)iq5[ibl].scales_h);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
hd.vec = _mm256_or_si256(_mm256_and_si256(lbits, mf), _mm256_and_si256(hbits, m30));
auto mins = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits, 4), mf), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), m30));
process_min_r4_b32(ibl, m4, mins, q8, acc);
for (int ib = 0; ib < QK_K/32; ++ib) {
#ifdef HAVE_FANCY_SIMD
auto scales_d = _mm256_cvtepi8_epi32(_mm_set1_epi32(hd.val[ib]));
#else
auto aux = _mm_set1_epi32(hd.val[ib]);
aux = _mm_cvtepu8_epi16(_mm_unpacklo_epi8(aux, aux));
auto scales_d = MM256_SET_M128I(aux, aux);
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq5[ibl].qs+2*ib+1);
auto hbits128 = _mm_loadu_si128((const __m128i*)iq5[ibl].qh + ib);
auto hbits = MM256_SET_M128I(hbits128, _mm_slli_epi16(hbits128, 4));
qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, mf), _mm256_and_si256(m10, hbits));
qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 2)));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 1)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), mf), _mm256_and_si256(m10, _mm256_srli_epi16(hbits, 3)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(scales_d, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// To avoid overflow, we can only add up to 4 q5 x q8 products.
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(scales_d, sumi1), _mm256_madd_epi16(scales_d, sumi2));
isum[iy] = _mm256_add_epi32(isum[iy], sumi);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <int nrc_y>
static void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = _mm256_set1_epi8(0xf);
auto m3 = _mm256_set1_epi8(0x30);
static const uint8_t k_shuff[32] = {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15, 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15};
auto shuff = _mm256_loadu_si256((const __m256i *)k_shuff);
#ifdef HAVE_FANCY_SIMD
__m256i isum[nrc_y];
#else
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto dl = _mm_cvtph_ps(_mm_loadl_epi64((const __m128i *)iq6[ibl].d));
auto d4 = _mm256_set_m128(dl, dl);
#ifndef HAVE_FANCY_SIMD
if constexpr (nrc_y == 1) {
d4 = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(0, ibl)));
}
#endif
{
#ifndef HAVE_FANCY_SIMD
auto min = _mm256_mul_ps(d4, _mm256_set1_ps(-32.f));
#endif
auto t1 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+0)), shuff); // blocks 0, 1, 2, 3 for each row
auto t2 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+1)), shuff); // blocks 4, 5, 6, 7 for each row
auto t3 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+2)), shuff); // blocks 8, 9, 10, 11 for each row
auto t4 = _mm256_shuffle_epi8(_mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i *)iq6[ibl].scales+3)), shuff); // blocks 12, 13, 14, 15 for each row
auto s1 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 0), _mm256_extracti128_si256(t1, 0)); // blocks 0, 1, 8, 9
auto s2 = MM256_SET_M128I(_mm256_extracti128_si256(t3, 1), _mm256_extracti128_si256(t1, 1)); // blocks 2, 3, 10, 11
auto s3 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 0), _mm256_extracti128_si256(t2, 0)); // blocks 4, 5, 12, 13
auto s4 = MM256_SET_M128I(_mm256_extracti128_si256(t4, 1), _mm256_extracti128_si256(t2, 1)); // blocks 6, 7, 14, 15
#ifdef HAVE_FANCY_SIMD
s1 = _mm256_mullo_epi16(s1, _mm256_set1_epi16(-32));
s2 = _mm256_mullo_epi16(s2, _mm256_set1_epi16(-32));
s3 = _mm256_mullo_epi16(s3, _mm256_set1_epi16(-32));
s4 = _mm256_mullo_epi16(s4, _mm256_set1_epi16(-32));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums(iy, ibl);
auto sumi = _mm256_setzero_si256();
#ifdef HAVE_FANCY_SIMD
sumi = _mm256_dpwssd_epi32(sumi, s1, _mm256_shuffle_epi32(bsums, 0x00));
sumi = _mm256_dpwssd_epi32(sumi, s2, _mm256_shuffle_epi32(bsums, 0x55));
sumi = _mm256_dpwssd_epi32(sumi, s3, _mm256_shuffle_epi32(bsums, 0xaa));
sumi = _mm256_dpwssd_epi32(sumi, s4, _mm256_shuffle_epi32(bsums, 0xff));
isum[iy] = sumi;
#else
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s1, _mm256_shuffle_epi32(bsums, 0x00)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s2, _mm256_shuffle_epi32(bsums, 0x55)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s3, _mm256_shuffle_epi32(bsums, 0xaa)));
sumi = _mm256_add_epi32(sumi, _mm256_madd_epi16(s4, _mm256_shuffle_epi32(bsums, 0xff)));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(min, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(min, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
const uint32_t * scales = (const uint32_t *)iq6[ibl].scales;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto iscales = _mm256_cvtepi8_epi32(_mm_loadl_epi64((const __m128i *)(scales + 2*ib)));
#ifndef HAVE_FANCY_SIMD
auto scales = _mm256_mul_ps(d4, _mm256_cvtepi32_ps(iscales));
#endif
auto lbits1 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)iq6[ibl].ql+2*ib+1);
auto hbits = _mm256_loadu_si256((const __m256i *)iq6[ibl].qh+ib);
qx[0] = _mm256_or_si256(_mm256_and_si256(lbits1, m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 4)));
qx[1] = _mm256_or_si256(_mm256_and_si256(lbits2, m4), _mm256_and_si256(m3, hbits));
qx[2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), m4), _mm256_and_si256(m3, _mm256_slli_epi16(hbits, 2)));
qx[3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), m4), _mm256_and_si256(m3, _mm256_srli_epi16(hbits, 2)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i*)q8.y[iy][ibl].qs+ib);
#ifdef HAVE_FANCY_SIMD
auto sumi = _mm256_setzero_si256();
sumi = _mm256_dpbusd_epi32(sumi, qx[0], _mm256_shuffle_epi32(y, 0x00));
sumi = _mm256_dpbusd_epi32(sumi, qx[1], _mm256_shuffle_epi32(y, 0x55));
sumi = _mm256_dpbusd_epi32(sumi, qx[2], _mm256_shuffle_epi32(y, 0xaa));
sumi = _mm256_dpbusd_epi32(sumi, qx[3], _mm256_shuffle_epi32(y, 0xff));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_mullo_epi32(iscales, sumi));
#else
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
// Quants are in 0...63, so we can add at most 4 as int16_t to be sure of no int16_t overflow
auto sumi = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
if constexpr (nrc_y == 1) {
acc[iy] = _mm256_fmadd_ps(scales, _mm256_cvtepi32_ps(sumi), acc[iy]);
} else {
acc[iy] = _mm256_fmadd_ps(_mm256_mul_ps(scales, _mm256_set1_ps(q8.scale(iy, ibl))), _mm256_cvtepi32_ps(sumi), acc[iy]);
}
#endif
}
}
#ifdef HAVE_FANCY_SIMD
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
}
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1));
acc[iy] = _mm256_setzero_ps();
info.store(ix+0, iy, sum);
}
}
}
template <typename Dequantizer> void set_functions(std::array<mul_mat_t, IQK_MAX_NY>& funcs) {
#ifdef HAVE_FANCY_SIMD
if constexpr (std::is_same_v<Dequantizer, DequantizerIQ4XS>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_iqX_k_q8_K_AVX512, Dequantizer, funcs)
} else {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_AVX512, Dequantizer, funcs)
funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
}
#else
if constexpr (std::is_same_v<Dequantizer, DequantizerQ2K> ||
std::is_same_v<Dequantizer, DequantizerQ3K> ||
std::is_same_v<Dequantizer, DequantizerQ6K>) {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_K_T, Dequantizer, funcs)
} else {
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, Dequantizer, funcs)
}
#endif
}
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nbl = n / QK_K;
__m256 acc[nrc_y] = {};
__m256i isum[nrc_y] = {};
__m256i qx[4];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)iq8[ibl].d));
for (int ib = 0; ib < QK_K/16; ++ib) {
qx[0] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+0);
qx[1] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+1);
qx[2] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+2);
qx[3] = _mm256_loadu_si256((const __m256i *)iq8[ibl].qs+4*ib+3);
#ifndef HAVE_FANCY_SIMD
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
#else
qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8.y[iy][ibl].qs+ib);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
isum[iy] = _mm256_dpbusd_epi32(isum[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto sumi1 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0])));
auto sumi2 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1])));
auto sumi3 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2])));
auto sumi4 = _mm256_madd_epi16(m1, _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3])));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi1, sumi2));
isum[iy] = _mm256_add_epi32(isum[iy], _mm256_add_epi32(sumi3, sumi4));
#endif
}
}
#ifdef HAVE_FANCY_SIMD
auto m4 = _mm256_mul_ps(d4, _mm256_set1_ps(-127.f));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto d4y = _mm256_mul_ps(d4, _mm256_set1_ps(q8.scale(iy, ibl)));
acc[iy] = _mm256_fmadd_ps(d4y, _mm256_cvtepi32_ps(isum[iy]), acc[iy]);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_fmadd_ps(m4, _mm256_set1_ps(q8.y[iy][ibl].sum), acc[iy]);
#endif
isum[iy] = _mm256_setzero_si256();
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = _mm256_setzero_ps();
}
}
}
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(n%32 == 0);
__m256i qx[4];
#ifndef HAVE_FANCY_SIMD
__m256i sx[4];
auto m1 = _mm256_set1_epi16(1);
#endif
__m256i acc[nrc_y] = {};
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
int32_t sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[4];
float dx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int kx = 0; kx < 4; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/32; ++i) {
for (int kx = 0; kx < 4; ++kx) qx[kx] = _mm256_loadu_si256((const __m256i *)q8x[kx] + i);
auto t0 = _mm256_unpacklo_epi32(qx[0], qx[1]);
auto t1 = _mm256_unpacklo_epi32(qx[2], qx[3]);
auto t2 = _mm256_unpackhi_epi32(qx[0], qx[1]);
auto t3 = _mm256_unpackhi_epi32(qx[2], qx[3]);
#ifdef HAVE_FANCY_SIMD
qx[0] = _mm256_add_epi8(_mm256_unpacklo_epi64(t0, t1), _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(_mm256_unpackhi_epi64(t0, t1), _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(_mm256_unpacklo_epi64(t2, t3), _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(_mm256_unpackhi_epi64(t2, t3), _mm256_set1_epi8(127));
#else
qx[0] = _mm256_unpacklo_epi64(t0, t1); sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
qx[1] = _mm256_unpackhi_epi64(t0, t1); sx[1] = _mm256_sign_epi8(qx[1], qx[1]);
qx[2] = _mm256_unpacklo_epi64(t2, t3); sx[2] = _mm256_sign_epi8(qx[2], qx[2]);
qx[3] = _mm256_unpackhi_epi64(t2, t3); sx[3] = _mm256_sign_epi8(qx[3], qx[3]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto dot1 = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto dot2 = _mm256_maddubs_epi16(sx[1], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto dot3 = _mm256_maddubs_epi16(sx[2], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto dot4 = _mm256_maddubs_epi16(sx[3], _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto dot12 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot1), _mm256_madd_epi16(m1, dot2));
auto dot34 = _mm256_add_epi32(_mm256_madd_epi16(m1, dot3), _mm256_madd_epi16(m1, dot4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(dot12, dot34));
#endif
}
}
auto scales_x = _mm_loadu_ps(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(acc[iy]), _mm256_extracti128_si256(acc[iy], 1));
#ifdef HAVE_FANCY_SIMD
sumi = _mm_add_epi32(sumi, _mm_set1_epi32(sy[iy]));
#endif
auto scale = _mm_mul_ps(scales_x, _mm_set1_ps(dy[iy]));
info.store(ix, iy, _mm_mul_ps(scale, _mm_cvtepi32_ps(sumi)));
acc[iy] = _mm256_setzero_si256();
}
}
}
// The HAVE_FANCY_SIMD should only be #if defined(__AVX512_VNNI__ && defined(__AVX512VL__)
template <int nrc_y>
static void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%32 == 0);
GGML_ASSERT(nrc_x%8 == 0);
#ifndef HAVE_FANCY_SIMD
auto m1 = _mm256_set1_epi16(1);
#endif
int nb = n / 16;
__m256i acc[nrc_y] = {};
__m256i qx[4];
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
float sy[nrc_y];
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ix += 8) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto dx = _mm256_loadu_ps(dptr);
auto q8x = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < nb; ++ib) { // Blocks of 16 for 8 interleaved rows
qx[0] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+0);
qx[1] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+1);
qx[2] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+2);
qx[3] = _mm256_loadu_si256((const __m256i *)q8x+4*ib+3);
#ifndef HAVE_FANCY_SIMD
auto s0 = _mm256_sign_epi8(qx[0], qx[0]);
auto s1 = _mm256_sign_epi8(qx[1], qx[1]);
auto s2 = _mm256_sign_epi8(qx[2], qx[2]);
auto s3 = _mm256_sign_epi8(qx[3], qx[3]);
#else
qx[0] = _mm256_add_epi8(qx[0], _mm256_set1_epi8(127));
qx[1] = _mm256_add_epi8(qx[1], _mm256_set1_epi8(127));
qx[2] = _mm256_add_epi8(qx[2], _mm256_set1_epi8(127));
qx[3] = _mm256_add_epi8(qx[3], _mm256_set1_epi8(127));
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
auto y128 = _mm_loadu_si128((const __m128i*)q8y[iy]+ib);
auto y = MM256_SET_M128I(y128, y128);
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[0], _mm256_shuffle_epi32(y, 0x00));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[1], _mm256_shuffle_epi32(y, 0x55));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[2], _mm256_shuffle_epi32(y, 0xaa));
acc[iy] = _mm256_dpbusd_epi32(acc[iy], qx[3], _mm256_shuffle_epi32(y, 0xff));
#else
auto sumi1 = _mm256_maddubs_epi16(s0, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x00), qx[0]));
auto sumi2 = _mm256_maddubs_epi16(s1, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0x55), qx[1]));
auto sumi3 = _mm256_maddubs_epi16(s2, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xaa), qx[2]));
auto sumi4 = _mm256_maddubs_epi16(s3, _mm256_sign_epi8(_mm256_shuffle_epi32(y, 0xff), qx[3]));
auto sumi12 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi1), _mm256_madd_epi16(m1, sumi2));
auto sumi34 = _mm256_add_epi32(_mm256_madd_epi16(m1, sumi3), _mm256_madd_epi16(m1, sumi4));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_add_epi32(sumi12, sumi34));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale = _mm256_mul_ps(dx, _mm256_set1_ps(dy[iy]));
#ifdef HAVE_FANCY_SIMD
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_set1_epi32(sy[iy]));
#endif
info.store(ix, iy, _mm256_mul_ps(scale, _mm256_cvtepi32_ps(acc[iy])));
acc[iy] = _mm256_setzero_si256();
}
}
}
typedef struct {
ggml_half d[16];
int8_t qs[8*QK8_1];
} block_q8_1_r8;
void iqk_convert_q2_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q2_K * x8[8];
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
float f_values[QK_K];
uint32_t block[8];
__m256i xv[4];
auto ml = _mm256_set1_epi8(0x03);
auto sign_bit = _mm256_set1_ps(-0.0f);
auto perm = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q2_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
auto vd = _mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].d));
auto vm = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x8[k][i].dmin)), _mm256_set1_ps(-1.f));
auto block_max = _mm256_setzero_ps();
for (int i128 = 0; i128 < 2; ++i128) {
auto bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs+i128);
xv[0] = _mm256_and_si256(bits, ml);
xv[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), ml);
xv[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), ml);
xv[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), ml);
for (int l = 0; l < 4; ++l) {
auto q1 = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(xv[l]));
auto q2 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(xv[l], 1));
q1 = _mm256_mullo_epi16(q1, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 0] & 0xf));
q2 = _mm256_mullo_epi16(q2, _mm256_set1_epi16(x8[k][i].scales[8*i128 + 2*l + 1] & 0xf));
auto m1 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 0] >> 4));
auto m2 = _mm256_mul_ps(vm, _mm256_set1_ps(x8[k][i].scales[8*i128 + 2*l + 1] >> 4));
auto v0 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q1))), vd, m1);
auto v1 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q1, 1))), vd, m1);
auto v2 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(q2))), vd, m2);
auto v3 = _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(q2, 1))), vd, m2);
auto max = _mm256_max_ps(_mm256_max_ps(_mm256_andnot_ps(sign_bit, v0), _mm256_andnot_ps(sign_bit, v1)),
_mm256_max_ps(_mm256_andnot_ps(sign_bit, v2), _mm256_andnot_ps(sign_bit, v3)));
block_max = _mm256_max_ps(block_max, max);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 0, v0);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 8, v1);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 16, v2);
_mm256_storeu_ps(f_values + 128*i128 + 32*l + 24, v3);
}
}
auto max4 = _mm_max_ps(_mm256_extractf128_ps(block_max, 1), _mm256_castps256_ps128(block_max));
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
float d = _mm_cvtss_f32(max4)/127.f;
auto id = _mm256_set1_ps(d != 0.0f ? 1/d : 0.0f);
y[i].d[k] = GGML_FP32_TO_FP16(d);
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto v0 = _mm256_loadu_ps(f_values + 32*ib32 + 0);
auto v1 = _mm256_loadu_ps(f_values + 32*ib32 + 8);
auto v2 = _mm256_loadu_ps(f_values + 32*ib32 + 16);
auto v3 = _mm256_loadu_ps(f_values + 32*ib32 + 24);
auto i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v0, id), _MM_ROUND_NEAREST));
auto i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v1, id), _MM_ROUND_NEAREST));
auto i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v2, id), _MM_ROUND_NEAREST));
auto i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(v3, id), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, perm);
_mm256_storeu_si256((__m256i *)block, i0);
auto q8 = (uint32_t *)y[i].qs + 64*ib32;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
}
y += nb;
}
}
void iqk_convert_q4_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q4_K * x8[8];
block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
ggml_half dh[16];
uint16_t all_ls[128];
uint32_t utmp[4];
const uint8_t * u8 = (const uint8_t *)utmp;
uint32_t block[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q4_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
dh[k+0] = x8[k][i].d;
dh[k+8] = x8[k][i].dmin;
make_q4_scales(x8[k][i].scales, utmp);
auto qs = x8[k][i].qs;
for (int ib64 = 0; ib64 < 4; ++ib64) {
all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
_mm256_storeu_si256((__m256i *)block, values1);
auto q8 = (uint32_t *)y[2*ib64+0].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
_mm256_storeu_si256((__m256i *)block, values2);
q8 = (uint32_t *)y[2*ib64+1].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
iscales32 = _mm256_cvtepi16_epi32(iscales16);
scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
void iqk_convert_q5_k_q8_1_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q5_K * x8[8];
block_q8_1_r8 * y = (block_q8_1_r8 *)vy;
ggml_half dh[16];
uint16_t all_ls[128];
uint32_t utmp[4];
const uint8_t * u8 = (const uint8_t *)utmp;
uint32_t block[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q5_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
dh[k+0] = x8[k][i].d;
dh[k+8] = x8[k][i].dmin;
make_q4_scales(x8[k][i].scales, utmp);
auto qs = x8[k][i].qs;
auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].qh);
for (int ib64 = 0; ib64 < 4; ++ib64) {
all_ls[8*(2*ib64 + 0) + k ] = u8[2*ib64+0];
all_ls[8*(2*ib64 + 1) + k ] = u8[2*ib64+1];
all_ls[8*(2*ib64 + 0) + k + 64] = u8[2*ib64+8];
all_ls[8*(2*ib64 + 1) + k + 64] = u8[2*ib64+9];
auto bits = _mm256_loadu_si256((const __m256i *)qs+ib64);
auto values1 = _mm256_and_si256(bits, _mm256_set1_epi8(0xf));
auto values2 = _mm256_and_si256(_mm256_srli_epi16(bits, 4), _mm256_set1_epi8(0xf));
values1 = _mm256_or_si256(values1, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 4)));
values2 = _mm256_or_si256(values2, _mm256_and_si256(_mm256_set1_epi8(0x10), _mm256_slli_epi16(hbits, 3)));
hbits = _mm256_srli_epi16(hbits, 2);
_mm256_storeu_si256((__m256i *)block, values1);
auto q8 = (uint32_t *)y[2*ib64+0].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
_mm256_storeu_si256((__m256i *)block, values2);
q8 = (uint32_t *)y[2*ib64+1].qs;
for (int l = 0; l < 4; ++l) {
q8[8*l + k + 0] = block[l + 0];
q8[8*l + k + 32] = block[l + 4];
}
}
}
auto vd = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+0));
auto vm = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)dh+1));
vm = _mm256_mul_ps(_mm256_set1_ps(-1.f), vm);
for (int ib32 = 0; ib32 < QK_K/32; ++ib32) {
auto iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32);
auto iscales32 = _mm256_cvtepi16_epi32(iscales16);
auto scales = _mm256_mul_ps(vd, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+0, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
iscales16 = _mm_loadu_si128((const __m128i *)all_ls + ib32 + 8);
iscales32 = _mm256_cvtepi16_epi32(iscales16);
scales = _mm256_mul_ps(vm, _mm256_cvtepi32_ps(iscales32));
_mm_storeu_si128((__m128i *)y[ib32].d+1, _mm256_cvtps_ph(scales, _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
void iqk_convert_q6_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q6_K * x8[8];
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
float all_s[64];
uint32_t block[8];
__m256i values[8];
auto ml = _mm256_set1_epi8(0x0f);
auto mh = _mm256_set1_epi8(0x30);
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q6_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
auto ql = x8[k][i].ql;
auto qh = x8[k][i].qh;
for (int i128 = 0; i128 < 2; ++i128) {
auto lbits1 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 0);
auto lbits2 = _mm256_loadu_si256((const __m256i *)ql + 2*i128 + 1);
auto hbits = _mm256_loadu_si256((const __m256i *)qh + i128);
values[4*i128+0] = _mm256_or_si256(_mm256_and_si256(lbits1, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 4), mh));
values[4*i128+1] = _mm256_or_si256(_mm256_and_si256(lbits2, ml), _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
values[4*i128+2] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits1, 4), ml), _mm256_and_si256(hbits, mh));
values[4*i128+3] = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(lbits2, 4), ml), _mm256_and_si256(_mm256_srli_epi16(hbits, 2), mh));
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
// We have two blocks of 16 with different scales
// We multiply the quants with the scales, find the max value, and convert to 8-bit quants with a single block scale.
auto q8 = _mm256_add_epi8(values[ib32], _mm256_set1_epi8(-32));
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q8));
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q8, 1));
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(x8[k][i].scales[2*ib32+0]));
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(x8[k][i].scales[2*ib32+1]));
auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
auto max4 = _mm_cvtepi32_ps(imax4);
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
float max = _mm_cvtss_f32(max4) / 127;
all_s[8*ib32+k] = d*max;
if (max > 1e-9f) {
auto scale = _mm256_set1_ps(1/max);
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_si256((__m256i *)block, i0);
} else {
_mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
}
auto qs = (uint32_t *)y[ib32].qs;
for (int l = 0; l < 4; ++l) {
qs[8*l + k + 0] = block[l + 0];
qs[8*l + k + 32] = block[l + 4];
}
}
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
void iqk_convert_q3_k_q8_0_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q3_K * x8[8];
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
float all_s[64];
uint32_t block[8];
__m256i values[8];
ScaleQ3 sc3;
auto ml = _mm256_set1_epi8(0x03);
auto mh = _mm256_set1_epi8(0x04);
union { __m256i vec; int16_t val[16]; } helper;
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask);
for (int i128 = 0; i128 < 2; ++i128) {
auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128);
values[4*i128+0] = _mm256_and_si256(q2bits, ml);
values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh));
values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh);
values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh);
values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh);
values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh);
hbits = _mm256_srli_epi16(hbits, 4);
}
helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales));
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32]));
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1));
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0]));
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1]));
auto abs_q16_l = _mm256_sign_epi16(q16_l, q16_l);
auto abs_q16_h = _mm256_sign_epi16(q16_h, q16_h);
auto max_q16 = _mm256_max_epi16(abs_q16_l, abs_q16_h);
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_q16), _mm256_extracti128_si256(max_q16, 1)));
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
auto max4 = _mm_cvtepi32_ps(imax4);
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
float max = _mm_cvtss_f32(max4) / 127;
all_s[8*ib32+k] = d*max;
if (max > 1e-9f) {
auto scale = _mm256_set1_ps(1/max);
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_si256((__m256i *)block, i0);
} else {
_mm256_storeu_si256((__m256i *)block, _mm256_setzero_si256());
}
auto qs = (uint32_t *)y[ib32].qs;
for (int l = 0; l < 4; ++l) {
qs[8*l + k + 0] = block[l + 0];
qs[8*l + k + 32] = block[l + 4];
}
}
}
for (int ib32 = 0; ib32 < 8; ++ib32) {
_mm_storeu_si128((__m128i *)y[ib32].d, _mm256_cvtps_ph(_mm256_loadu_ps(all_s + 8*ib32), _MM_FROUND_TO_NEAREST_INT));
}
y += QK_K/32;
}
}
}
void iqk_convert_q3_k_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_q3_K * x8[8];
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
uint32_t block[8];
__m256i values[8];
ScaleQ3 sc3;
auto ml = _mm256_set1_epi8(0x03);
auto mh = _mm256_set1_epi8(0x04);
union { __m256i vec; int16_t val[16]; } helper;
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q3_K *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
auto hbits = _mm256_loadu_si256((const __m256i *)x8[k][i].hmask);
helper.vec = _mm256_cvtepi8_epi16(sc3.make_scales((const uint16_t *)x8[k][i].scales));
auto max_i16 = _mm256_setzero_si256();
for (int i128 = 0; i128 < 2; ++i128) {
auto q2bits = _mm256_loadu_si256((const __m256i *)x8[k][i].qs + i128);
values[4*i128+0] = _mm256_and_si256(q2bits, ml);
values[4*i128+1] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), ml);
values[4*i128+2] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), ml);
values[4*i128+3] = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), ml);
values[4*i128+0] = _mm256_or_si256(values[4*i128+0], _mm256_and_si256(_mm256_slli_epi16(hbits, 2), mh));
values[4*i128+1] = _mm256_or_si256(values[4*i128+1], _mm256_and_si256(_mm256_slli_epi16(hbits, 1), mh));
values[4*i128+2] = _mm256_or_si256(values[4*i128+2], _mm256_and_si256(hbits, mh));
values[4*i128+3] = _mm256_or_si256(values[4*i128+3], _mm256_and_si256(_mm256_srli_epi16(hbits, 1), mh));
values[4*i128+0] = _mm256_sub_epi8(values[4*i128+0], mh);
values[4*i128+1] = _mm256_sub_epi8(values[4*i128+1], mh);
values[4*i128+2] = _mm256_sub_epi8(values[4*i128+2], mh);
values[4*i128+3] = _mm256_sub_epi8(values[4*i128+3], mh);
hbits = _mm256_srli_epi16(hbits, 4);
for (int l = 0; l < 4; ++l) {
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[4*i128+l]));
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[4*i128+l], 1));
q16_l = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+0]), q16_l);
q16_h = _mm256_mullo_epi16(_mm256_set1_epi16(helper.val[8*i128+2*l+1]), q16_h);
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_l, q16_l));
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(q16_h, q16_h));
}
}
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
auto max4 = _mm_cvtepi32_ps(imax4);
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
bool needs_scaling = true;
float dnew = _mm_cvtss_f32(max4) / 127;
if (dnew < 1.f) {
dnew = 1.f; needs_scaling = false;
}
d *= dnew;
y[i].d[k] = GGML_FP32_TO_FP16(d);
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
for (int ib32 = 0; ib32 < 8; ++ib32) {
auto q16_l = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(values[ib32]));
auto q16_h = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(values[ib32], 1));
q16_l = _mm256_mullo_epi16(q16_l, _mm256_set1_epi16(helper.val[2*ib32+0]));
q16_h = _mm256_mullo_epi16(q16_h, _mm256_set1_epi16(helper.val[2*ib32+1]));
if (needs_scaling) {
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_l));
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_l, 1));
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(q16_h));
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(q16_h, 1));
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_si256((__m256i *)block, i0);
} else {
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
auto i0 = _mm256_packs_epi16(q16_l, q16_h);
auto i0_l = _mm256_castsi256_si128(i0);
auto i0_h = _mm256_extracti128_si256(i0, 1);
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
}
auto qs = (uint32_t *)y[i].qs + 64*ib32;
for (int l = 0; l < 8; ++l) {
qs[8*l + k] = block[l];
}
}
}
}
y += nb;
}
}
inline float convert_to_q8_k_r8(int k, float d0, const __m256i * qx, const int16_t * scales, uint32_t * block, int8_t * q8_k) {
auto max_i16 = _mm256_setzero_si256();
__m256i qs[16];
for (int ib32 = 0; ib32 < 8; ++ib32) {
qs[2*ib32+0] = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(qx[ib32]));
qs[2*ib32+1] = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(qx[ib32], 1));
qs[2*ib32+0] = _mm256_mullo_epi16(qs[2*ib32+0], _mm256_set1_epi16(scales[2*ib32+0]));
qs[2*ib32+1] = _mm256_mullo_epi16(qs[2*ib32+1], _mm256_set1_epi16(scales[2*ib32+1]));
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+0], qs[2*ib32+0]));
max_i16 = _mm256_max_epi16(max_i16, _mm256_sign_epi16(qs[2*ib32+1], qs[2*ib32+1]));
}
auto max_q32 = _mm256_cvtepi16_epi32(_mm_max_epi16(_mm256_castsi256_si128(max_i16), _mm256_extracti128_si256(max_i16, 1)));
auto imax4 = _mm_max_epi32(_mm256_castsi256_si128(max_q32), _mm256_extracti128_si256(max_q32, 1));
auto max4 = _mm_cvtepi32_ps(imax4);
max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
bool needs_scaling = true;
float dnew = _mm_cvtss_f32(max4) * d0;
if (dnew < 1.f) {
dnew = 1.f; needs_scaling = false;
}
auto scale = _mm256_set1_ps(std::abs(dnew) > 1e-9f ? 1/dnew : 0.f);
for (int ib32 = 0; ib32 < 8; ++ib32) {
if (needs_scaling) {
auto i0 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+0]));
auto i1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+0], 1));
auto i2 = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(qs[2*ib32+1]));
auto i3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(qs[2*ib32+1], 1));
i0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i0)), _MM_ROUND_NEAREST));
i1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i1)), _MM_ROUND_NEAREST));
i2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i2)), _MM_ROUND_NEAREST));
i3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(scale, _mm256_cvtepi32_ps(i3)), _MM_ROUND_NEAREST));
i0 = _mm256_packs_epi32(i0, i1);
i2 = _mm256_packs_epi32(i2, i3);
i0 = _mm256_packs_epi16(i0, i2);
i0 = _mm256_permutevar8x32_epi32(i0, _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7));
_mm256_storeu_si256((__m256i *)block, i0);
} else {
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 16, 17, 18, 19, 20, 21, 22, 23, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31
auto i0 = _mm256_packs_epi16(qs[2*ib32+0], qs[2*ib32+1]);
auto i0_l = _mm256_castsi256_si128(i0);
auto i0_h = _mm256_extracti128_si256(i0, 1);
_mm_storeu_si128((__m128i *)block+0, _mm_unpacklo_epi64(i0_l, i0_h));
_mm_storeu_si128((__m128i *)block+1, _mm_unpackhi_epi64(i0_l, i0_h));
}
auto qs = (uint32_t *)q8_k + 64*ib32;
for (int l = 0; l < 8; ++l) {
qs[8*l + k] = block[l];
}
}
return dnew;
}
// TODO: move this to iqk_gemm_iquants
void iqk_convert_iq4_xs_q8_k_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
int nb = n/QK_K;
const block_iq4_xs * x8[8];
block_q8_k_r8 * y = (block_q8_k_r8 *)vy;
auto values128 = _mm_loadu_si128((const __m128i *)iq4k_values);
auto values = MM256_SET_M128I(values128, values128);
int16_t ls[16];
float dnew[8];
uint32_t block[8];
__m256i xv[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_iq4_xs *)((const char *)vx + (ix + k)*bx);
for (int i = 0; i < nb; ++i) {
for (int k = 0; k < 8; ++k) {
float d = GGML_FP16_TO_FP32(x8[k][i].d);
for (int ib32 = 0; ib32 < 8; ++ib32) {
ls[2*ib32+0] = ls[2*ib32+1] = (((x8[k][i].scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((x8[k][i].scales_h >> 2*ib32) & 3) << 4)) - 32;
auto bits = _mm_loadu_si128((const __m128i *)x8[k][i].qs + ib32);
xv[ib32] = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(bits, 4), bits), _mm256_set1_epi8(0xf));
xv[ib32] = _mm256_shuffle_epi8(values, xv[ib32]);
}
dnew[k] = d * convert_to_q8_k_r8(k, 1.f/127, xv, ls, block, y[i].qs);
}
_mm_storeu_si128((__m128i *)y[i].d, _mm256_cvtps_ph(_mm256_loadu_ps(dnew), _MM_ROUND_NEAREST));
}
y += nb;
}
}
} // namespace
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
//: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ||
etypeA == GGML_TYPE_Q6_K ? GGML_TYPE_Q8_2_X4
//etypeA == GGML_TYPE_Q6_K || etypeA == GGML_TYPE_Q3_K ? GGML_TYPE_Q8_2_X4
//: etypeA == GGML_TYPE_Q4_K || etypeA == GGML_TYPE_Q5_K ? GGML_TYPE_Q8_2_X4
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
return false;
}
func16 = nullptr;
switch (typeA) {
case GGML_TYPE_Q2_K:
set_functions<DequantizerQ2K>(kernels);
break;
case GGML_TYPE_Q3_K:
set_functions<DequantizerQ3K>(kernels);
//IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ3K_AVX2, kernels);
break;
case GGML_TYPE_Q4_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ4K_AVX2, kernels);
//set_functions<DequantizerQ4K>(kernels);
break;
case GGML_TYPE_Q5_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_2_X4_T, DequantizerQ5K_AVX2, kernels);
//set_functions<DequantizerQ5K>(kernels);
break;
case GGML_TYPE_Q6_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qY_K_q8_2_X4_T, DequantizerQ6K_AVX2, kernels);
//set_functions<DequantizerQ6K>(kernels);
break;
case GGML_TYPE_IQ4_XS:
set_functions<DequantizerIQ4XS>(kernels);
break;
case GGML_TYPE_Q2_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q3_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q4_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q5_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q6_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
break;
case GGML_TYPE_IQ4_XS_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels)
break;
case GGML_TYPE_Q8_K_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_q8_k_r8_q8_k<16>;
#endif
break;
case GGML_TYPE_Q8_KV:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_q8_KV_q8_KV<16>;
#endif
break;
case GGML_TYPE_Q8_KV_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
break;
default:
return false;
}
return true;
}
bool iqk_convert_kquants_q8X_r8(int type, int n, const void * vx, size_t bx, void * vy, int nrc_x) {
switch (ggml_type(type)) {
case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
default: return false;
}
return true;
}
#else
// --------------------------------- __aarch64__ --------------------------------------
namespace {
template <typename Q8>
inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums8(iy, i);
int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
}
}
template <typename Q8>
inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums(iy, i);
int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
}
}
struct Scales8 {
uint32_t utmp[4];
const uint8_t * sc8 = (const uint8_t *)utmp;
template <typename Q8, typename Qx>
inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
make_q4_scales(x.scales, utmp);
int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
uint8x8_t scales8 = vld1_u8(sc8);
uint16x8_t scales16 = vmovl_u8(scales8);
int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
return scales;
}
};
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
return s8.process_scales_mins(x[i], q8, i, acc);
}
inline void prepare(int i, int j) {
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
else bits.prepare(x[i].qs+64*j);
}
Q4bits bits;
Scales8 s8;
};
struct HighBit5 {
const uint8x16_t mhb = vdupq_n_u8(0x10);
uint8x16x2_t bits;
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
if (do_shift) {
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
}
}
};
struct HighBit3 {
const uint8x16_t mhb = vdupq_n_u8(0x04);
uint8x16x2_t bits;
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
if (do_shift) {
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
}
}
};
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].qh);
return s8.process_scales_mins(x[i], q8, i, acc);
}
inline void prepare(int i, int j) {
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
else bits.prepare(x[i].qs+64*j);
h.apply(bits.b1, bits.b2, j == 0);
}
Q4bits bits;
HighBit5 h;
Scales8 s8;
uint8x16x2_t hbits;
};
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
int32x4x4_t scales = {
vmovl_s16(vget_low_s16 (scales16.val[0])),
vmovl_s16(vget_high_s16(scales16.val[0])),
vmovl_s16(vget_low_s16 (scales16.val[1])),
vmovl_s16(vget_high_s16(scales16.val[1])),
};
return scales;
}
template <typename Q8>
inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
accum_mins_16(scales16, q8, acc, i, c);
return make_wider(scales16);
}
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
}
inline void prepare(int i, int j) {
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
bits.prepare64(x[i].ql+64*j);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
}
Q4bits bits;
const uint8x16_t mhb = vdupq_n_u8(0x30);
};
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].hmask);
mask = vdupq_n_u8(0x01);
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
uint32_t aux2 = sc16[4] | (sc16[5] << 16);
aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
if (nrc > 1) {
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
}
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
return make_wider(scales16);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
if (nrc > 1) {
h.apply(bits.b1, bits.b2, j == 0);
} else {
auto minus4 = vdupq_n_u8(0xfc);
auto zero = vdupq_n_u8(0);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
}
}
uint32_t aux32[4];
Q2bits bits;
uint8x16_t mask;
HighBit3 h;
};
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return true; }
template <typename Q8>
inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
auto scales_and_mins = vld1q_u8(x[i].scales);
auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
}
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
process_scales(i, q8, acc);
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
return make_wider(scales16);
}
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
auto m1 = vdupq_n_u8(1);
auto shuffle = vdupq_n_u8(8*j);
bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
}
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
}
uint32_t aux32[4];
uint8x16_t scales8;
Q2bits bits;
};
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
static int8x16_t load_values() {
static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
return vld1q_s8(iq4nl_values);
}
DequantizerIQ4XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(load_values()) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
inline void new_row(int ix) { x = (const block_iq4_xs *)((const char *)vx + bx*ix); }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
(void)q8;
(void)acc;
d = GGML_FP16_TO_FP32(x[i].d);
const uint16_t scales_h = x[i].scales_h;
const uint16_t * scales_l = (const uint16_t *)x[i].scales_l;
aux32[0] = scales_l[0] | (scales_l[1] << 16);
aux32[1] = aux32[0] >> 4;
// scl is ordered as 0, 2, 4, 6, 1, 3, 5, 7
uint8x8_t scl8 = vand_u8(vld1_u8((const uint8_t *)aux32), vdup_n_u8(0xf));
uint16_t * aux16 = (uint16_t *)aux32;
aux16[0] = scales_h << 4; aux16[1] = scales_h << 2; aux16[2] = scales_h; aux16[3] = scales_h >> 2;
// sch is ordered as 0, 4, 1, 5, 2, 6, 3, 7
uint8x8_t sch8 = vand_u8(vld1_u8((const uint8_t *)aux16), vdup_n_u8(0x30));
int8x8_t scales8 = vadd_s8(vreinterpret_s8_u8(vorr_u8(scl8, vtbl1_u8(sch8, vreinterpret_u8_u32(hshuff)))), vdup_n_s8(-32));
// shuffle 0, 2, 4, 6, 1, 3, 5, 7 -> 0, 1, 2, 3, 4, 5, 6, 7
scales8 = vtbl1_s8(scales8, vreinterpret_s8_u32(hshuff));
int16x8_t scales16 = vmovl_s8(scales8);
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void prepare(int i, int j) {
bits.prepare16(x[i].qs+64*j);
//if (nrc == 1) {
// bits.prepare16_v2(x[i].qs+64*j);
//} else {
// bits.prepare16(x[i].qs+64*j);
//}
for (int k = 0; k < 4; ++k) {
bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b1.val[k]));
bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values, bits.b2.val[k]));
}
}
Q4bits bits;
const int8x16_t values;
uint32_t aux32[2];
constexpr static uint32x2_t hshuff = {0x05010400, 0x07030602};
};
IQK_ALWAYS_INLINE void prepare_q4_k_quants(const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
qx[0] = vandq_u8(bits.val[0], m4); // 0...3 from the 4 rows
qx[1] = vandq_u8(bits.val[1], m4); // 16..19
qx[2] = vandq_u8(bits.val[2], m4); // 4...7
qx[3] = vandq_u8(bits.val[3], m4); // 20..23
qx[4] = vshrq_n_u8(bits.val[0], 4); // 8..11
qx[5] = vshrq_n_u8(bits.val[1], 4); // 24..27
qx[6] = vshrq_n_u8(bits.val[2], 4); // 12..15
qx[7] = vshrq_n_u8(bits.val[3], 4); // 28..31
}
template <int nrc_y>
void mul_mat_q2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = vdupq_n_u8(0x0f);
auto m03 = vdupq_n_u8(0x03);
int nbl = n / QK_K;
int8x16_t qx[4];
float32x4_t acc[nrc_y] = {};
int16x8x4_t i16scales;
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q2_k_r4 * iq2 = (const block_q2_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
int32x4_t isum[nrc_y] = {};
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto m4 = vmulq_f32(vdupq_n_f32(-1.f), vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d+4)));
for (int is = 0; is < 2; ++is) {
auto sl = vld1q_u8_x2(iq2[ibl].scales + 32*is);
auto m = vshrq_n_u8(sl.val[0], 4);
i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
i16scales.val[1] = vmovl_u8(vget_high_u8(m));
m = vshrq_n_u8(sl.val[1], 4);
i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
i16scales.val[3] = vmovl_u8(vget_high_u8(m));
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vdupq_n_s32(0);
auto bsums = vld1q_s16(q8.y[iy][ibl].bsums + 8*is);
auto b8 = vget_low_s16(bsums);
//auto bsums = q8.load_bsums(iy, ibl);
//auto b8 = vget_low_s16(bsums.val[0]);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[0]), b8, 0);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[0]), b8, 1);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[1]), b8, 2);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[1]), b8, 3);
b8 = vget_high_s16(bsums);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[2]), b8, 0);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[2]), b8, 1);
sumi = vmlal_lane_s16(sumi, vget_low_s16 (i16scales.val[3]), b8, 2);
sumi = vmlal_lane_s16(sumi, vget_high_s16(i16scales.val[3]), b8, 3);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(m4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
}
m = vandq_u8(sl.val[0], mf);
i16scales.val[0] = vmovl_u8(vget_low_u8 (m));
i16scales.val[1] = vmovl_u8(vget_high_u8(m));
m = vandq_u8(sl.val[1], mf);
i16scales.val[2] = vmovl_u8(vget_low_u8 (m));
i16scales.val[3] = vmovl_u8(vget_high_u8(m));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[0], m03));
qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 2), m03));
qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 4), m03));
qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[0], 6), m03));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
qx[0] = vreinterpretq_s8_u8(vandq_u8( bits.val[1], m03));
qx[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 2), m03));
qx[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 4), m03));
qx[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(bits.val[1], 6), m03));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = vdupq_n_u8(0x0f);
auto m30 = vdupq_n_u8(0x30);
auto m32 = vdupq_n_s8(-32);
auto m03 = vdupq_n_u8(0x03);
auto m04 = vdupq_n_u8(0x04);
int nbl = n / QK_K;
int8x16_t qx[4];
float32x4_t acc[nrc_y] = {};
int8x16x4_t i8scales;
int16x8x4_t i16scales;
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q3_k_r4 * iq3 = (const block_q3_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
int32x4_t isum[nrc_y] = {};
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
auto sh = vld1q_u8(iq3[ibl].scales_h);
i8scales.val[0] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30)));
i8scales.val[1] = vaddq_s8(m32, vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(vshlq_n_u8(sh, 2), m30)));
i8scales.val[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m30)));
i8scales.val[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30)));
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
hbits = veorq_u8(hbits, vdupq_n_u8(0xff));
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[0], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 2))));
qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshlq_n_u8(hbits, 1))));
qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, hbits)));
qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 1))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
qx[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8( lbits.val[1], m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 2))));
qx[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 3))));
qx[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 4))));
qx[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03)), vreinterpretq_s8_u8(vandq_u8(m04, vshrq_n_u8(hbits, 5))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = vdupq_n_u8(0xf);
auto m3 = vdupq_n_u8(0x30);
int nbl = n / QK_K;
int8x16_t qx[8];
int8x16x2_t iscales;
int32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q4_k_r4 * iq4 = (const block_q4_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d+4));
m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
auto sh = vld1q_u8(iq4[ibl].scales_h);
iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m3));
iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3));
for (int is = 0; is < 2; ++is) {
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
float32x4x4_t fscales;
fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3);
}
}
iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m3));
iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m3));
int32x4_t isum[nrc_y] = {};
for (int is = 0; is < 2; ++is) {
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
prepare_q4_k_quants(mf, bits, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = vdupq_n_u8(0xf);
auto m30 = vdupq_n_u8(0x30);
auto m10 = vdupq_n_u8(0x10);
int nbl = n / QK_K;
int8x16_t qx[8];
int8x16x2_t iscales;
int32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_k_r4 * iq5 = (const block_q5_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
auto m4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d+4));
m4 = vmulq_f32(m4, vdupq_n_f32(-1.f));
auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
auto sh = vld1q_u8(iq5[ibl].scales_h);
iscales.val[0] = vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(vshlq_n_u8(sh, 2), m30));
iscales.val[1] = vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m30));
for (int is = 0; is < 2; ++is) {
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
float32x4x4_t fscales;
fscales.val[0] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
fscales.val[1] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
fscales.val[2] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
fscales.val[3] = vmulq_f32(m4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto m8 = vld1q_f32((const float *)q8.y[iy][ibl].bsums + 4*is);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[0], m8, 0);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[1], m8, 1);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[2], m8, 2);
acc[iy] = vmlaq_laneq_f32(acc[iy], fscales.val[3], m8, 3);
}
}
iscales.val[0] = vorrq_u8(vandq_u8(sl.val[0], mf), vandq_u8(vshlq_n_u8(sh, 4), m30));
iscales.val[1] = vorrq_u8(vandq_u8(sl.val[1], mf), vandq_u8(sh, m30));
int32x4_t isum[nrc_y] = {};
for (int is = 0; is < 2; ++is) {
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
auto hbits2 = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
auto hbits1 = vshlq_n_u8(hbits2, 4);
prepare_q4_k_quants(mf, lbits, qx);
qx[0] = vorrq_u8(qx[0], vandq_u8(m10, hbits1));
qx[1] = vorrq_u8(qx[1], vandq_u8(m10, hbits2));
qx[2] = vorrq_u8(qx[2], vandq_u8(m10, vshrq_n_u8(hbits1, 2)));
qx[3] = vorrq_u8(qx[3], vandq_u8(m10, vshrq_n_u8(hbits2, 2)));
qx[4] = vorrq_u8(qx[4], vandq_u8(m10, vshrq_n_u8(hbits1, 1)));
qx[5] = vorrq_u8(qx[5], vandq_u8(m10, vshrq_n_u8(hbits2, 1)));
qx[6] = vorrq_u8(qx[6], vandq_u8(m10, vshrq_n_u8(hbits1, 3)));
qx[7] = vorrq_u8(qx[7], vandq_u8(m10, vshrq_n_u8(hbits2, 3)));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q6_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto mf = vdupq_n_u8(0x0f);
auto m3 = vdupq_n_u8(0x30);
auto m32 = vdupq_n_s8(-32);
int nbl = n / QK_K;
int8x16_t qx[4];
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q6_k_r4 * iq6 = (const block_q6_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ibl].d));
int32x4_t isum[nrc_y] = {};
for (int is = 0; is < 2; ++is) {
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x4(iq6[ibl].ql + 256*is + 64*ib);
auto hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib);
auto iscales = vmovl_s8(vld1_s8(iq6[ibl].scales + 32*is + 8*ib));
auto scales = vmovl_s16(vget_low_s16(iscales));
qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[0], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[2], mf), vandq_u8(m3, hbits)));
qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
scales = vmovl_s16(vget_high_s16(iscales));
hbits = vld1q_u8(iq6[ibl].qh + 128*is + 32*ib + 16);
qx[0] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[1], mf), vandq_u8(m3, vshlq_n_u8(hbits, 4))));
qx[1] = vaddq_s8(m32, vorrq_u8(vandq_u8 (lbits.val[3], mf), vandq_u8(m3, hbits)));
qx[2] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m3, vshlq_n_u8(hbits, 2))));
qx[3] = vaddq_s8(m32, vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m3, vshrq_n_u8(hbits, 2))));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q8_k_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_k_r8 * iq8 = (const block_q8_k_r8 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4l = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+0));
auto d4h = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[ibl].d+4));
int32x4_t isum[2*nrc_y] = {};
for (int ib = 0; ib < QK_K/16; ++ib) {
auto q1 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 0);
auto q2 = vld1q_s8_x4(iq8[ibl].qs + 128*ib + 64);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+16*ib);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[0], y, 0);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[1], y, 0);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q1.val[2], y, 1);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q1.val[3], y, 1);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[0], y, 2);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[1], y, 2);
isum[2*iy+0] = vdotq_laneq_s32(isum[2*iy+0], q2.val[2], y, 3);
isum[2*iy+1] = vdotq_laneq_s32(isum[2*iy+1], q2.val[3], y, 3);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
const float * bsum = (const float *)q8.y[iy][ibl].bsums;
auto m8 = vdupq_n_f32(-128.f*bsum[0]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[2*iy+0]));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[2*iy+1]));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], d4l, m8);
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], d4l, m8);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m3 = vdupq_n_u8(0x30);
auto m32 = vdupq_n_s8(-32);
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[8];
int8x16x4_t iscales;
int32x4x2_t scales;
float32x4_t acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_iq4_xs_r8 * iq4 = (const block_iq4_xs_r8 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4_f16 = vld1q_f16((const float16_t *)iq4[ibl].d);
auto d4l = vcvt_f32_f16(vget_low_f16 (d4_f16));
auto d4h = vcvt_f32_f16(vget_high_f16(d4_f16));
auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
auto sh = vld1q_u8(iq4[ibl].scales_h);
iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
iscales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
iscales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
iscales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
scales.val[1] = vmovl_s16(vget_low_s16(iscales16_2));
for (int l = 0; l < 2; ++l) {
uint8x16x2_t bits;
bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l);
bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 32);
prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 64);
bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 96);
prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(d4l, d8), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[ib64]));
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[ib64]));
scales.val[0] = vmovl_s16(vget_high_s16(iscales16_1));
scales.val[1] = vmovl_s16(vget_high_s16(iscales16_2));
for (int l = 0; l < 2; ++l) {
uint8x16x2_t bits;
bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 16);
bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 48);
prepare_iq4_nl_quants_r8(values, m4, bits, qx+0);
bits.val[0] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l + 80);
bits.val[1] = vld1q_u8(iq4[ibl].qs + 256*ib64 + 128*l +112);
prepare_iq4_nl_quants_r8(values, m4, bits, qx+4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+64*ib64+32*l);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
isum[iy] = vmlaq_s32(isum[iy], sumi, scales.val[l]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto d8 = vdupq_n_f32(q8.scale(iy, ibl));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(d4h, d8), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%32 == 0);
int32x4_t acc[4] = {};
auto dptr = (const float *)info.src1_row(0);
const float dy = dptr[0];
auto q8y = (const int8_t *)(dptr + 2);
for (int ix = 0; ix < nrc_x; ++ix) {
auto dx = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dx + 2);
for (int i = 0; i < n/64; ++i) {
auto qx = vld1q_s8_x4(q8x + 64*i);
for (int j = 0; j < 4; ++j) {
acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 64*i + 16*j));
}
}
if (int i = 2*(n/64); i < n/32) {
auto qx = vld1q_s8_x2(q8x + 32*i);
for (int j = 0; j < 2; ++j) {
acc[j] = ggml_vdotq_s32(acc[j], qx.val[j], vld1q_s8(q8y + 32*i + 16*j));
}
}
acc[0] = vaddq_s32(acc[0], acc[1]);
acc[2] = vaddq_s32(acc[2], acc[3]);
acc[0] = vaddq_s32(acc[0], acc[2]);
info.store(ix, 0, dx[0]*dy*vaddvq_s32(acc[0]));
acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
}
}
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(n%16 == 0);
int8x16_t qx[4];
int32x4_t acc[nrc_y] = {};
float dy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[4];
float dx[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int kx = 0; kx < 4; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/16; ++i) {
for (int kx = 0; kx < 4; ++kx) qx[kx] = vld1q_s8(q8x[kx] + 16*i);
auto row01 = vtrnq_s32(qx[0], qx[1]);
auto row23 = vtrnq_s32(qx[2], qx[3]);
qx[0] = vtrn1q_s64(row01.val[0], row23.val[0]);
qx[1] = vtrn1q_s64(row01.val[1], row23.val[1]);
qx[2] = vtrn2q_s64(row01.val[0], row23.val[0]);
qx[3] = vtrn2q_s64(row01.val[1], row23.val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8y[iy] + 16*i);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[0], y, 0);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[1], y, 1);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[2], y, 2);
acc[iy] = vdotq_laneq_s32(acc[iy], qx[3], y, 3);
}
}
auto scales_x = vld1q_f32(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale = vmulq_f32(scales_x, vdupq_n_f32(dy[iy]));
info.store(ix, iy, vmulq_f32(scale, vcvtq_f32_s32(acc[iy])));
acc[iy] = vdupq_n_s32(0);
}
}
}
template <int nrc_y>
void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
int32x4_t acc[2*nrc_y] = {};
float dy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ix += 8) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dptr + 8);
for (int ib = 0; ib < n/16; ++ib) {
auto q1 = vld1q_s8_x4(q8x + 128*ib + 0);
auto q2 = vld1q_s8_x4(q8x + 128*ib + 64);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8y[iy]+16*ib);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[0], y, 0);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[1], y, 0);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q1.val[2], y, 1);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q1.val[3], y, 1);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[0], y, 2);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[1], y, 2);
acc[2*iy+0] = vdotq_laneq_s32(acc[2*iy+0], q2.val[2], y, 3);
acc[2*iy+1] = vdotq_laneq_s32(acc[2*iy+1], q2.val[3], y, 3);
}
}
auto scale1_x = vld1q_f32(dptr+0);
auto scale2_x = vld1q_f32(dptr+4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto scale_y = vdupq_n_f32(dy[iy]);
auto scale1 = vmulq_f32(scale1_x, scale_y);
auto scale2 = vmulq_f32(scale2_x, scale_y);
info.store(ix+0, iy, vmulq_f32(scale1, vcvtq_f32_s32(acc[2*iy+0])));
info.store(ix+4, iy, vmulq_f32(scale2, vcvtq_f32_s32(acc[2*iy+1])));
acc[2*iy+0] = acc[2*iy+1] = vdupq_n_s32(0.f);
}
}
}
}
bool iqk_convert_kquants_q8X_r8([[maybe_unused]] int type, [[maybe_unused]] int n, [[maybe_unused]] const void * vx, [[maybe_unused]] size_t bx, [[maybe_unused]] void * vy, [[maybe_unused]] int nrc_x) {
return false;
//switch (ggml_type(type)) {
// case GGML_TYPE_Q2_K: iqk_convert_q2_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q3_K: iqk_convert_q3_k_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q4_K: iqk_convert_q4_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q5_K: iqk_convert_q5_k_q8_1_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_Q6_K: iqk_convert_q6_k_q8_0_r8(n, vx, bx, vy, nrc_x); break;
// case GGML_TYPE_IQ4_XS: iqk_convert_iq4_xs_q8_k_r8(n, vx, bx, vy, nrc_x); break;
// default: return false;
//}
//return true;
}
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
//: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
return false;
}
func16 = nullptr;
switch (typeA) {
case GGML_TYPE_Q2_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ2K, kernels)
break;
case GGML_TYPE_Q3_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ3K, kernels)
break;
case GGML_TYPE_Q4_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ4K, kernels)
break;
case GGML_TYPE_Q5_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ5K, kernels)
break;
case GGML_TYPE_Q6_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ6K, kernels)
break;
case GGML_TYPE_IQ4_XS:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4XS, kernels)
break;
case GGML_TYPE_Q2_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q3_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q4_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q5_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
break;
case GGML_TYPE_Q6_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
break;
case GGML_TYPE_IQ4_XS_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k, kernels)
break;
case GGML_TYPE_Q8_K_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
break;
case GGML_TYPE_Q8_KV:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
kernels[0] = mul_mat_q8_KV_q8_KV_1;
func16 = mul_mat_q8_KV_q8_KV<16>;
break;
case GGML_TYPE_Q8_KV_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
break;
default:
return false;
}
return true;
}
#endif
namespace {
#ifdef __AVX2__
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%32 == 0);
if (nrc_y == 1 && nrc_x == 1) {
auto dx = (const float *)vx;
auto dy = (const float *)info.src1_row(0);
#ifdef HAVE_FANCY_SIMD
auto sy = (const int32_t *)(dy + 1);
auto x = (const int8_t *)(dx + 2);
auto y = (const int8_t *)(dy + 2);
auto isum = _mm512_setzero_si512();
for (int i = 0; i < n/64; ++i) {
auto qx = _mm512_loadu_si512((const __m512i *)x + i);
auto qy = _mm512_loadu_si512((const __m512i *)y + i);
isum = _mm512_dpbusd_epi32(isum, _mm512_add_epi8(qx, _mm512_set1_epi8(127)), qy);
}
auto isum256 = _mm256_add_epi32(_mm512_castsi512_si256(isum), _mm512_extracti32x8_epi32(isum, 1));
for (int i = 2*(n/64); i < n/32; ++i) {
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
isum256 = _mm256_dpbusd_epi32(isum256, _mm256_add_epi8(qx, _mm256_set1_epi8(127)), qy);
}
info.store(0, 0, dx[0]*dy[0]*(hsum_i32_8(isum256) - 127*sy[0]));
#else
auto x = (const int8_t *)(dx + 2);
auto y = (const int8_t *)(dy + 2);
auto isum = _mm256_setzero_si256();
for (int i = 0; i < n/32; ++i) {
auto qx = _mm256_loadu_si256((const __m256i *)x + i);
auto qy = _mm256_loadu_si256((const __m256i *)y + i);
auto dot = _mm256_maddubs_epi16(_mm256_sign_epi8(qx, qx), _mm256_sign_epi8(qy, qx));
isum = _mm256_add_epi32(isum, _mm256_madd_epi16(_mm256_set1_epi16(1), dot));
}
info.store(0, 0, dx[0]*dy[0]*hsum_i32_8(isum));
#endif
return;
}
__m256i qx[2];
__m256i acc[2*nrc_y] = {};
float dy[nrc_y];
#ifdef HAVE_FANCY_SIMD
int32_t sy[nrc_y];
#else
__m256i sx[2];
auto m1 = _mm256_set1_epi16(1);
#endif
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
#ifdef HAVE_FANCY_SIMD
auto iptr = (const int32_t *)(dptr+1);
sy[iy] = -127*iptr[0];
#endif
q8y[iy] = (const int8_t *)(dptr + 2);
}
for (int ix = 0; ix < nrc_x; ++ix) {
auto dx = (const float *)((const char *)vx + ix*bx);
auto q8x = (const int8_t *)(dx + 2);
for (int i = 0; i < n/64; ++i) {
for (int j = 0; j < 2; ++j) {
#ifdef HAVE_FANCY_SIMD
qx[j] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + 2*i + j), _mm256_set1_epi8(127));
#else
qx[j] = _mm256_loadu_si256((const __m256i *)q8x + 2*i + j);
sx[j] = _mm256_sign_epi8(qx[j], qx[j]);
#endif
}
for (int iy = 0; iy < nrc_y; ++iy) {
for (int j = 0; j < 2; ++j) {
#ifdef HAVE_FANCY_SIMD
acc[2*iy+j] = _mm256_dpbusd_epi32(acc[2*iy+j], qx[j], _mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j));
#else
auto dot = _mm256_maddubs_epi16(sx[j], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + 2*i + j), qx[j]));
acc[2*iy+j] = _mm256_add_epi32(acc[2*iy+j], _mm256_madd_epi16(m1, dot));
#endif
}
}
}
if (int i = 2*(n/64); i < n/32) {
#ifdef HAVE_FANCY_SIMD
qx[0] = _mm256_add_epi8(_mm256_loadu_si256((const __m256i *)q8x + i), _mm256_set1_epi8(127));
#else
qx[0] = _mm256_loadu_si256((const __m256i *)q8x + i);
sx[0] = _mm256_sign_epi8(qx[0], qx[0]);
#endif
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
acc[2*iy] = _mm256_dpbusd_epi32(acc[2*iy], qx[0], _mm256_loadu_si256((const __m256i *)q8y[iy] + i));
#else
auto dot = _mm256_maddubs_epi16(sx[0], _mm256_sign_epi8(_mm256_loadu_si256((const __m256i *)q8y[iy] + i), qx[0]));
acc[2*iy] = _mm256_add_epi32(acc[2*iy], _mm256_madd_epi16(m1, dot));
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = hsum_i32_8(_mm256_add_epi32(acc[2*iy], acc[2*iy+1]));
#ifdef HAVE_FANCY_SIMD
info.store(ix, iy, dx[0]*dy[iy]*(sumi+sy[iy]));
#else
info.store(ix, iy, dx[0]*dy[iy]*sumi);
#endif
acc[2*iy] = acc[2*iy+1] = _mm256_setzero_si256();
}
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
GGML_ASSERT(n%32 == 0);
__m512i qx[4];
__m512i acc[nrc_y <= 4 ? 2*nrc_y : nrc_y] = {};
float dy[nrc_y];
int32_t sy[nrc_y];
const int8_t * q8y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
dy[iy] = dptr[0];
auto iptr = (const int32_t *)(dptr + 1);
sy[iy] = -64*iptr[0];
q8y[iy] = (const int8_t *)(dptr + 2);
}
const int8_t * q8x[8];
float dx[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int kx = 0; kx < 8; ++kx) {
auto dptr = (const float *)((const char *)vx + (ix+kx)*bx);
dx[kx] = dptr[0];
q8x[kx] = (const int8_t *)(dptr + 2);
}
for (int i = 0; i < n/32; ++i) {
for (int kx = 0; kx < 4; ++kx) {
qx[kx] = _mm512_inserti32x8(_mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)q8x[kx+0] + i)),
_mm256_loadu_si256((const __m256i *)q8x[kx+4] + i), 1);
}
auto t0 = _mm512_unpacklo_epi32(qx[0], qx[1]);
auto t1 = _mm512_unpacklo_epi32(qx[2], qx[3]);
auto t2 = _mm512_unpackhi_epi32(qx[0], qx[1]);
auto t3 = _mm512_unpackhi_epi32(qx[2], qx[3]);
qx[0] = _mm512_xor_si512(_mm512_unpacklo_epi64(t0, t1), _mm512_set1_epi8(-128));
qx[1] = _mm512_xor_si512(_mm512_unpackhi_epi64(t0, t1), _mm512_set1_epi8(-128));
qx[2] = _mm512_xor_si512(_mm512_unpacklo_epi64(t2, t3), _mm512_set1_epi8(-128));
qx[3] = _mm512_xor_si512(_mm512_unpackhi_epi64(t2, t3), _mm512_set1_epi8(-128));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y256 = _mm256_loadu_si256((const __m256i *)q8y[iy] + i);
auto y = _mm512_inserti32x8(_mm512_castsi256_si512(y256), y256, 1);
if constexpr (nrc_y <= 4) {
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
} else {
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
}
}
auto scales_x = _mm256_loadu_ps(dx);
for (int iy = 0; iy < nrc_y; ++iy) {
if constexpr (nrc_y <= 4) {
auto ss = _mm512_add_epi32(_mm512_add_epi32(acc[2*iy+0], acc[2*iy+1]), _mm512_set1_epi32(sy[iy]));
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 0), _mm512_extracti32x4_epi32(ss, 1));
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(ss, 2), _mm512_extracti32x4_epi32(ss, 3));
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
} else {
acc[iy] = _mm512_add_epi32(acc[iy], _mm512_set1_epi32(sy[iy]));
auto sum1 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 0), _mm512_extracti32x4_epi32(acc[iy], 1));
auto sum2 = _mm_add_epi32(_mm512_extracti32x4_epi32(acc[iy], 2), _mm512_extracti32x4_epi32(acc[iy], 3));
auto scale = _mm256_mul_ps(scales_x, _mm256_set1_ps(dy[iy]));
info.store(ix+0, iy, _mm_mul_ps(_mm256_castps256_ps128(scale), _mm_cvtepi32_ps(sum1)));
info.store(ix+4, iy, _mm_mul_ps(_mm256_extractf128_ps(scale, 1), _mm_cvtepi32_ps(sum2)));
acc[iy] = _mm512_setzero_si512();
}
}
}
}
#endif
#endif
template <int k_step>
inline std::pair<mul_mat_t, int> mul_mat_kernel([[maybe_unused]] int D, int int_typeA, int nq) {
auto typeA = ggml_type(int_typeA);
constexpr int kMaxQ = 8;
#define MAKE_FUNCS(mul_mat, n) \
if (n >= kMaxQ) return std::make_pair(mul_mat, kMaxQ>, kMaxQ);\
else {\
switch (n) {\
case 1: return std::make_pair(mul_mat, 1>, 1);\
case 2: return std::make_pair(mul_mat, 2>, 2);\
case 3: return std::make_pair(mul_mat, 3>, 3);\
case 4: return std::make_pair(mul_mat, 4>, 4);\
case 5: return std::make_pair(mul_mat, 5>, 5);\
case 6: return std::make_pair(mul_mat, 6>, 6);\
case 7: return std::make_pair(mul_mat, 7>, 7);\
}\
}
#define MAKE_FUNCS_ONLY_NRC(mul_mat, n) \
if (n >= kMaxQ) return std::make_pair(mul_mat<kMaxQ>, kMaxQ);\
else {\
switch (n) {\
case 1: return std::make_pair(mul_mat<1>, 1);\
case 2: return std::make_pair(mul_mat<2>, 2);\
case 3: return std::make_pair(mul_mat<3>, 3);\
case 4: return std::make_pair(mul_mat<4>, 4);\
case 5: return std::make_pair(mul_mat<5>, 5);\
case 6: return std::make_pair(mul_mat<6>, 6);\
case 7: return std::make_pair(mul_mat<7>, 7);\
}\
}
if (typeA == GGML_TYPE_Q8_KV) {
#ifdef __aarch64__
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1, 1);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#else
if (nq == 1) return std::make_pair(mul_mat_q8_KV_q8_KV_1<1>, 1);
#ifdef HAVE_FANCY_SIMD
if (D%32 == 0 && k_step%8 == 0) {
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV_8<16>, 16);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV_8, nq);
} else {
if (nq%16 == 0) return std::make_pair(mul_mat_q8_KV_q8_KV<16>, 16);
}
#endif
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_q8_KV, nq);
#endif
}
else if (typeA == GGML_TYPE_Q8_KV_R8) {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_KV_r8_q8_KV, nq);
}
GGML_ABORT("Fatal error");
}
inline std::pair<mul_mat_t, int> mul_mat_kernel(int D, int int_typeA, int nq, int k_step) {
switch (k_step) {
case 32: return mul_mat_kernel< 32>(D, int_typeA, nq);
case 64: return mul_mat_kernel< 64>(D, int_typeA, nq);
case 128: return mul_mat_kernel<128>(D, int_typeA, nq);
default: GGML_ABORT("Fatal error");
}
}
}
void iqk_gemm_q8kv_fa(int D, int nq, int type_k, const char * k, size_t stride_k, DataInfo& info, int k_step) {
auto [mul_mat, nrc_q] = mul_mat_kernel(D, type_k, nq, k_step);
for (int iq = 0; iq < nq/nrc_q; ++iq) {
mul_mat(D, k, stride_k, info, k_step);
info.cur_y += nrc_q;
}
int iq = nrc_q*(nq/nrc_q);
if (iq < nq) {
auto [mul_mat1, nrc_q1] = mul_mat_kernel(D, type_k, nq - iq, k_step);
GGML_ASSERT(nrc_q1 == nq - iq);
mul_mat1(D, k, stride_k, info, k_step);
}
}
#endif