Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4

This commit is contained in:
Iwan Kawrakow
2025-05-17 19:53:48 +03:00
parent d66ec60836
commit 7868545062
2 changed files with 526 additions and 522 deletions

View File

@@ -1043,6 +1043,492 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
}
}
template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
const float * dptr = (const float *)info.src1_row(iy);
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
y[iy] = (const int8_t *)(dptr + 8);
}
}
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
float d[8*nrc_y];
const int8_t * y[nrc_y];
};
struct DequantizerIQ1BN {
const __m256i m1_8 = _mm256_set1_epi8(1);
static __m256i load_shuffle(int i) {
static const uint8_t data[128] = {
0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
};
return _mm256_loadu_si256((const __m256i*)data + i);
}
const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
const __m256i mult[4] = {
_mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
_mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
_mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
_mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
};
const __m256i m3 = _mm256_set1_epi16(3);
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
#endif
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
auto data = MM256_SET_M128I(data128, data128);
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
#else
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
#endif
}
};
template <int nrc_y>
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
DequantizerIQ1BN deq;
__m256i accd[nrc_y];
__m256i val[4];
#ifndef HAVE_FANCY_SIMD
const auto m1_16 = _mm256_set1_epi16(1);
#endif
const block_iq1_bn * x;
const char * cx0 = (const char *)vx;
float scale;
ggml_half d16;
for (int ix = 0; ix < nrc_x; ++ix) {
const char * cx = cx0 + ix*bx;
std::memcpy(&d16, cx, sizeof(d16));
scale = GGML_FP16_TO_FP32(d16);
cx += sizeof(d16);
x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
#ifdef HAVE_FANCY_SIMD
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
#endif
}
accd[0] = _mm256_add_epi32(acc1, acc2);
}
else {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i, 0)),
val[1], q8.load_quants(iy, i, 1)),
val[2], q8.load_quants(iy, i, 2)),
val[3], q8.load_quants(iy, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
#endif
}
}
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
info.store(ix, iy, scale*hsum_float_4(sumf));
}
}
}
struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
}
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
val[0] = _mm256_and_si256(q2_1, mask2);
val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
}
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
}
const __m256i m1_8 = _mm256_set1_epi8(1);
const __m256i mf_8 = _mm256_set1_epi8(16);
const __m256i mask2 = _mm256_set1_epi8(0x03);
const __m256i mask3 = _mm256_set1_epi8(0x30);
};
template <int nrc_y>
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
DequantizeIQ2BN deq(vx, bx);
__m256i accd[nrc_y];
__m256i val[4];
#ifndef HAVE_FANCY_SIMD
const auto m1_16 = _mm256_set1_epi16(1);
#endif
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
if constexpr (nrc_y == 1) {
__m256i acc[2] = {};
for (int i = 0; i < nb/2; ++i) {
deq.prepare4(i, val);
#ifdef HAVE_FANCY_SIMD
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
val[1], q8.load_quants(0, i, 1));
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
val[3], q8.load_quants(0, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
#endif
}
accd[0] = _mm256_add_epi32(acc[0], acc[1]);
}
else {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare4(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
_mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
_mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare2(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
val[1], q8.load_quants(iy, i/2, 1));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
info.store(ix, iy, deq.d*hsum_float_4(sumf));
}
}
}
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
Q8_16<nrc_y> q8(info);
auto m3 = _mm256_set1_epi8(0x3);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK_IQ1BN;
__m256i qx[4];
if constexpr (nrc_y > 4) {
__m256i acc[nrc_y] = {};
__m128 sum4[nrc_y];
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = _mm_loadu_ps(dptr);
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+0);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
acc[iy] = _mm256_setzero_si256();
}
for (int ib = 0; ib < nb; ++ib) {
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+1);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
info.store(ix, iy, s4);
acc[iy] = _mm256_setzero_si256();
}
}
} else {
__m256i acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = _mm_loadu_ps(dptr);
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+0);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+1);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
info.store(ix, iy, sum4);
acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
}
}
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
if constexpr (nrc_y == 1) {
mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
} else {
Q8_16<nrc_y> q8(info);
auto m3 = _mm512_set1_epi8(0x3);
int nb = n / QK_IQ1BN;
__m512i acc[2*nrc_y] = {};
__m512i qx[8];
for (int ix = 0; ix < nrc_x/8; ++ix) {
const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
auto dl = _mm_loadu_ps(dptr1);
auto dh = _mm_loadu_ps(dptr2);
const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
qx[0] = _mm512_and_si512(bits_l, m3);
qx[1] = _mm512_and_si512(bits_h, m3);
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants64(iy, ib);
auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
__m128 sum4;
for (int k = 0; k < 2; ++k) {
const auto& dx = k == 0 ? dl : dh;
auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
info.store(8*ix + 4*k, iy, sum4);
}
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
}
}
if (int ix = 8*(nrc_x/8); ix < nrc_x) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = _mm_loadu_ps(dptr);
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
qx[0] = _mm512_and_si512(bits_l, m3);
qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants64(iy, ib);
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf = _mm512_cvtepi32_ps(acc[iy]);
auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
info.store(ix, iy, sum4);
}
}
}
}
#else
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
}
#endif
} // namespace
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
@@ -1095,6 +1581,43 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
func16 = mul_mat_iq1_m_r4_q8_0<16>;
#endif
break;
case GGML_TYPE_IQ1_BN:
assert (ne00 % QK_IQ1BN == 0);
funcs[0] = mul_mat_iq1bn_q8_K64<1>;
funcs[1] = mul_mat_iq1bn_q8_K64<2>;
funcs[2] = mul_mat_iq1bn_q8_K64<3>;
funcs[3] = mul_mat_iq1bn_q8_K64<4>;
funcs[4] = mul_mat_iq1bn_q8_K64<5>;
funcs[5] = mul_mat_iq1bn_q8_K64<6>;
funcs[6] = mul_mat_iq1bn_q8_K64<7>;
funcs[7] = mul_mat_iq1bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
assert (ne00 % QK_IQ1BN == 0);
funcs[0] = mul_mat_iq2bn_q8_K64<1>;
funcs[1] = mul_mat_iq2bn_q8_K64<2>;
funcs[2] = mul_mat_iq2bn_q8_K64<3>;
funcs[3] = mul_mat_iq2bn_q8_K64<4>;
funcs[4] = mul_mat_iq2bn_q8_K64<5>;
funcs[5] = mul_mat_iq2bn_q8_K64<6>;
funcs[6] = mul_mat_iq2bn_q8_K64<7>;
funcs[7] = mul_mat_iq2bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN_R4:
assert (ne00 % QK_IQ1BN == 0);
funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
expected_typeB = GGML_TYPE_Q8_K16;
break;
default:
return false;
}

View File

@@ -1185,228 +1185,6 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
#endif // Zen4 or vanilla AVX2
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
Q8_16<nrc_y> q8(info);
auto m3 = _mm256_set1_epi8(0x3);
auto m1 = _mm256_set1_epi16(1);
int nb = n / QK_IQ1BN;
__m256i qx[4];
if constexpr (nrc_y > 4) {
__m256i acc[nrc_y] = {};
__m128 sum4[nrc_y];
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = _mm_loadu_ps(dptr);
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+0);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
acc[iy] = _mm256_setzero_si256();
}
for (int ib = 0; ib < nb; ++ib) {
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+1);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
info.store(ix, iy, s4);
acc[iy] = _mm256_setzero_si256();
}
}
} else {
__m256i acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = _mm_loadu_ps(dptr);
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+0);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
qx[0] = _mm256_and_si256(bits, m3);
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants(iy, 2*ib+1);
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
info.store(ix, iy, sum4);
acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
}
}
}
}
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
if constexpr (nrc_y == 1) {
mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
} else {
Q8_16<nrc_y> q8(info);
auto m3 = _mm512_set1_epi8(0x3);
int nb = n / QK_IQ1BN;
__m512i acc[2*nrc_y] = {};
__m512i qx[8];
for (int ix = 0; ix < nrc_x/8; ++ix) {
const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
auto dl = _mm_loadu_ps(dptr1);
auto dh = _mm_loadu_ps(dptr2);
const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
qx[0] = _mm512_and_si512(bits_l, m3);
qx[1] = _mm512_and_si512(bits_h, m3);
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants64(iy, ib);
auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
__m128 sum4;
for (int k = 0; k < 2; ++k) {
const auto& dx = k == 0 ? dl : dh;
auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
info.store(8*ix + 4*k, iy, sum4);
}
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
}
}
if (int ix = 8*(nrc_x/8); ix < nrc_x) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = _mm_loadu_ps(dptr);
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
qx[0] = _mm512_and_si512(bits_l, m3);
qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants64(iy, ib);
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
auto sumf = _mm512_cvtepi32_ps(acc[iy]);
auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
info.store(ix, iy, sum4);
}
}
}
}
#else
template <int nrc_y>
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
}
#endif
#ifdef HAVE_FANCY_SIMD
template <int nrc_y>
static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -5127,268 +4905,6 @@ static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data
}
}
template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
const float * dptr = (const float *)info.src1_row(iy);
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
y[iy] = (const int8_t *)(dptr + 8);
}
}
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
float d[8*nrc_y];
const int8_t * y[nrc_y];
};
struct DequantizerIQ1BN {
const __m256i m1_8 = _mm256_set1_epi8(1);
static __m256i load_shuffle(int i) {
static const uint8_t data[128] = {
0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
};
return _mm256_loadu_si256((const __m256i*)data + i);
}
const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
const __m256i mult[4] = {
_mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
_mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
_mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
_mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
};
const __m256i m3 = _mm256_set1_epi16(3);
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
#endif
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
auto data = MM256_SET_M128I(data128, data128);
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
#else
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
#endif
}
};
template <int nrc_y>
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
DequantizerIQ1BN deq;
__m256i accd[nrc_y];
__m256i val[4];
#ifndef HAVE_FANCY_SIMD
const auto m1_16 = _mm256_set1_epi16(1);
#endif
const block_iq1_bn * x;
const char * cx0 = (const char *)vx;
float scale;
ggml_half d16;
for (int ix = 0; ix < nrc_x; ++ix) {
const char * cx = cx0 + ix*bx;
std::memcpy(&d16, cx, sizeof(d16));
scale = GGML_FP16_TO_FP32(d16);
cx += sizeof(d16);
x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
#ifdef HAVE_FANCY_SIMD
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
#endif
}
accd[0] = _mm256_add_epi32(acc1, acc2);
}
else {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i, 0)),
val[1], q8.load_quants(iy, i, 1)),
val[2], q8.load_quants(iy, i, 2)),
val[3], q8.load_quants(iy, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
#endif
}
}
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
info.store(ix, iy, scale*hsum_float_4(sumf));
}
}
}
struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
}
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
val[0] = _mm256_and_si256(q2_1, mask2);
val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
}
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
}
const __m256i m1_8 = _mm256_set1_epi8(1);
const __m256i mf_8 = _mm256_set1_epi8(16);
const __m256i mask2 = _mm256_set1_epi8(0x03);
const __m256i mask3 = _mm256_set1_epi8(0x30);
};
template <int nrc_y>
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
DequantizeIQ2BN deq(vx, bx);
__m256i accd[nrc_y];
__m256i val[4];
#ifndef HAVE_FANCY_SIMD
const auto m1_16 = _mm256_set1_epi16(1);
#endif
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
if constexpr (nrc_y == 1) {
__m256i acc[2] = {};
for (int i = 0; i < nb/2; ++i) {
deq.prepare4(i, val);
#ifdef HAVE_FANCY_SIMD
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
val[1], q8.load_quants(0, i, 1));
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
val[3], q8.load_quants(0, i, 3));
#else
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
#endif
}
accd[0] = _mm256_add_epi32(acc[0], acc[1]);
}
else {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
for (int i = 0; i < nb/2; ++i) {
deq.prepare4(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
_mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
_mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare2(i, val);
for (int iy = 0; iy < nrc_y; ++iy) {
#ifdef HAVE_FANCY_SIMD
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
val[1], q8.load_quants(iy, i/2, 1));
#else
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
#endif
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto vd = q8.scale(iy);
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
info.store(ix, iy, deq.d*hsum_float_4(sumf));
}
}
}
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
#ifdef HAVE_FANCY_SIMD
m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
@@ -5445,44 +4961,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs) : false;
case GGML_TYPE_IQ1_BN:
assert (ne00 % QK_IQ1BN == 0);
mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
assert (ne00 % QK_IQ1BN == 0);
mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>;
mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>;
mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>;
mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>;
mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>;
mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>;
mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>;
mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN_R4:
assert (ne00 % QK_IQ1BN == 0);
mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
//#ifdef HAVE_FANCY_SIMD
mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
//#endif
expected_typeB = GGML_TYPE_Q8_K16;
break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@@ -5824,6 +5302,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_BN_R4:
return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
default: