Refactor iqk: factor out repacked iqk quants (NEON)

This commit is contained in:
Iwan Kawrakow
2025-05-19 08:25:56 +03:00
parent 7e59d2b974
commit 7aa2de6d5a
2 changed files with 671 additions and 742 deletions

View File

@@ -2561,13 +2561,661 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
};
template <int nrc_y>
void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[8];
int16x8x4_t iscales;
int32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto d4 = vld1q_f32(dptr);
const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto sas = vld1q_u8_x2(iq4[ibl].scales);
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
// Adding the block shifts costs us ~9% in performance drop.
// Is there a better way?
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
{
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
auto b8 = vget_low_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vget_high_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
}
for (int is = 0; is < 2; ++is) {
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
prepare_iq4_nl_quants(values, m4, bits, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m10 = vdupq_n_u8(0x10);
auto values = vld1q_s8_x2(iq5nl_values);
int nbl = n / QK_K;
int8x16_t qx[8];
int16x8x4_t iscales;
int32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto d4 = vld1q_f32(dptr);
const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto sas = vld1q_u8_x2(iq5[ibl].scales);
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
// Adding the block shifts costs us ~9% in performance drop.
// Is there a better way?
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1);
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1);
{
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
auto b8 = vget_low_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vget_high_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
}
for (int is = 0; is < 2; ++is) {
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4)));
qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2)));
qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits));
qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2)));
qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3)));
qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1)));
qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1)));
qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3)));
for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y, int k_shift>
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
int32x4_t * isum) {
auto ms = vdupq_n_s8(k_shift);
int8x16_t s8_1, s8_2;
if constexpr (k_shift == 5) {
auto m1 = vdupq_n_u8(1);
s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
} else {
if constexpr (k_shift == 4) {
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
} else {
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1)));
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1)));
}
}
auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
for (int iy = 0; iy < nrc_y; ++iy) {
auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
if constexpr (k_shift == 5) {
auto m1 = vdupq_n_u8(1);
s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
} else {
if constexpr (k_shift == 4) {
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
} else {
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3)));
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5)));
}
}
s16_1 = vmovl_s8(vget_low_s8 (s8_1));
s16_2 = vmovl_s8(vget_high_s8(s8_1));
s16_3 = vmovl_s8(vget_low_s8 (s8_2));
s16_4 = vmovl_s8(vget_high_s8(s8_2));
for (int iy = 0; iy < nrc_y; ++iy) {
auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
}
template <int nrc_y>
void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m03 = vdupq_n_u8(0x03);
auto ms = vdupq_n_u8(4);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
auto values8 = vld1_s8(iq2nl_values);
auto values = vcombine_s8(values8, values8);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto extra8 = vld1_u8(iq2[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq2[ibl].scales);
i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8));
i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8));
i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8));
i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8));
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
qx[0] = vandq_u8( bits.val[0], m03);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03);
qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03);
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qx[0] = vandq_u8( bits.val[1], m03);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03);
qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03);
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
auto m03 = vdupq_n_u8(0x03);
auto m04 = vdupq_n_u8(0x04);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
auto values = vld1q_s8(iq3nl_values);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
auto extra8 = vld1_u8(iq3[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
auto sh8 = vld1_u8(iq3[ibl].scales_h);
auto sh = vcombine_u8(sh8, sh8);
i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
sh = vshrq_n_u8(sh, 4);
i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m3 = vdupq_n_u8(0x30);
auto ms = vdupq_n_u8(4);
auto m32 = vdupq_n_s8(-32);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
auto extra8 = vld1_u8(iq4[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
auto sh = vld1q_u8(iq4[ibl].scales_h);
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15
}
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m3 = vdupq_n_u8(0x30);
auto ms = vdupq_n_u8(2);
auto m32 = vdupq_n_s8(-32);
auto m10 = vdupq_n_u8(0x10);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
auto values = vld1q_s8_x2(iq5nl_values);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
auto extra8 = vld1_u8(iq5[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
auto sh = vld1q_u8(iq5[ibl].scales_h);
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2
qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 1));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
}
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2
qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2
if constexpr (nrc_y == 1) {
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
}
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
return false;
}
@@ -2599,33 +3247,20 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
case GGML_TYPE_IQ6_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ6K, kernels);
break;
// case GGML_TYPE_IQ2_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
// break;
// case GGML_TYPE_IQ3_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
//#ifdef HAVE_FANCY_SIMD
// func16 = mul_mat_iq3_k_r4_q8_k<16>;
//#endif
// break;
// case GGML_TYPE_IQ4_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
// func16 = mul_mat_iq4_k_r4_q8_k<16>;
// break;
// case GGML_TYPE_IQ4_KS_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
//#ifndef HAVE_FANCY_SIMD
// // For some reason Zen4 does not like this particular function
// func16 = mul_mat_iq4_ks_r4_q8_k<16>;
//#endif
// break;
// case GGML_TYPE_IQ5_KS_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
//#ifndef HAVE_FANCY_SIMD
// // For some reason Zen4 does not like this particular function
// func16 = mul_mat_iq5_ks_r4_q8_k<16>;
//#endif
// break;
case GGML_TYPE_IQ2_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
break;
case GGML_TYPE_IQ3_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
break;
case GGML_TYPE_IQ4_K_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
break;
case GGML_TYPE_IQ4_KS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
break;
case GGML_TYPE_IQ5_KS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
default:
return false;
}

View File

@@ -878,44 +878,6 @@ template <int nrc> struct Q8_K64 {
const int8_t * y[nrc_y];
};
struct DequantizerIQ1BN {
const uint8x16_t m1 = vdupq_n_u8(1);
static inline uint8x16x4_t load_shuffles() {
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
return vld1q_u8_x4(data);
}
static inline uint8x16x4_t load_mult() {
static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
return vld1q_u8_x4(data);
}
const uint8x16x4_t shuff = load_shuffles();
const uint8x16x4_t mult = load_mult();
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
for (int k = 0; k < 4; ++k) {
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
}
}
IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
for (int k = 0; k < 4; ++k) {
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
}
}
};
template <int nrc> struct Q8_16 {
constexpr static int nrc_y = nrc;
@@ -938,166 +900,6 @@ template <int nrc> struct Q8_16 {
const int8_t * y[nrc_y];
};
template <int nrc_y>
void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[8];
int16x8x4_t iscales;
int32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto d4 = vld1q_f32(dptr);
const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto sas = vld1q_u8_x2(iq4[ibl].scales);
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
// Adding the block shifts costs us ~9% in performance drop.
// Is there a better way?
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
{
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
auto b8 = vget_low_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vget_high_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
}
for (int is = 0; is < 2; ++is) {
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
prepare_iq4_nl_quants(values, m4, bits, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq5_ks_r4_q8_k_neon(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m10 = vdupq_n_u8(0x10);
auto values = vld1q_s8_x2(iq5nl_values);
int nbl = n / QK_K;
int8x16_t qx[8];
int16x8x4_t iscales;
int32x4x4_t scales;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
auto dptr = (const float *)((const char *)vx + ix*bx);
auto d4 = vld1q_f32(dptr);
const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto sas = vld1q_u8_x2(iq5[ibl].scales);
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
// Adding the block shifts costs us ~9% in performance drop.
// Is there a better way?
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1);
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1);
{
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
for (int iy = 0; iy < nrc_y; ++iy) {
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
auto b8 = vget_low_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vget_high_s16(bs);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
}
for (int is = 0; is < 2; ++is) {
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4)));
qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2)));
qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits));
qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2)));
qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3)));
qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1)));
qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1)));
qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3)));
for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -1407,496 +1209,6 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
}
}
template <int nrc_y, int k_shift>
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
int32x4_t * isum) {
auto ms = vdupq_n_s8(k_shift);
int8x16_t s8_1, s8_2;
if constexpr (k_shift == 5) {
auto m1 = vdupq_n_u8(1);
s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
} else {
if constexpr (k_shift == 4) {
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
} else {
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1)));
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1)));
}
}
auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
for (int iy = 0; iy < nrc_y; ++iy) {
auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
if constexpr (k_shift == 5) {
auto m1 = vdupq_n_u8(1);
s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
} else {
if constexpr (k_shift == 4) {
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
} else {
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3)));
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5)));
}
}
s16_1 = vmovl_s8(vget_low_s8 (s8_1));
s16_2 = vmovl_s8(vget_high_s8(s8_1));
s16_3 = vmovl_s8(vget_low_s8 (s8_2));
s16_4 = vmovl_s8(vget_high_s8(s8_2));
for (int iy = 0; iy < nrc_y; ++iy) {
auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
}
}
template <int nrc_y>
void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m03 = vdupq_n_u8(0x03);
auto ms = vdupq_n_u8(4);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
auto values8 = vld1_s8(iq2nl_values);
auto values = vcombine_s8(values8, values8);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto extra8 = vld1_u8(iq2[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq2[ibl].scales);
i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8));
i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8));
i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8));
i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8));
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
qx[0] = vandq_u8( bits.val[0], m03);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03);
qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03);
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qx[0] = vandq_u8( bits.val[1], m03);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03);
qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03);
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
auto m03 = vdupq_n_u8(0x03);
auto m04 = vdupq_n_u8(0x04);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
auto values = vld1q_s8(iq3nl_values);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
auto extra8 = vld1_u8(iq3[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
auto sh8 = vld1_u8(iq3[ibl].scales_h);
auto sh = vcombine_u8(sh8, sh8);
i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
sh = vshrq_n_u8(sh, 4);
i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m3 = vdupq_n_u8(0x30);
auto ms = vdupq_n_u8(4);
auto m32 = vdupq_n_s8(-32);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
auto values = vld1q_s8(iq4k_values);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
auto extra8 = vld1_u8(iq4[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
auto sh = vld1q_u8(iq4[ibl].scales_h);
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15
}
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
if constexpr (nrc_y == 1) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
auto m4 = vdupq_n_u8(0xf);
auto m3 = vdupq_n_u8(0x30);
auto ms = vdupq_n_u8(2);
auto m32 = vdupq_n_s8(-32);
auto m10 = vdupq_n_u8(0x10);
uint8x16x2_t shift_shuffle = {
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
};
auto values = vld1q_s8_x2(iq5nl_values);
int nbl = n / QK_K;
int8x16_t qx[4];
int8x16x4_t i8scales;
int16x8x4_t i16scales;
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx);
for (int ibl = 0; ibl < nbl; ++ibl) {
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
auto extra8 = vld1_u8(iq5[ibl].extra);
uint8x16_t extra;
if constexpr (nrc_y == 1) {
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
} else {
extra = vcombine_u8(extra8, extra8);
}
auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
auto sh = vld1q_u8(iq5[ibl].scales_h);
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
int32x4_t isum[nrc_y] = {};
if constexpr (nrc_y == 1) {
iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum);
}
for (int is = 0; is < 2; ++is) {
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
for (int ib = 0; ib < 4; ++ib) {
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2
qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2
uint8x16_t shifts;
if constexpr (nrc_y == 1) {
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
} else {
shifts = vandq_u8(ms, vshlq_n_u8(extra, 1));
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
extra = vshrq_n_u8(extra, 1);
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
}
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2
qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2
if constexpr (nrc_y == 1) {
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
} else {
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
}
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
m.funcs[0] = func<Dequantizer, 1>;\
m.funcs[1] = func<Dequantizer, 2>;\
@@ -1955,6 +1267,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ5_KS:
case GGML_TYPE_IQ5_K:
case GGML_TYPE_IQ6_K:
case GGML_TYPE_IQ2_K_R4:
case GGML_TYPE_IQ3_K_R4:
case GGML_TYPE_IQ4_K_R4:
case GGML_TYPE_IQ5_K_R4:
case GGML_TYPE_IQ4_KS_R4:
case GGML_TYPE_IQ5_KS_R4:
return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
@@ -1975,10 +1293,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_IQ4_NL_R4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ4_KS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_ks_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k);
m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
@@ -2011,26 +1325,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.func16 = mul_mat_iq3_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ4_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ5_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ5_KS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_ks_r4_q8_k_neon);
expected_Btype = GGML_TYPE_Q8_K;
break;
default:
return false;
}