Refactor iqk: factor out 1-bit quants (NEON)

This commit is contained in:
Iwan Kawrakow
2025-05-18 16:54:44 +03:00
parent c63a0af5b7
commit 28b94800c1
8 changed files with 765 additions and 628 deletions

View File

@@ -525,6 +525,28 @@ struct Q4Bits {
#endif
#else
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
constexpr static int nrc_y = nrc;
Q8(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
}
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
inline int16x8_t load_bsums8(int iy, int i) const {
auto q8s = vld1q_s16_x2(y[iy][i].bsums);
return vpaddq_s16(q8s.val[0], q8s.val[1]);
}
inline float scale(int iy, int i) const { return y[iy][i].d; }
const block_q8 * y[nrc_y];
};
#endif
#endif

View File

@@ -785,6 +785,11 @@ static const uint32_t iq1s_grid_us[2048] = {
};
#endif
}
#ifdef __x86_64__
namespace {
template <int nrc_y>
void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
@@ -1540,14 +1545,7 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
switch (typeA) {
case GGML_TYPE_IQ1_S:
if (ne00%QK_K != 0) return false;
funcs[0] = mul_mat_iq1_s_q8_K<1>;
funcs[1] = mul_mat_iq1_s_q8_K<2>;
funcs[2] = mul_mat_iq1_s_q8_K<3>;
funcs[3] = mul_mat_iq1_s_q8_K<4>;
funcs[4] = mul_mat_iq1_s_q8_K<5>;
funcs[5] = mul_mat_iq1_s_q8_K<6>;
funcs[6] = mul_mat_iq1_s_q8_K<7>;
funcs[7] = mul_mat_iq1_s_q8_K<8>;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_K, funcs);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq1_s_q8_K<16>;
#endif
@@ -1555,66 +1553,31 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
break;
case GGML_TYPE_IQ1_S_R4:
if (ne00%128 != 0) return false;
funcs[0] = mul_mat_iq1_s_r4_q8_1<1>;
funcs[1] = mul_mat_iq1_s_r4_q8_1<2>;
funcs[2] = mul_mat_iq1_s_r4_q8_1<3>;
funcs[3] = mul_mat_iq1_s_r4_q8_1<4>;
funcs[4] = mul_mat_iq1_s_r4_q8_1<5>;
funcs[5] = mul_mat_iq1_s_r4_q8_1<6>;
funcs[6] = mul_mat_iq1_s_r4_q8_1<7>;
funcs[7] = mul_mat_iq1_s_r4_q8_1<8>;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_r4_q8_1, funcs);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq1_s_r4_q8_1<16>;
#endif
break;
case GGML_TYPE_IQ1_M_R4:
if (ne00%128 != 0) return false;
funcs[0] = mul_mat_iq1_m_r4_q8_0<1>;
funcs[1] = mul_mat_iq1_m_r4_q8_0<2>;
funcs[2] = mul_mat_iq1_m_r4_q8_0<3>;
funcs[3] = mul_mat_iq1_m_r4_q8_0<4>;
funcs[4] = mul_mat_iq1_m_r4_q8_0<5>;
funcs[5] = mul_mat_iq1_m_r4_q8_0<6>;
funcs[6] = mul_mat_iq1_m_r4_q8_0<7>;
funcs[7] = mul_mat_iq1_m_r4_q8_0<8>;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);
#ifdef HAVE_FANCY_SIMD
func16 = mul_mat_iq1_m_r4_q8_0<16>;
#endif
break;
case GGML_TYPE_IQ1_BN:
assert (ne00 % QK_IQ1BN == 0);
funcs[0] = mul_mat_iq1bn_q8_K64<1>;
funcs[1] = mul_mat_iq1bn_q8_K64<2>;
funcs[2] = mul_mat_iq1bn_q8_K64<3>;
funcs[3] = mul_mat_iq1bn_q8_K64<4>;
funcs[4] = mul_mat_iq1bn_q8_K64<5>;
funcs[5] = mul_mat_iq1bn_q8_K64<6>;
funcs[6] = mul_mat_iq1bn_q8_K64<7>;
funcs[7] = mul_mat_iq1bn_q8_K64<8>;
if (ne00 % QK_IQ1BN != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1bn_q8_K64, funcs);
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
assert (ne00 % QK_IQ1BN == 0);
funcs[0] = mul_mat_iq2bn_q8_K64<1>;
funcs[1] = mul_mat_iq2bn_q8_K64<2>;
funcs[2] = mul_mat_iq2bn_q8_K64<3>;
funcs[3] = mul_mat_iq2bn_q8_K64<4>;
funcs[4] = mul_mat_iq2bn_q8_K64<5>;
funcs[5] = mul_mat_iq2bn_q8_K64<6>;
funcs[6] = mul_mat_iq2bn_q8_K64<7>;
funcs[7] = mul_mat_iq2bn_q8_K64<8>;
if (ne00 % QK_IQ1BN != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2bn_q8_K64, funcs);
expected_typeB = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN_R4:
assert (ne00 % QK_IQ1BN == 0);
funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
if (ne00 % QK_IQ1BN != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_bn_r4_q8_k16, funcs);
expected_typeB = GGML_TYPE_Q8_K16;
break;
@@ -1626,4 +1589,694 @@ bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t,
}
#else
// -------------------------------- __aarch64__
namespace {
template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
y[iy] = (const int8_t *)(dptr + 8);
}
}
inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); }
inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); }
float d[8*nrc_y];
const int8_t * y[nrc_y];
};
struct DequantizerIQ1BN {
const uint8x16_t m1 = vdupq_n_u8(1);
static inline uint8x16x4_t load_shuffles() {
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
return vld1q_u8_x4(data);
}
static inline uint8x16x4_t load_mult() {
static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
return vld1q_u8_x4(data);
}
const uint8x16x4_t shuff = load_shuffles();
const uint8x16x4_t mult = load_mult();
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
for (int k = 0; k < 4; ++k) {
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
}
}
IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
for (int k = 0; k < 4; ++k) {
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
}
}
};
template <int nrc_y>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
DequantizerIQ1BN deq;
int32x4_t accd[nrc_y];
int8x16x4_t v1, v2;
float scale;
ggml_half d16;
char * c16 = (char *)&d16;
for (int ix = 0; ix < nrc_x; ++ix) {
const char * cx = ((const char *)vx + ix*bx);
c16[0] = cx[0]; c16[1] = cx[1];
//std::memcpy(&d16, cx, sizeof(d16));
cx += sizeof(d16);
scale = GGML_FP16_TO_FP32(d16);
const block_iq1_bn * x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
}
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
}
else {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
q = q8.load_quants(iy, i, 2);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
q = q8.load_quants(iy, i, 3);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
}
}
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants_nosub(x+i, v1);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, i/2, 0);
for (int j = 0; j < 4; ++j) {
accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]);
}
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i/2, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i/2, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
template <int nrc> struct Q8_16 {
constexpr static int nrc_y = nrc;
Q8_16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto ptr = (const float *)info.src1_row(iy);
std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
y[iy] = (const int8_t *)(ptr + 5);
}
}
inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
inline float scale(int iy, int k) const { return d[5*iy+k]; }
inline float sum_row(int iy) const { return d[5*iy + 4]; }
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
float d[5*nrc_y];
const int8_t * y[nrc_y];
};
template <int nrc_y>
static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
Q8_16<nrc_y> q8(info);
auto m3 = vdupq_n_u8(0x3);
int nb = n / QK_IQ1BN;
if constexpr (nrc_y == 1) {
auto mc = vdupq_n_u8(0xc);
int32x4_t acc[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto y = q8.load_quants(0, ib);
for (int j = 0; j < 4; ++j) {
auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j);
auto bits2 = vshrq_n_u8(bits1, 4);
acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0);
acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1);
acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2);
acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3);
}
}
auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
info.store(ix, 0, sumf);
}
} else {
int32x4_t acc[4*nrc_y] = {};
uint8x16_t qx[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = vld1q_u8_x2(iq2 + 64*ib);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
}
bits = vld1q_u8_x2(iq2 + 64*ib + 32);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
info.store(ix, iy, sumf);
acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
}
}
}
}
template <int nrc_y>
static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
int32x4_t accd[nrc_y];
const auto mask2 = vdupq_n_s8(3);
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
const float d = *dptr;
const block_iq2_bn * x = (const block_iq2_bn *)(dptr + 1);
if constexpr (nrc_y == 1) {
int8x16x4_t v1;
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
for (int j = 0; j < 2; ++j) {
auto q = q8.load_quants64(0, i, j);
auto q2bits = vld1q_u8(x[2*i+j].qs);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]);
}
}
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
} else {
int8x16x4_t v1, v2;
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
for (int i = 0; i < nb/2; ++i) {
auto q2bits = vld1q_u8(x[2*i+0].qs);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
q2bits = vld1q_u8(x[2*i+1].qs);
v2.val[0] = vandq_s8(q2bits, mask2);
v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v2.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
q = q8.load_quants(iy, i, 2);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
q = q8.load_quants(iy, i, 3);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
}
}
}
int i = 2*(nb/2);
if (i < nb) {
auto q2bits = vld1q_u8(x[i].qs);
int8x16x4_t v1;
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i/2, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i/2, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, -d*vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
template <int nrc_y>
static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
uint8x16_t qx[8];
float32x4_t acc[nrc_y] = {};
auto ms = vdup_n_u16(0x8000);
auto mask = vdupq_n_s8(0x03);
float d8[4*nrc_y];
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums)));
vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales));
}
for (int k = 0; k < 4; ++k) {
auto sas = vld1_u16(x[4*ib+k].qh);
auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
signs = vadd_s16(vdup_n_s16(-8), signs);
auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4)));
qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask);
qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask);
qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask);
qx[7] = vandq_u8(vshrq_n_u8(qx[6], 4), mask); qx[6] = vandq_u8(qx[6], mask);
auto scales = vmovl_u16(scales4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
sumi = vmulq_s32(scales, sumi);
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi));
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d1, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303};
auto step = vdupq_n_u8(4);
auto ms = vdupq_n_u8(0x08);
auto mask = vdupq_n_s8(0x18);
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr)));
auto x = (const block_iq1_m_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
for (int k = 0; k < 4; ++k) {
auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]);
scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf));
auto scales16 = vmovl_u8(scales4);
auto scales1 = vmovl_u16(vget_low_u16(scales16));
auto scales2 = vmovl_u16(vget_high_u16(scales16));
auto qh = (const uint32_t *)x[4*ib+k].qh;
auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
signs = vaddq_s8(signs, vdupq_n_s8(-8));
qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
auto shuffle = shuffle0;
for (int j = 0; j < 4; ++j) {
auto s = vqtbl1q_s8(signs, shuffle);
qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask));
qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask));
shuffle = vaddq_u8(shuffle, step);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
auto sumi1 = vdupq_n_s32(0);
auto sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d1, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float32x4_t acc[2] = {};
int32x4_t isum[8];
auto ms = vdup_n_u16(0x8000);
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
auto scale_yd = vdupq_n_f32(q8.y[0][ib].d);
auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums))));
for (int k = 0; k < 4; ++k) {
auto sas = vld1_u16(x[4*ib+k].qh);
auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
isum[k+4] = vmull_s16(signs, scales4);
qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]});
qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)]});
qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]});
qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]});
qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]});
qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)]});
qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
auto scales = vmovl_u16(scales4);
auto y = vld1q_s8_x2(q8.y[0][ib].qs + 32*k);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
sumi1 = vpaddq_s32(sumi1, sumi2);
sumi3 = vpaddq_s32(sumi3, sumi4);
isum[k] = vmulq_s32(scales, vpaddq_s32(sumi1, sumi3));
}
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[0]), scale_yd, 0);
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[1]), scale_yd, 1);
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[2]), scale_yd, 2);
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[3]), scale_yd, 3);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[4]), scale_ym, 0);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[5]), scale_ym, 1);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[6]), scale_ym, 2);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[7]), scale_ym, 3);
}
info.store(ix, 0, vmulq_f32(d1, vfmaq_f32(acc[0], acc[1], vdupq_n_f32(IQ1S_DELTA))));
acc[0] = acc[1] = vdupq_n_f32(0.f);
}
}
template <int nrc_y>
void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
Q8<nrc_y, block_q8_K> q8(info);
int8x16_t qx[16];
int32x4_t scales[2];
int16x4_t deltas[2];
float32x4_t acc[nrc_y] = {};
auto delta_mask = vdupq_n_u16(0x8000);
for (int ix = 0; ix < nrc_x; ++ix) {
auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
auto qhb = vld1q_u16(iq1s[ibl].qh);
auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7));
scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1));
auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask);
// Note: we explicitely assume IQ1S_DELTA = 0.125
auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask));
//auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask));
//deltas128 = vmulq_s16(scales128, deltas128);
scales128 = vshlq_n_u16(scales128, 3);
auto qs = iq1s[ibl].qs;
auto qh = iq1s[ibl].qh;
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]});
qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]});
qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]});
qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]});
qs += 8;
}
scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128)));
scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128)));
deltas[0] = vget_low_s16 (deltas128);
deltas[1] = vget_high_s16(deltas128);
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums8(iy, ibl);
auto sumi = vdupq_n_s32(0);
sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums));
sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums));
for (int k = 0; k < QK_K/128; ++k) {
auto qy = q8.load_quants_64(iy, ibl, 2*k+0);
auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]);
auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]);
auto dot12 = vpaddq_s32(dot1, dot2);
qy = q8.load_quants_64(iy, ibl, 2*k+1);
auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]);
auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]);
auto dot34 = vpaddq_s32(dot3, dot4);
auto dot = vpaddq_s32(dot12, dot34);
sumi = vmlaq_s32(sumi, dot, scales[k]);
}
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
acc[iy] = vdupq_n_f32(0);
}
}
}
}
bool iqk_set_kernels_1bit(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& funcs, mul_mat_t& func16) {
auto expected_Btype = GGML_TYPE_Q8_K128;
func16 = nullptr;
switch (typeA) {
case GGML_TYPE_IQ1_BN:
if (ne00 % QK_IQ1BN != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1bn_q8_K64, funcs);
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
if (ne00 % QK_IQ1BN != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2bn_q8_K64, funcs);
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN_R4:
if (ne00 % QK_IQ1BN != 0) return false;
funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
//funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
//funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
//funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
expected_Btype = GGML_TYPE_Q8_K16;
break;
case GGML_TYPE_IQ1_S:
if (ne00%QK_K != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_q8_K, funcs);
func16 = mul_mat_iq1_s_q8_K<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_S_R4:
if (ne00%128 != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_s_r4_q8_1, funcs);
funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
func16 = mul_mat_iq1_s_r4_q8_1<16>;
expected_Btype = GGML_TYPE_Q8_K128;
break;
case GGML_TYPE_IQ1_M_R4:
if (ne00%128 != 0) return false;
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_m_r4_q8_0, funcs);
func16 = mul_mat_iq1_m_r4_q8_0<16>;
expected_Btype = GGML_TYPE_Q8_K128;
break;
default:
return false;
}
return ggml_type(typeB) == expected_Btype;
}
#endif
#endif

View File

@@ -7,6 +7,8 @@
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#ifdef __x86_64__
namespace {
// float matrices - we handle f16, bf16 (if native bf16 support is available) and f32, but only to f32 result
@@ -564,4 +566,9 @@ bool iqk_set_kernels_float(int ne00, int typeA, int typeB, std::array<mul_mat_t,
}
#else
// ----------------------------------- __aarch64__ -----------------------------------------------
#endif
#endif

View File

@@ -7,6 +7,8 @@
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#ifdef __x86_64__
namespace {
#ifdef HAVE_FANCY_SIMD
@@ -2124,4 +2126,9 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
}
#else
// ----------------------------------------- __aarch64__ ---------------------------------------------
#endif
#endif

View File

@@ -7,6 +7,8 @@
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#ifdef __x86_64__
namespace {
inline __m256i get_scale_shuffle_8(int i) {
@@ -1627,4 +1629,9 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
}
#else
// --------------------------------------- __aarch64__ ---------------------------------------------
#endif
#endif

View File

@@ -7,6 +7,8 @@
#define GGML_COMMON_IMPL_C
#include "ggml-common.h"
#ifdef __x86_64__
namespace {
// Handles q4_K and q5_K scales/mins
@@ -1776,4 +1778,9 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
}
#else
// --------------------------------- __aarch64__ --------------------------------------
#endif
#endif

View File

@@ -11,6 +11,8 @@
// ============================== Legacy quants
//
#ifdef __x86_64__
namespace {
struct DotHelper {
@@ -1699,4 +1701,9 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
return ggml_type(typeB) == expected_typeB;
}
#else
// ---------------------------- __aarch64__ ----------------------------------------------
#endif
#endif

View File

@@ -3037,88 +3037,6 @@ struct DequantizerIQ1BN {
}
};
template <int nrc_y>
static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
DequantizerIQ1BN deq;
int32x4_t accd[nrc_y];
int8x16x4_t v1, v2;
float scale;
ggml_half d16;
char * c16 = (char *)&d16;
for (int ix = 0; ix < nrc_x; ++ix) {
const char * cx = ((const char *)vx + ix*bx);
c16[0] = cx[0]; c16[1] = cx[1];
//std::memcpy(&d16, cx, sizeof(d16));
cx += sizeof(d16);
scale = GGML_FP16_TO_FP32(d16);
const block_iq1_bn * x = (const block_iq1_bn *)cx;
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
}
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
}
else {
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants_nosub(x+2*i+0, v1);
deq.prepare_iq1bn_quants_nosub(x+2*i+1, v2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
q = q8.load_quants(iy, i, 2);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
q = q8.load_quants(iy, i, 3);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
}
}
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants_nosub(x+i, v1);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, i/2, 0);
for (int j = 0; j < 4; ++j) {
accd[0] = ggml_vdotq_s32(accd[0], q.val[j], v1.val[j]);
}
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i/2, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i/2, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, -scale * vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
template <int nrc> struct Q8_16 {
constexpr static int nrc_y = nrc;
@@ -3141,195 +3059,6 @@ template <int nrc> struct Q8_16 {
const int8_t * y[nrc_y];
};
template <int nrc_y>
static IQK_NOINLINE void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
if (nrc_x%4) {
printf("%s: %d is not a multiple of 4\n", __func__, nrc_x);
GGML_ABORT("fatal error");
}
Q8_16<nrc_y> q8(info);
auto m3 = vdupq_n_u8(0x3);
int nb = n / QK_IQ1BN;
if constexpr (nrc_y == 1) {
auto mc = vdupq_n_u8(0xc);
int32x4_t acc[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto y = q8.load_quants(0, ib);
for (int j = 0; j < 4; ++j) {
auto bits1 = vld1q_u8(iq2 + 64*ib + 16*j);
auto bits2 = vshrq_n_u8(bits1, 4);
acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits1, m3), y.val[j], 0);
acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits1, mc), y.val[j], 1);
acc[2*j+0] = vdotq_laneq_s32(acc[2*j+0], vandq_u8(bits2, m3), y.val[j], 2);
acc[2*j+1] = vdotq_laneq_s32(acc[2*j+1], vandq_u8(bits2, mc), y.val[j], 3);
}
}
auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
info.store(ix, 0, sumf);
}
} else {
int32x4_t acc[4*nrc_y] = {};
uint8x16_t qx[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = vld1q_u8_x2(iq2 + 64*ib);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
}
bits = vld1q_u8_x2(iq2 + 64*ib + 32);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
info.store(ix, iy, sumf);
acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
}
}
}
}
template <int nrc_y>
static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
const int nb = n / QK_IQ1BN;
Q8_K64<nrc_y> q8(info);
int32x4_t accd[nrc_y];
const auto mask2 = vdupq_n_s8(3);
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
const float d = *dptr;
const block_iq2_bn * x = (const block_iq2_bn *)(dptr + 1);
if constexpr (nrc_y == 1) {
int8x16x4_t v1;
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
for (int j = 0; j < 2; ++j) {
auto q = q8.load_quants64(0, i, j);
auto q2bits = vld1q_u8(x[2*i+j].qs);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
acc[0] = ggml_vdotq_s32(acc[0], q.val[0], v1.val[0]);
acc[1] = ggml_vdotq_s32(acc[1], q.val[1], v1.val[1]);
acc[2] = ggml_vdotq_s32(acc[2], q.val[2], v1.val[2]);
acc[3] = ggml_vdotq_s32(acc[3], q.val[3], v1.val[3]);
}
}
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
} else {
int8x16x4_t v1, v2;
for (int iy = 0; iy < nrc_y; ++iy) accd[iy] = vdupq_n_s32(0);
for (int i = 0; i < nb/2; ++i) {
auto q2bits = vld1q_u8(x[2*i+0].qs);
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
q2bits = vld1q_u8(x[2*i+1].qs);
v2.val[0] = vandq_s8(q2bits, mask2);
v2.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v2.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v2.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
q = q8.load_quants(iy, i, 2);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[0]), q.val[1], v2.val[1]);
q = q8.load_quants(iy, i, 3);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v2.val[2]), q.val[1], v2.val[3]);
}
}
}
int i = 2*(nb/2);
if (i < nb) {
auto q2bits = vld1q_u8(x[i].qs);
int8x16x4_t v1;
v1.val[0] = vandq_s8(q2bits, mask2);
v1.val[1] = vandq_s8(vshrq_n_u8(q2bits, 2), mask2);
v1.val[2] = vandq_s8(vshrq_n_u8(q2bits, 4), mask2);
v1.val[3] = vshrq_n_u8(q2bits, 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i/2, 0);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[0]), q.val[1], v1.val[1]);
q = q8.load_quants(iy, i/2, 1);
accd[iy] = ggml_vdotq_s32(ggml_vdotq_s32(accd[iy], q.val[0], v1.val[2]), q.val[1], v1.val[3]);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, -d*vaddvq_f32(vfmsq_f32(q8.minus(iy), q8.scale(iy), vcvtq_f32_s32(accd[iy]))));
}
}
}
IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
@@ -3760,280 +3489,6 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
}
}
static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float32x4_t acc[2] = {};
int32x4_t isum[8];
auto ms = vdup_n_u16(0x8000);
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
auto scale_yd = vdupq_n_f32(q8.y[0][ib].d);
auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums))));
for (int k = 0; k < 4; ++k) {
auto sas = vld1_u16(x[4*ib+k].qh);
auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
isum[k+4] = vmull_s16(signs, scales4);
qx[0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)]});
qx[1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)]});
qx[2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)]});
qx[3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)]});
qx[4] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)]});
qx[5] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)]});
qx[6] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
qx[7] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)],
iq1s_grid[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
auto scales = vmovl_u16(scales4);
auto y = vld1q_s8_x2(q8.y[0][ib].qs + 32*k);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
sumi1 = vpaddq_s32(sumi1, sumi2);
sumi3 = vpaddq_s32(sumi3, sumi4);
isum[k] = vmulq_s32(scales, vpaddq_s32(sumi1, sumi3));
}
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[0]), scale_yd, 0);
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[1]), scale_yd, 1);
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[2]), scale_yd, 2);
acc[0] = vfmaq_laneq_f32(acc[0], vcvtq_f32_s32(isum[3]), scale_yd, 3);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[4]), scale_ym, 0);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[5]), scale_ym, 1);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[6]), scale_ym, 2);
acc[1] = vfmaq_laneq_f32(acc[1], vcvtq_f32_s32(isum[7]), scale_ym, 3);
}
info.store(ix, 0, vmulq_f32(d1, vfmaq_f32(acc[0], acc[1], vdupq_n_f32(IQ1S_DELTA))));
acc[0] = acc[1] = vdupq_n_f32(0.f);
}
}
template <int nrc_y>
static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
uint8x16_t qx[8];
float32x4_t acc[nrc_y] = {};
auto ms = vdup_n_u16(0x8000);
auto mask = vdupq_n_s8(0x03);
float d8[4*nrc_y];
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums)));
vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales));
}
for (int k = 0; k < 4; ++k) {
auto sas = vld1_u16(x[4*ib+k].qh);
auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
scales4 = vorr_u16(vshl_n_u16(scales4, 1), vdup_n_u16(1));
auto signs = vreinterpret_s16_u16(vorr_u16(vceq_u16(vand_u16(sas, ms), ms), vdup_n_u16(1)));
signs = vadd_s16(vdup_n_s16(-8), signs);
auto delta4 = vmulq_f32(vdupq_n_f32(0.125f), vcvtq_f32_s32(vmull_s16(signs, scales4)));
qx[0] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
qx[2] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 5) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 5) & 0x0700)]});
qx[4] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[0] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[1] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[2] << 2) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[3] << 2) & 0x0700)]});
qx[6] = vreinterpretq_u8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[0] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[1] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[2] >> 1) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[3] >> 1) & 0x0700)]});
qx[1] = vandq_u8(vshrq_n_u8(qx[0], 4), mask); qx[0] = vandq_u8(qx[0], mask);
qx[3] = vandq_u8(vshrq_n_u8(qx[2], 4), mask); qx[2] = vandq_u8(qx[2], mask);
qx[5] = vandq_u8(vshrq_n_u8(qx[4], 4), mask); qx[4] = vandq_u8(qx[4], mask);
qx[7] = vandq_u8(vshrq_n_u8(qx[6], 4), mask); qx[6] = vandq_u8(qx[6], mask);
auto scales = vmovl_u16(scales4);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
sumi = vmulq_s32(scales, sumi);
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi));
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d1, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq1_s_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
Q8<nrc_y, block_q8_K> q8(info);
int8x16_t qx[16];
int32x4_t scales[2];
int16x4_t deltas[2];
float32x4_t acc[nrc_y] = {};
auto delta_mask = vdupq_n_u16(0x8000);
for (int ix = 0; ix < nrc_x; ++ix) {
auto iq1s = (const block_iq1_s *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < n/QK_K; ++ibl) {
float d = GGML_FP16_TO_FP32(iq1s[ibl].d);
auto qhb = vld1q_u16(iq1s[ibl].qh);
auto scales128 = vandq_u16(vshrq_n_u16(qhb, 12), vdupq_n_u16(7));
scales128 = vaddq_u16(vshlq_n_u16(scales128, 1), vdupq_n_u16(1));
auto mask = vceqq_u16(vandq_u16(qhb, delta_mask), delta_mask);
// Note: we explicitely assume IQ1S_DELTA = 0.125
auto deltas128 = vsubq_s16(vbicq_s16(scales128, mask), vandq_s16(scales128, mask));
//auto deltas128 = vorrq_s16(vandq_s16(vdupq_n_s16(-1), mask), vbicq_s16(vdupq_n_s16(1), mask));
//deltas128 = vmulq_s16(scales128, deltas128);
scales128 = vshlq_n_u16(scales128, 3);
auto qs = iq1s[ibl].qs;
auto qh = iq1s[ibl].qh;
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
qx[4*ib64+0] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[0] | ((qh[2*ib64+0] << 8) & 0x700)], iq1s_grid[qs[1] | ((qh[2*ib64+0] << 5) & 0x700)]});
qx[4*ib64+1] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[2] | ((qh[2*ib64+0] << 2) & 0x700)], iq1s_grid[qs[3] | ((qh[2*ib64+0] >> 1) & 0x700)]});
qx[4*ib64+2] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[4] | ((qh[2*ib64+1] << 8) & 0x700)], iq1s_grid[qs[5] | ((qh[2*ib64+1] << 5) & 0x700)]});
qx[4*ib64+3] = vreinterpretq_s8_u64(uint64x2_t{iq1s_grid[qs[6] | ((qh[2*ib64+1] << 2) & 0x700)], iq1s_grid[qs[7] | ((qh[2*ib64+1] >> 1) & 0x700)]});
qs += 8;
}
scales[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales128)));
scales[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales128)));
deltas[0] = vget_low_s16 (deltas128);
deltas[1] = vget_high_s16(deltas128);
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = q8.load_bsums8(iy, ibl);
auto sumi = vdupq_n_s32(0);
sumi = vmlal_s16(sumi, deltas[0], vget_low_s16 (bsums));
sumi = vmlal_s16(sumi, deltas[1], vget_high_s16(bsums));
for (int k = 0; k < QK_K/128; ++k) {
auto qy = q8.load_quants_64(iy, ibl, 2*k+0);
auto dot1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+0], qy.val[0]), qx[8*k+1], qy.val[1]);
auto dot2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+2], qy.val[2]), qx[8*k+3], qy.val[3]);
auto dot12 = vpaddq_s32(dot1, dot2);
qy = q8.load_quants_64(iy, ibl, 2*k+1);
auto dot3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+4], qy.val[0]), qx[8*k+5], qy.val[1]);
auto dot4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[8*k+6], qy.val[2]), qx[8*k+7], qy.val[3]);
auto dot34 = vpaddq_s32(dot3, dot4);
auto dot = vpaddq_s32(dot12, dot34);
sumi = vmlaq_s32(sumi, dot, scales[k]);
}
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d*q8.scale(iy, ibl)), vcvtq_f32_s32(sumi));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, 0.125f*vaddvq_f32(acc[iy]));
acc[iy] = vdupq_n_f32(0);
}
}
}
template <int nrc_y>
static void mul_mat_iq1_m_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
auto shuffle0 = uint32x4_t{0x00000000, 0x01010101, 0x02020202, 0x03030303};
auto step = vdupq_n_u8(4);
auto ms = vdupq_n_u8(0x08);
auto mask = vdupq_n_s8(0x18);
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vmulq_f32(vdupq_n_f32(0.125f), vcvt_f32_f16(vld1_f16((const float16_t *)dptr)));
auto x = (const block_iq1_m_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
for (int k = 0; k < 4; ++k) {
auto scales4 = vdup_n_u32(((const uint32_t *)x[4*ib+k].scales)[0]);
scales4 = vand_u8(vshl_u32(scales4, int32x2_t{0, -4}), vdup_n_u8(0xf));
auto scales16 = vmovl_u8(scales4);
auto scales1 = vmovl_u16(vget_low_u16(scales16));
auto scales2 = vmovl_u16(vget_high_u16(scales16));
auto qh = (const uint32_t *)x[4*ib+k].qh;
auto idxh = uint32x4_t{qh[0], qh[0] >> 4, qh[1], qh[1] >> 4};
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(idxh, ms), ms), vdupq_n_u8(1)));
signs = vaddq_s8(signs, vdupq_n_s8(-8));
qx[0] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 0] | ((x[4*ib+k].qh[0] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 1] | ((x[4*ib+k].qh[1] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 2] | ((x[4*ib+k].qh[2] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 3] | ((x[4*ib+k].qh[3] << 8) & 0x0700)]});
qx[2] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 4] | ((x[4*ib+k].qh[0] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 5] | ((x[4*ib+k].qh[1] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 6] | ((x[4*ib+k].qh[2] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 7] | ((x[4*ib+k].qh[3] << 4) & 0x0700)]});
qx[4] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[ 8] | ((x[4*ib+k].qh[4] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[ 9] | ((x[4*ib+k].qh[5] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[10] | ((x[4*ib+k].qh[6] << 8) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[11] | ((x[4*ib+k].qh[7] << 8) & 0x0700)]});
qx[6] = vreinterpretq_s8_u32(uint32x4_t{iq1s_grid_us[x[4*ib+k].qs[12] | ((x[4*ib+k].qh[4] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[13] | ((x[4*ib+k].qh[5] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[14] | ((x[4*ib+k].qh[6] << 4) & 0x0700)],
iq1s_grid_us[x[4*ib+k].qs[15] | ((x[4*ib+k].qh[7] << 4) & 0x0700)]});
auto shuffle = shuffle0;
for (int j = 0; j < 4; ++j) {
auto s = vqtbl1q_s8(signs, shuffle);
qx[2*j+1] = vaddq_s8(s, vandq_s8(vshrq_n_s8(qx[2*j+0], 1), mask));
qx[2*j+0] = vaddq_s8(s, vandq_s8(vshlq_n_s8(qx[2*j+0], 3), mask));
shuffle = vaddq_u8(shuffle, step);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib].qs + 32*k);
auto sumi1 = vdupq_n_s32(0);
auto sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[0]), y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[1]), y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[2]), y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, vreinterpretq_s8_u8(qx[3]), y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[4]), y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[5]), y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
isum[iy] = vmlaq_s32(vmlaq_s32(isum[iy], sumi1, scales1), sumi2, scales2);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d1, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -5699,25 +5154,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ3_S:
MulMat::set_functions<DequantizerIQ3S>(m);
break;
case GGML_TYPE_IQ1_BN:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1bn_q8_K64);
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2bn_q8_K64);
expected_Btype = GGML_TYPE_Q8_K64;
break;
case GGML_TYPE_IQ2_BN_R4:
m.funcs[0] = mul_mat_iq2_bn_r4_q8_k16<1>;
m.funcs[1] = mul_mat_iq2_bn_r4_q8_k16<2>;
m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
//m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
//m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
//m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
expected_Btype = GGML_TYPE_Q8_K16;
break;
case GGML_TYPE_Q4_0:
MulMat::set_functions<DequantizerQ40>(m);
expected_Btype = GGML_TYPE_Q8_0_X4;
@@ -5773,22 +5209,13 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.func16 = mul_mat_iq2_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_BN_R4:
case GGML_TYPE_IQ1_S:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_q8_K);
m.func16 = mul_mat_iq1_s_q8_K<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1);
m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
m.func16 = mul_mat_iq1_s_r4_q8_1<16>;
expected_Btype = GGML_TYPE_Q8_K128;
break;
case GGML_TYPE_IQ1_M_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0);
m.func16 = mul_mat_iq1_m_r4_q8_0<16>;
expected_Btype = GGML_TYPE_Q8_K128;
break;
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ3_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;