New iq4_kt: faster NEON

We are now at 9.4 t/s, up from 6.6 t/s for the f16 trellis.
This commit is contained in:
Iwan Kawrakow
2025-06-08 10:05:49 +03:00
parent 4102aa998c
commit 07d6e1d4b1

View File

@@ -1353,6 +1353,9 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
return vpaddq_s32(dot.val[0], dot.val[2]);
};
//int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}};
int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}};
float32x4x2_t scales;
for (int ix = 0; ix < nrc_x; ++ix) {
@@ -1376,14 +1379,33 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
for (int ib = 0; ib < 4; ++ib) {
for (int j = 0; j < 4; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
}
auto vql1 = vmovl_u8(vld1_u8(ql+8*ib));
auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32));
auto vqh = vmovl_u8(vld1_u8(qh+8*ib));
vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8)));
vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4)));
auto sh1_u32 = vdupq_n_u32(shb[ib+0]);
auto sh2_u32 = vdupq_n_u32(shb[ib+4]);
auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1])));
auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1])));
vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1));
vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2));
auto oh1 = vdupq_n_u32(o_helper.val[ib+0]);
auto oh2 = vdupq_n_u32(o_helper.val[ib+4]);
vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1));
vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1));
vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2));
vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2));
//auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts);
//auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts);
//for (int j = 0; j < 4; ++j) {
// const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
// const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
// values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
// values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
// values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
// values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
//}
xv[ib+0] = trellis.next32(values+0);
xv[ib+4] = trellis.next32(values+8);
}