iq1_kt: tiny bit better GEMV on NEON

This commit is contained in:
Iwan Kawrakow
2025-07-16 12:47:49 +02:00
parent 882fc0235e
commit 31554f534f

View File

@@ -2315,30 +2315,36 @@ void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5));
idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5));
idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5));
//if constexpr (nrc_y == 1) {
// const block_q8_0_x4& ybl = y[0][2*i+0];
// const block_q8_0_x4& ybh = y[0][2*i+1];
// auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
// auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
// int32x4x4_t suml = {};
// int32x4x4_t sumh = {};
// for (int ib = 0; ib < 4; ++ib) {
// auto xl = trellis.next32(ql + 4*ib + 0, 4096);
// auto xh = trellis.next32(ql + 4*ib + 16, 4096);
// auto yl = vld1q_s8_x2(ybl.qs + 32*ib);
// auto yh = vld1q_s8_x2(ybh.qs + 32*ib);
// suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
// sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
// }
// auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
// auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
// auto sl = vpaddq_s32(sl1, sl2);
// auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
// auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
// auto sh = vpaddq_s32(sh1, sh2);
// accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
// accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
//} else {
if constexpr (nrc_y == 1) {
const block_q8_0_x4& ybl = y[0][2*i+0];
const block_q8_0_x4& ybh = y[0][2*i+1];
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
int32x4x4_t suml = {};
int32x4x4_t sumh = {};
for (int ib = 0; ib < 2; ++ib) {
auto xl = trellis.next32(vget_low_u16(idx.val[ib+0]));
auto xh = trellis.next32(vget_low_u16(idx.val[ib+2]));
auto yl = vld1q_s8_x2(ybl.qs + 64*ib);
auto yh = vld1q_s8_x2(ybh.qs + 64*ib);
suml.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
sumh.val[2*ib+0] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
xl = trellis.next32(vget_high_u16(idx.val[ib+0]));
xh = trellis.next32(vget_high_u16(idx.val[ib+2]));
yl = vld1q_s8_x2(ybl.qs + 64*ib + 32);
yh = vld1q_s8_x2(ybh.qs + 64*ib + 32);
suml.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
sumh.val[2*ib+1] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
}
auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
auto sl = vpaddq_s32(sl1, sl2);
auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
auto sh = vpaddq_s32(sh1, sh2);
accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
} else {
for (int k = 0; k < 4; ++k) {
xv[2*k+0] = trellis.next32(vget_low_u16 (idx.val[k]));
xv[2*k+1] = trellis.next32(vget_high_u16(idx.val[k]));
@@ -2358,7 +2364,7 @@ void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
}
}
//}
}
}
if constexpr (nrc_y == 1) {