mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 23:54:10 +00:00
Refactor iqk: Factor out GEMM for iq1_bn, iq2_bn, iq2_bn_r4
This commit is contained in:
@@ -1043,6 +1043,492 @@ static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataI
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc> struct Q8_K64 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_K64(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const float * dptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
|
||||
y[iy] = (const int8_t *)(dptr + 8);
|
||||
}
|
||||
}
|
||||
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
|
||||
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
|
||||
inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
|
||||
|
||||
float d[8*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
static __m256i load_shuffle(int i) {
|
||||
static const uint8_t data[128] = {
|
||||
0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
|
||||
3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
|
||||
6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
|
||||
9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
|
||||
};
|
||||
return _mm256_loadu_si256((const __m256i*)data + i);
|
||||
}
|
||||
const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
|
||||
const __m256i mult[4] = {
|
||||
_mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
};
|
||||
const __m256i m3 = _mm256_set1_epi16(3);
|
||||
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
|
||||
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
#endif
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
|
||||
auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
|
||||
auto data = MM256_SET_M128I(data128, data128);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
|
||||
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
|
||||
v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
|
||||
v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
|
||||
#else
|
||||
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
|
||||
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizerIQ1BN deq;
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
const block_iq1_bn * x;
|
||||
const char * cx0 = (const char *)vx;
|
||||
float scale;
|
||||
ggml_half d16;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
const char * cx = cx0 + ix*bx;
|
||||
std::memcpy(&d16, cx, sizeof(d16));
|
||||
scale = GGML_FP16_TO_FP32(d16);
|
||||
cx += sizeof(d16);
|
||||
x = (const block_iq1_bn *)cx;
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
|
||||
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
|
||||
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
}
|
||||
accd[0] = _mm256_add_epi32(acc1, acc2);
|
||||
}
|
||||
else {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)),
|
||||
val[1], q8.load_quants(iy, i, 1)),
|
||||
val[2], q8.load_quants(iy, i, 2)),
|
||||
val[3], q8.load_quants(iy, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
|
||||
info.store(ix, iy, scale*hsum_float_4(sumf));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
|
||||
DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
|
||||
auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
|
||||
auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
|
||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
|
||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
|
||||
}
|
||||
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
|
||||
val[0] = _mm256_and_si256(q2_1, mask2);
|
||||
val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
|
||||
}
|
||||
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
|
||||
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
|
||||
}
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
const __m256i mf_8 = _mm256_set1_epi8(16);
|
||||
const __m256i mask2 = _mm256_set1_epi8(0x03);
|
||||
const __m256i mask3 = _mm256_set1_epi8(0x30);
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizeIQ2BN deq(vx, bx);
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc[2] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare4(i, val);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
|
||||
val[1], q8.load_quants(0, i, 1));
|
||||
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
|
||||
val[3], q8.load_quants(0, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
|
||||
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
|
||||
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
}
|
||||
accd[0] = _mm256_add_epi32(acc[0], acc[1]);
|
||||
}
|
||||
else {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare4(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
|
||||
val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
deq.prepare2(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
|
||||
val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
|
||||
info.store(ix, iy, deq.d*hsum_float_4(sumf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
Q8_16<nrc_y> q8(info);
|
||||
auto m3 = _mm256_set1_epi8(0x3);
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
int nb = n / QK_IQ1BN;
|
||||
__m256i qx[4];
|
||||
if constexpr (nrc_y > 4) {
|
||||
__m256i acc[nrc_y] = {};
|
||||
__m128 sum4[nrc_y];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+0);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
|
||||
auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
|
||||
sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+1);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
|
||||
auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
|
||||
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
|
||||
info.store(ix, iy, s4);
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+0);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+1);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
|
||||
auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
|
||||
auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
|
||||
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
|
||||
info.store(ix, iy, sum4);
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
if constexpr (nrc_y == 1) {
|
||||
mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
|
||||
} else {
|
||||
Q8_16<nrc_y> q8(info);
|
||||
auto m3 = _mm512_set1_epi8(0x3);
|
||||
int nb = n / QK_IQ1BN;
|
||||
__m512i acc[2*nrc_y] = {};
|
||||
__m512i qx[8];
|
||||
for (int ix = 0; ix < nrc_x/8; ++ix) {
|
||||
const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
|
||||
const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
|
||||
auto dl = _mm_loadu_ps(dptr1);
|
||||
auto dh = _mm_loadu_ps(dptr2);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
|
||||
const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
|
||||
auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
|
||||
qx[0] = _mm512_and_si512(bits_l, m3);
|
||||
qx[1] = _mm512_and_si512(bits_h, m3);
|
||||
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
|
||||
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
|
||||
qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
|
||||
qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
|
||||
qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
|
||||
qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants64(iy, ib);
|
||||
auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
|
||||
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
|
||||
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
|
||||
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
__m128 sum4;
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
const auto& dx = k == 0 ? dl : dh;
|
||||
auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
|
||||
sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
|
||||
sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
|
||||
info.store(8*ix + 4*k, iy, sum4);
|
||||
}
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
if (int ix = 8*(nrc_x/8); ix < nrc_x) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
|
||||
qx[0] = _mm512_and_si512(bits_l, m3);
|
||||
qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
|
||||
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
|
||||
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants64(iy, ib);
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf = _mm512_cvtepi32_ps(acc[iy]);
|
||||
auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
|
||||
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
|
||||
info.store(ix, iy, sum4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace
|
||||
|
||||
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
|
||||
@@ -1095,6 +1581,43 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
|
||||
func16 = mul_mat_iq1_m_r4_q8_0<16>;
|
||||
#endif
|
||||
break;
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
funcs[0] = mul_mat_iq1bn_q8_K64<1>;
|
||||
funcs[1] = mul_mat_iq1bn_q8_K64<2>;
|
||||
funcs[2] = mul_mat_iq1bn_q8_K64<3>;
|
||||
funcs[3] = mul_mat_iq1bn_q8_K64<4>;
|
||||
funcs[4] = mul_mat_iq1bn_q8_K64<5>;
|
||||
funcs[5] = mul_mat_iq1bn_q8_K64<6>;
|
||||
funcs[6] = mul_mat_iq1bn_q8_K64<7>;
|
||||
funcs[7] = mul_mat_iq1bn_q8_K64<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
funcs[0] = mul_mat_iq2bn_q8_K64<1>;
|
||||
funcs[1] = mul_mat_iq2bn_q8_K64<2>;
|
||||
funcs[2] = mul_mat_iq2bn_q8_K64<3>;
|
||||
funcs[3] = mul_mat_iq2bn_q8_K64<4>;
|
||||
funcs[4] = mul_mat_iq2bn_q8_K64<5>;
|
||||
funcs[5] = mul_mat_iq2bn_q8_K64<6>;
|
||||
funcs[6] = mul_mat_iq2bn_q8_K64<7>;
|
||||
funcs[7] = mul_mat_iq2bn_q8_K64<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN_R4:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
|
||||
funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
|
||||
funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
|
||||
funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
|
||||
funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
|
||||
funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
|
||||
funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
|
||||
funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K16;
|
||||
break;
|
||||
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1185,228 +1185,6 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
|
||||
|
||||
#endif // Zen4 or vanilla AVX2
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16_avx2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
Q8_16<nrc_y> q8(info);
|
||||
auto m3 = _mm256_set1_epi8(0x3);
|
||||
auto m1 = _mm256_set1_epi16(1);
|
||||
int nb = n / QK_IQ1BN;
|
||||
__m256i qx[4];
|
||||
if constexpr (nrc_y > 4) {
|
||||
__m256i acc[nrc_y] = {};
|
||||
__m128 sum4[nrc_y];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+0);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
|
||||
auto s4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), s4);
|
||||
sum4[iy] = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), s4);
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+1);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[iy] = _mm256_add_epi32(acc[iy], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[iy]);
|
||||
auto s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4[iy]);
|
||||
s4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), s4);
|
||||
info.store(ix, iy, s4);
|
||||
acc[iy] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
__m256i acc[2*nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+0);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+0);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[2*iy+0] = _mm256_add_epi32(acc[2*iy+0], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
bits = _mm256_loadu_si256((const __m256i *)iq2l + 2*ib+1);
|
||||
qx[0] = _mm256_and_si256(bits, m3);
|
||||
qx[1] = _mm256_and_si256(_mm256_srli_epi16(bits, 2), m3);
|
||||
qx[2] = _mm256_and_si256(_mm256_srli_epi16(bits, 4), m3);
|
||||
qx[3] = _mm256_and_si256(_mm256_srli_epi16(bits, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants(iy, 2*ib+1);
|
||||
auto sumi1 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[0], _mm256_shuffle_epi32(y, 0x00)),
|
||||
_mm256_maddubs_epi16(qx[1], _mm256_shuffle_epi32(y, 0x55)));
|
||||
auto sumi2 = _mm256_add_epi16(_mm256_maddubs_epi16(qx[2], _mm256_shuffle_epi32(y, 0xaa)),
|
||||
_mm256_maddubs_epi16(qx[3], _mm256_shuffle_epi32(y, 0xff)));
|
||||
acc[2*iy+1] = _mm256_add_epi32(acc[2*iy+1], _mm256_madd_epi16(m1, _mm256_add_epi16(sumi1, sumi2)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf1 = _mm256_cvtepi32_ps(acc[2*iy+0]);
|
||||
auto sumf2 = _mm256_cvtepi32_ps(acc[2*iy+1]);
|
||||
auto sum4 = _mm_mul_ps(_mm256_extractf128_ps(sumf1, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf1, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm256_extractf128_ps(sumf2, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
|
||||
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
|
||||
info.store(ix, iy, sum4);
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm256_setzero_si256();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
if constexpr (nrc_y == 1) {
|
||||
mul_mat_iq2_bn_r4_q8_k16_avx2<1>(n, vx, bx, info, nrc_x);
|
||||
} else {
|
||||
Q8_16<nrc_y> q8(info);
|
||||
auto m3 = _mm512_set1_epi8(0x3);
|
||||
int nb = n / QK_IQ1BN;
|
||||
__m512i acc[2*nrc_y] = {};
|
||||
__m512i qx[8];
|
||||
for (int ix = 0; ix < nrc_x/8; ++ix) {
|
||||
const float * dptr1 = (const float *)((const char *)vx + (8*ix+0)*bx);
|
||||
const float * dptr2 = (const float *)((const char *)vx + (8*ix+4)*bx);
|
||||
auto dl = _mm_loadu_ps(dptr1);
|
||||
auto dh = _mm_loadu_ps(dptr2);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr1 + 4);
|
||||
const uint8_t * iq2h = (const uint8_t *)(dptr2 + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
|
||||
auto bits_h = _mm512_loadu_si512((const __m512i *)iq2h + ib);
|
||||
qx[0] = _mm512_and_si512(bits_l, m3);
|
||||
qx[1] = _mm512_and_si512(bits_h, m3);
|
||||
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
|
||||
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 2), m3);
|
||||
qx[4] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
|
||||
qx[5] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 4), m3);
|
||||
qx[6] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
|
||||
qx[7] = _mm512_and_si512(_mm512_srli_epi16(bits_h, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants64(iy, ib);
|
||||
auto sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[0], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[1], sy);
|
||||
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[2], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[3], sy);
|
||||
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[4], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[5], sy);
|
||||
sy = _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff));
|
||||
acc[2*iy+0] = _mm512_dpbusd_epi32(acc[2*iy+0], qx[6], sy);
|
||||
acc[2*iy+1] = _mm512_dpbusd_epi32(acc[2*iy+1], qx[7], sy);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
__m128 sum4;
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
const auto& dx = k == 0 ? dl : dh;
|
||||
auto sumf = _mm512_cvtepi32_ps(acc[2*iy+k]);
|
||||
sum4 = _mm_mul_ps (_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dx, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
|
||||
sum4 = _mm_fmadd_ps(dx, _mm_set1_ps(-q8.sum_row(iy)), sum4);
|
||||
info.store(8*ix + 4*k, iy, sum4);
|
||||
}
|
||||
acc[2*iy+0] = acc[2*iy+1] = _mm512_setzero_si512();
|
||||
}
|
||||
}
|
||||
if (int ix = 8*(nrc_x/8); ix < nrc_x) {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = _mm_loadu_ps(dptr);
|
||||
const uint8_t * iq2l = (const uint8_t *)(dptr + 4);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto bits_l = _mm512_loadu_si512((const __m512i *)iq2l + ib);
|
||||
qx[0] = _mm512_and_si512(bits_l, m3);
|
||||
qx[1] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 2), m3);
|
||||
qx[2] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 4), m3);
|
||||
qx[3] = _mm512_and_si512(_mm512_srli_epi16(bits_l, 6), m3);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = q8.load_quants64(iy, ib);
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[0], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x00)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[1], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0x55)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[2], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xaa)));
|
||||
acc[iy] = _mm512_dpbusd_epi32(acc[iy], qx[3], _mm512_shuffle_epi32(y, _MM_PERM_ENUM(0xff)));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto dy = q8.scale(iy);
|
||||
auto sumf = _mm512_cvtepi32_ps(acc[iy]);
|
||||
auto sum4 = _mm_mul_ps(_mm512_extractf32x4_ps(sumf, 0), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x00)));
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 1), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0x55)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 2), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xaa)), sum4);
|
||||
sum4 = _mm_fmadd_ps(_mm512_extractf32x4_ps(sumf, 3), _mm_mul_ps(dl, _mm_shuffle_ps(dy, dy, 0xff)), sum4);
|
||||
sum4 = _mm_fmadd_ps(dl, _mm_set1_ps(-q8.sum_row(iy)), sum4);
|
||||
info.store(ix, iy, sum4);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
if (nrc_x%4) {
|
||||
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
mul_mat_iq2_bn_r4_q8_k16_avx2<nrc_y>(n, vx, bx, info, nrc_x);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq4_nl_r4_q8_2(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
@@ -5127,268 +4905,6 @@ static void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const Data
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc> struct Q8_K64 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
|
||||
Q8_K64(const DataInfo& info) {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const float * dptr = (const float *)info.src1_row(iy);
|
||||
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
|
||||
y[iy] = (const int8_t *)(dptr + 8);
|
||||
}
|
||||
}
|
||||
|
||||
inline __m256i load_quants(int iy, int i, int j) const { return _mm256_loadu_si256((const __m256i*)y[iy] + 4*i + j); }
|
||||
inline __m128 scale(int iy) const { return _mm_loadu_ps(d + 8*iy); }
|
||||
inline __m128 minus(int iy) const { return _mm_loadu_ps(d + 8*iy + 4); }
|
||||
|
||||
float d[8*nrc_y];
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
static __m256i load_shuffle(int i) {
|
||||
static const uint8_t data[128] = {
|
||||
0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 1, 255, 1, 255, 1, 255, 1, 255, 1, 255, 2, 255, 2, 255, 2, 255, 2, 255, 2, 255, 12, 255,
|
||||
3, 255, 3, 255, 3, 255, 3, 255, 3, 255, 4, 255, 4, 255, 4, 255, 4, 255, 4, 255, 5, 255, 5, 255, 5, 255, 5, 255, 5, 255, 12, 255,
|
||||
6, 255, 6, 255, 6, 255, 6, 255, 6, 255, 7, 255, 7, 255, 7, 255, 7, 255, 7, 255, 8, 255, 8, 255, 8, 255, 8, 255, 8, 255, 12, 255,
|
||||
9, 255, 9, 255, 9, 255, 9, 255, 9, 255, 10, 255, 10, 255, 10, 255, 10, 255, 10, 255, 11, 255, 11, 255, 11, 255, 11, 255, 11, 255, 12, 255,
|
||||
};
|
||||
return _mm256_loadu_si256((const __m256i*)data + i);
|
||||
}
|
||||
const __m256i shuff[4] = { load_shuffle(0), load_shuffle(1), load_shuffle(2), load_shuffle(3) };
|
||||
const __m256i mult[4] = {
|
||||
_mm256_set_epi64x(0x5100010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x1b00010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0900010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
_mm256_set_epi64x(0x0300010003000900, 0x1b00510001000300, 0x09001b0051000100, 0x030009001b005100),
|
||||
};
|
||||
const __m256i m3 = _mm256_set1_epi16(3);
|
||||
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
|
||||
const __m256i bmask = _mm256_set_epi8(62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
#endif
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, __m256i& v1, __m256i& v2) const {
|
||||
auto data128 = _mm_loadu_si128((const __m128i *)x); // Note: we load 16 instead of 13 bytes!
|
||||
auto data = MM256_SET_M128I(data128, data128);
|
||||
auto val1 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[0]), mult[0]), m3);
|
||||
auto val2 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[1]), mult[1]), m3);
|
||||
auto val3 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[2]), mult[2]), m3);
|
||||
auto val4 = _mm256_mulhi_epu16(_mm256_mullo_epi16(_mm256_shuffle_epi8(data, shuff[3]), mult[3]), m3);
|
||||
#if defined HAVE_FANCY_SIMD && defined __AVX512VBMI__
|
||||
v1 = _mm256_permutex2var_epi8(val1, bmask, val2);
|
||||
v2 = _mm256_permutex2var_epi8(val3, bmask, val4);
|
||||
#else
|
||||
v1 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val1, val2), 216);
|
||||
v2 = _mm256_permute4x64_epi64(_mm256_packs_epi16(val3, val4), 216);
|
||||
#endif
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
IQK_NOINLINE void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizerIQ1BN deq;
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
const block_iq1_bn * x;
|
||||
const char * cx0 = (const char *)vx;
|
||||
float scale;
|
||||
ggml_half d16;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
const char * cx = cx0 + ix*bx;
|
||||
std::memcpy(&d16, cx, sizeof(d16));
|
||||
scale = GGML_FP16_TO_FP32(d16);
|
||||
cx += sizeof(d16);
|
||||
x = (const block_iq1_bn *)cx;
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc1 = _mm256_setzero_si256(), acc2 = _mm256_setzero_si256();
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc1 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc1, val[0], q8.load_quants(0, i, 0)), val[1], q8.load_quants(0, i, 1));
|
||||
acc2 = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc2, val[2], q8.load_quants(0, i, 2)), val[3], q8.load_quants(0, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
|
||||
acc1 = _mm256_add_epi32(acc1, _mm256_madd_epi16(m1_16, dot1));
|
||||
acc2 = _mm256_add_epi32(acc2, _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
}
|
||||
accd[0] = _mm256_add_epi32(acc1, acc2);
|
||||
}
|
||||
else {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 0, val[0], val[1]);
|
||||
deq.prepare_iq1bn_quants(x + 2*i + 1, val[2], val[3]);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)),
|
||||
val[1], q8.load_quants(iy, i, 1)),
|
||||
val[2], q8.load_quants(iy, i, 2)),
|
||||
val[3], q8.load_quants(iy, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)));
|
||||
dot1 = _mm256_madd_epi16(m1_16, _mm256_add_epi16(dot1, dot2));
|
||||
accd[iy] = _mm256_add_epi32(dot1, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
deq.prepare_iq1bn_quants(x + i, val[0], val[1]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i/2, 0)), val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 1))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
|
||||
info.store(ix, iy, scale*hsum_float_4(sumf));
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
struct DequantizeIQ2BN final : public BaseDequantizer<block_iq2_bn, true> {
|
||||
DequantizeIQ2BN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare4(int i, __m256i * val) const {
|
||||
auto q2bits_1 = _mm256_loadu_si256((const __m256i *)x[2*i].qs);
|
||||
auto q2bits_2 = _mm256_srli_epi16(q2bits_1, 2);
|
||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x20), val+0);
|
||||
make2(_mm256_permute2x128_si256(q2bits_1, q2bits_2, 0x31), val+2);
|
||||
}
|
||||
IQK_ALWAYS_INLINE void make2(__m256i q2_1, __m256i * val) const {
|
||||
val[0] = _mm256_and_si256(q2_1, mask2);
|
||||
val[1] = _mm256_and_si256(_mm256_srli_epi16(q2_1, 4), mask2);
|
||||
}
|
||||
IQK_ALWAYS_INLINE void prepare2(int i, __m256i * val) const {
|
||||
auto q2bits_1 = _mm_loadu_si128((const __m128i *)x[i].qs);
|
||||
make2(MM256_SET_M128I(_mm_srli_epi16(q2bits_1, 2), q2bits_1), val);
|
||||
}
|
||||
const __m256i m1_8 = _mm256_set1_epi8(1);
|
||||
const __m256i mf_8 = _mm256_set1_epi8(16);
|
||||
const __m256i mask2 = _mm256_set1_epi8(0x03);
|
||||
const __m256i mask3 = _mm256_set1_epi8(0x30);
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
IQK_NOINLINE void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
const int nb = n / QK_IQ1BN;
|
||||
Q8_K64<nrc_y> q8(info);
|
||||
DequantizeIQ2BN deq(vx, bx);
|
||||
__m256i accd[nrc_y];
|
||||
__m256i val[4];
|
||||
|
||||
#ifndef HAVE_FANCY_SIMD
|
||||
const auto m1_16 = _mm256_set1_epi16(1);
|
||||
#endif
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
__m256i acc[2] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare4(i, val);
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
acc[0] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[0], val[0], q8.load_quants(0, i, 0)),
|
||||
val[1], q8.load_quants(0, i, 1));
|
||||
acc[1] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(acc[1], val[2], q8.load_quants(0, i, 2)),
|
||||
val[3], q8.load_quants(0, i, 3));
|
||||
#else
|
||||
auto dot1 = _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(0, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(0, i, 1)));
|
||||
auto dot2 = _mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(0, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(0, i, 3)));
|
||||
acc[0] = _mm256_add_epi32(acc[0], _mm256_madd_epi16(m1_16, dot1));
|
||||
acc[1] = _mm256_add_epi32(acc[1], _mm256_madd_epi16(m1_16, dot2));
|
||||
#endif
|
||||
}
|
||||
accd[0] = _mm256_add_epi32(acc[0], acc[1]);
|
||||
}
|
||||
else {
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = _mm256_setzero_si256();
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare4(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy],
|
||||
val[0], q8.load_quants(iy, i, 0)), val[1], q8.load_quants(iy, i, 1)),
|
||||
val[2], q8.load_quants(iy, i, 2)), val[3], q8.load_quants(iy, i, 3));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i, 1))),
|
||||
_mm256_add_epi16(_mm256_maddubs_epi16(val[2], q8.load_quants(iy, i, 2)),
|
||||
_mm256_maddubs_epi16(val[3], q8.load_quants(iy, i, 3)))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
deq.prepare2(i, val);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
accd[iy] = _mm256_dpbusd_epi32(_mm256_dpbusd_epi32(accd[iy], val[0], q8.load_quants(iy, i/2, 0)),
|
||||
val[1], q8.load_quants(iy, i/2, 1));
|
||||
#else
|
||||
auto dot = _mm256_madd_epi16(m1_16, _mm256_add_epi16(_mm256_maddubs_epi16(val[0], q8.load_quants(iy, i/2, 0)),
|
||||
_mm256_maddubs_epi16(val[1], q8.load_quants(iy, i/2, 0))));
|
||||
accd[iy] = _mm256_add_epi32(dot, accd[iy]);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto vd = q8.scale(iy);
|
||||
auto sumi = _mm_add_epi32(_mm256_castsi256_si128(accd[iy]), _mm256_extractf128_si256(accd[iy], 1));
|
||||
auto sumf = _mm_fmsub_ps(vd, _mm_cvtepi32_ps(sumi), q8.minus(iy));
|
||||
info.store(ix, iy, deq.d*hsum_float_4(sumf));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
m.funcs[0] = mul_mat_qX_K_q8_K_AVX512_1<Dequantizer>;
|
||||
@@ -5445,44 +4961,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
case GGML_TYPE_IQ5_K:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
return ggml_type(typeB) == GGML_TYPE_Q8_K ? iqk_set_kernels_iqk_quants(ne00, typeA, typeB, mm.funcs) : false;
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
mm.funcs[0] = mul_mat_iq1bn_q8_K64<1>;
|
||||
mm.funcs[1] = mul_mat_iq1bn_q8_K64<2>;
|
||||
mm.funcs[2] = mul_mat_iq1bn_q8_K64<3>;
|
||||
mm.funcs[3] = mul_mat_iq1bn_q8_K64<4>;
|
||||
mm.funcs[4] = mul_mat_iq1bn_q8_K64<5>;
|
||||
mm.funcs[5] = mul_mat_iq1bn_q8_K64<6>;
|
||||
mm.funcs[6] = mul_mat_iq1bn_q8_K64<7>;
|
||||
mm.funcs[7] = mul_mat_iq1bn_q8_K64<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
mm.funcs[0] = mul_mat_iq2bn_q8_K64<1>;
|
||||
mm.funcs[1] = mul_mat_iq2bn_q8_K64<2>;
|
||||
mm.funcs[2] = mul_mat_iq2bn_q8_K64<3>;
|
||||
mm.funcs[3] = mul_mat_iq2bn_q8_K64<4>;
|
||||
mm.funcs[4] = mul_mat_iq2bn_q8_K64<5>;
|
||||
mm.funcs[5] = mul_mat_iq2bn_q8_K64<6>;
|
||||
mm.funcs[6] = mul_mat_iq2bn_q8_K64<7>;
|
||||
mm.funcs[7] = mul_mat_iq2bn_q8_K64<8>;
|
||||
expected_typeB = GGML_TYPE_Q8_K64;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_BN_R4:
|
||||
assert (ne00 % QK_IQ1BN == 0);
|
||||
mm.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
|
||||
mm.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
|
||||
mm.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
|
||||
mm.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
|
||||
mm.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
|
||||
mm.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
|
||||
//#ifdef HAVE_FANCY_SIMD
|
||||
mm.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
|
||||
mm.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
|
||||
//#endif
|
||||
expected_typeB = GGML_TYPE_Q8_K16;
|
||||
break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@@ -5824,6 +5302,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
|
||||
case GGML_TYPE_IQ1_S:
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
case GGML_TYPE_IQ1_M_R4:
|
||||
case GGML_TYPE_IQ1_BN:
|
||||
case GGML_TYPE_IQ2_BN:
|
||||
case GGML_TYPE_IQ2_BN_R4:
|
||||
return iqk_set_kernels_1bit(ne00, typeA, typeB, mm.funcs, mm.func16);
|
||||
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user