Refactor iqk: factor out repacked legacy quants (NEON)

This commit is contained in:
Iwan Kawrakow
2025-05-19 07:51:28 +03:00
parent bd1e4d4909
commit 2b8a231d87
3 changed files with 339 additions and 431 deletions

View File

@@ -771,7 +771,58 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf
}
}
static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
return sumi;
}
static IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
return sumi;
}
static IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
return sumi;
}
static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
}
static IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
}
#endif

View File

@@ -2310,6 +2310,273 @@ static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInf
mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x);
}
template <typename Dequantizer, int nrc_y>
void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
int8x16_t qx[8];
float d8[4*nrc_y];
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
deq.new_row(ix);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales = deq.prepare(4*ib4+k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = deq.prepare(ib, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
auto y = vld1q_s8_x2(qy[ib].qs);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, deq.result(acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <typename Dequantizer, int nrc_y>
void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
int8x16_t qx[16];
float d8[4*nrc_y];
float32x4_t acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 8) {
deq.new_row(ix);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales = deq.prepare(ib4, k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi1 = interleaved_dotq(qx+0, y);
auto sumi2 = interleaved_dotq(qx+8, y);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = deq.prepare(ib, 0, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
auto y = vld1q_s8_x2(qy[ib].qs);
auto sumi1 = interleaved_dotq(qx+0, y);
auto sumi2 = interleaved_dotq(qx+8, y);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, deq.result(acc[2*iy+0]));
info.store(ix+4, iy, deq.result(acc[2*iy+1]));
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
struct IQ4_NL_R4_Dequantizer {
IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
auto bits = vld1q_u8_x4(iq4[ib].qs);
prepare_iq4_nl_quants(values, m4, bits, qx);
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_iq4_nl_r4 * iq4;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const int8x16_t values;
};
struct Q4_0_R8_Dequantizer {
Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); }
inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d);
float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) };
for (int j = 0; j < 4; ++j) {
auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j);
bits.val[0] = veorq_u8(m88, bits.val[0]);
bits.val[1] = veorq_u8(m88, bits.val[1]);
qx[2*j+0] = vshlq_n_u8(bits.val[0], 4);
qx[2*j+1] = vandq_u8(bits.val[0], m4);
qx[2*j+8] = vshlq_n_u8(bits.val[1], 4);
qx[2*j+9] = vandq_u8(bits.val[1], m4);
}
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return vmulq_f32(norm, acc);
}
const char * cx;
const size_t bx;
const block_iq4_nl_r8 * iq4;
const uint8x16_t m4 = vdupq_n_u8(0xf0);
const uint8x16_t m88 = vdupq_n_u8(0x88);
const float32x4_t norm = vdupq_n_f32(1.f/16);
};
struct Q5_0_R4_Dequantizer {
Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
auto lbits = vld1q_u8_x4(iq5[ib].qs);
auto hbits = vld1q_u8(iq5[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_q5_0_r4 * iq5;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const uint8x16_t m5 = vdupq_n_u8(0x10);
const int8x16_t m16 = vdupq_n_s8(-16);
};
struct Q6_0_R4_Dequantizer {
Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
auto lbits = vld1q_u8_x4(iq6[ib].qs);
auto hbits = vld1q_u8_x2(iq6[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_q6_0_r4 * iq6;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const uint8x16_t m6 = vdupq_n_u8(0x30);
const int8x16_t m32 = vdupq_n_s8(-32);
};
inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
auto y = vld1q_s8_x2(qy);
sumi1 = sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
}
template <int nrc_y>
void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
}
bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -2344,29 +2611,26 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array<mu
case GGML_TYPE_IQ4_NL:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerIQ4NL, kernels);
break;
// case GGML_TYPE_Q4_0_R8:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_0_r8_q8_2, kernels)
//#ifdef HAVE_FANCY_SIMD
// func16 = mul_mat_q4_0_r8_q8_2<16>;
//#endif
// break;
// case GGML_TYPE_Q5_0_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_0_r4_q8_2, kernels)
// break;
// case GGML_TYPE_Q6_0_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_0_r4_q8_2, kernels)
// break;
// case GGML_TYPE_Q8_0_R8:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_2, kernels)
// break;
// case GGML_TYPE_IQ4_NL_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels)
// break;
case GGML_TYPE_Q4_0_R8:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer, kernels);
break;
case GGML_TYPE_Q5_0_R4:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer, kernels);
break;
case GGML_TYPE_Q6_0_R4:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer, kernels);
break;
case GGML_TYPE_Q8_0_R8:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_0, kernels);
break;
case GGML_TYPE_IQ4_NL_R4:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer, kernels);
break;
default:
return false;
}
return ggml_type(typeB) == expected_typeB;
return true;
}
#endif

View File

@@ -938,59 +938,6 @@ template <int nrc> struct Q8_16 {
const int8_t * y[nrc_y];
};
IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
return sumi;
}
IQK_ALWAYS_INLINE int32x4x2_t interleaved_dotq_b16(const int8x16_t * qx, const int8x16x2_t& y) {
int32x4x2_t sumi = { vdupq_n_s32(0), vdupq_n_s32(0) };
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[0], y.val[0], 0);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[1], y.val[1], 0);
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[2], y.val[0], 1);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[3], y.val[1], 1);
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[4], y.val[0], 2);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[5], y.val[1], 2);
sumi.val[0] = vdotq_laneq_s32(sumi.val[0], qx[6], y.val[0], 3);
sumi.val[1] = vdotq_laneq_s32(sumi.val[1], qx[7], y.val[1], 3);
return sumi;
}
IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16_t& y) {
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y, 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y, 1);
sumi = vdotq_laneq_s32(sumi, qx[2], y, 2);
sumi = vdotq_laneq_s32(sumi, qx[3], y, 3);
return sumi;
}
IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
}
IQK_ALWAYS_INLINE void prepare_iq4_nl_quants_r8(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x2_t& bits, int8x16_t * qx) {
qx[0] = vqtbl1q_s8(values, vandq_u8( bits.val[0], m4));
qx[1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4));
qx[2] = vqtbl1q_s8(values, vandq_u8( bits.val[1], m4));
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4));
}
template <int nrc_y>
void mul_mat_iq4_xs_r8_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -2573,345 +2520,6 @@ void mul_mat_q8_KV_r8_q8_KV(int n, const void * vx, size_t bx, const DataInfo& i
}
}
void mul_mat_iq4_nl_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_0_x4> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto values = vld1q_s8(iq4k_values);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
auto acc = vdupq_n_f32(0.f);
const block_iq4_nl_r4 * iq4 = (const block_iq4_nl_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
auto y1 = vld1q_s8_x4(q8.y[0][ib4].qs);
auto y2 = vld1q_s8_x4(q8.y[0][ib4].qs+64);
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[0][ib4].d[k])));
auto sumi = vdupq_n_s32(0);
const auto yval = k < 2 ? y1.val + 2*k : y2.val + 2*(k-2);
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
sumi = vdotq_laneq_s32(sumi, qx[0], yval[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], yval[1], 0);
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
sumi = vdotq_laneq_s32(sumi, qx[2], yval[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], yval[1], 1);
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
sumi = vdotq_laneq_s32(sumi, qx[4], yval[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], yval[1], 2);
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
sumi = vdotq_laneq_s32(sumi, qx[6], yval[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], yval[1], 3);
acc = vfmaq_f32(acc, d4d8, vcvtq_f32_s32(sumi));
}
}
info.store(ix, 0, acc);
}
}
template <typename Dequantizer, int nrc_y>
void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
int8x16_t qx[8];
float d8[4*nrc_y];
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
deq.new_row(ix);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales = deq.prepare(4*ib4+k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = deq.prepare(ib, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
auto y = vld1q_s8_x2(qy[ib].qs);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, deq.result(acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <typename Dequantizer, int nrc_y>
void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
int8x16_t qx[16];
float d8[4*nrc_y];
float32x4_t acc[2*nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 8) {
deq.new_row(ix);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales = deq.prepare(ib4, k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi1 = interleaved_dotq(qx+0, y);
auto sumi2 = interleaved_dotq(qx+8, y);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales = deq.prepare(ib, 0, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
auto y = vld1q_s8_x2(qy[ib].qs);
auto sumi1 = interleaved_dotq(qx+0, y);
auto sumi2 = interleaved_dotq(qx+8, y);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales.val[0], dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales.val[1], dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, deq.result(acc[2*iy+0]));
info.store(ix+4, iy, deq.result(acc[2*iy+1]));
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
struct IQ4_NL_R4_Dequantizer {
IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
auto bits = vld1q_u8_x4(iq4[ib].qs);
prepare_iq4_nl_quants(values, m4, bits, qx);
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_iq4_nl_r4 * iq4;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const int8x16_t values;
};
struct Q4_0_R4_Dequantizer {
Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
qx[4] = vandq_u8(bits.val[0], m4); // 8..11
qx[5] = vandq_u8(bits.val[1], m4); // 24..27
qx[6] = vandq_u8(bits.val[2], m4); // 12..15
qx[7] = vandq_u8(bits.val[3], m4); // 28..31
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return vmulq_f32(norm, acc);
}
const char * cx;
const size_t bx;
const block_iq4_nl_r4 * iq4;
const uint8x16_t m4 = vdupq_n_u8(0xf0);
const uint8x16_t m88 = vdupq_n_u8(0x88);
const float32x4_t norm = vdupq_n_f32(1.f/16);
};
struct Q4_0_R8_Dequantizer {
Q4_0_R8_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r8 *)(cx + ix*bx); }
inline float32x4x2_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales16 = vld1q_f16((const float16_t *)iq4[4*ib4+k].d);
float32x4x2_t scales = { vcvt_f32_f16(vget_low_f16(scales16)), vcvt_f32_f16(vget_high_f16(scales16)) };
for (int j = 0; j < 4; ++j) {
auto bits = vld1q_u8_x2(iq4[4*ib4+k].qs + 32*j);
bits.val[0] = veorq_u8(m88, bits.val[0]);
bits.val[1] = veorq_u8(m88, bits.val[1]);
qx[2*j+0] = vshlq_n_u8(bits.val[0], 4);
qx[2*j+1] = vandq_u8(bits.val[0], m4);
qx[2*j+8] = vshlq_n_u8(bits.val[1], 4);
qx[2*j+9] = vandq_u8(bits.val[1], m4);
}
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return vmulq_f32(norm, acc);
}
const char * cx;
const size_t bx;
const block_iq4_nl_r8 * iq4;
const uint8x16_t m4 = vdupq_n_u8(0xf0);
const uint8x16_t m88 = vdupq_n_u8(0x88);
const float32x4_t norm = vdupq_n_f32(1.f/16);
};
struct Q5_0_R4_Dequantizer {
Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
auto lbits = vld1q_u8_x4(iq5[ib].qs);
auto hbits = vld1q_u8(iq5[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_q5_0_r4 * iq5;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const uint8x16_t m5 = vdupq_n_u8(0x10);
const int8x16_t m16 = vdupq_n_s8(-16);
};
struct Q6_0_R4_Dequantizer {
Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
auto lbits = vld1q_u8_x4(iq6[ib].qs);
auto hbits = vld1q_u8_x2(iq6[ib].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 2), m6), m32); // 20..23
qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits.val[0], m6), m32); // 8..11
qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(hbits.val[1], m6), m32); // 24..27
qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits.val[0], 2), m6), m32); // 12..15
qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits.val[1], 2), m6), m32); // 28..31
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_q6_0_r4 * iq6;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const uint8x16_t m6 = vdupq_n_u8(0x30);
const int8x16_t m32 = vdupq_n_s8(-32);
};
inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
auto y = vld1q_s8_x2(qy);
sumi1 = sumi2 = vdupq_n_s32(0);
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
}
template <int nrc_y>
void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%8 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
int nb = n / QK8_0;
float32x4_t acc[2*nrc_y] = {};
int8x16_t qx[16];
float d8[4*nrc_y];
for (int ix = 0; ix < nrc_x; ix += 8) {
const block_q8_0_r8 * iq8 = (const block_q8_0_r8 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
for (int k = 0; k < 4; ++k) {
auto scales16 = vld1q_f16((const float16_t *)iq8[4*ib4+k].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
auto dy = vdupq_n_f32(d8[4*iy+k]);
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
}
for (int ib = 4*(nb/4); ib < nb; ++ib) {
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
int32x4_t sumi1, sumi2;
for (int iy = 0; iy < nrc_y; ++iy) {
auto qy = (const block_q8_0 *)q8.y[iy];
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, acc[2*iy+0]);
info.store(ix+4, iy, acc[2*iy+1]);
acc[2*iy] = acc[2*iy+1] = vdupq_n_f32(0.f);
}
}
}
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
m.funcs[0] = func<Dequantizer, 1>;\
m.funcs[1] = func<Dequantizer, 2>;\
@@ -2975,11 +2583,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q6_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_IQ4_NL:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_Q4_0_R8:
case GGML_TYPE_Q5_0_R4:
case GGML_TYPE_Q6_0_R4:
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_IQ4_NL_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ4_XS_R8:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r8_q8_k);
expected_Btype = GGML_TYPE_Q8_K32;
@@ -3074,22 +2683,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_ks_r4_q8_k_neon);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q4_0_R8:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r8_q8_0, Q4_0_R8_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q5_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q6_0_R4:
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer);
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
case GGML_TYPE_Q8_0_R8:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r8_q8_0);
expected_Btype = GGML_TYPE_Q8_0_X4;
break;
default:
return false;
}