mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-11 06:20:09 +00:00
New iq4_kt: faster NEON
We are now at 9.4 t/s, up from 6.6 t/s for the f16 trellis.
This commit is contained in:
@@ -1353,6 +1353,9 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
return vpaddq_s32(dot.val[0], dot.val[2]);
|
||||
};
|
||||
|
||||
//int32x4x2_t shifts = {int32x4_t{-8, -11, -14, -17}, int32x4_t{-20, -23, -26, -29}};
|
||||
int32x4x2_t shifts = {int32x4_t{4, 1, -2, -5}, int32x4_t{-8, -11, -14, -17}};
|
||||
|
||||
float32x4x2_t scales;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
@@ -1376,14 +1379,33 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
|
||||
o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
}
|
||||
auto vql1 = vmovl_u8(vld1_u8(ql+8*ib));
|
||||
auto vql2 = vmovl_u8(vld1_u8(ql+8*ib+32));
|
||||
auto vqh = vmovl_u8(vld1_u8(qh+8*ib));
|
||||
vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 8)));
|
||||
vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(vqh, 4)));
|
||||
auto sh1_u32 = vdupq_n_u32(shb[ib+0]);
|
||||
auto sh2_u32 = vdupq_n_u32(shb[ib+4]);
|
||||
auto sh1 = vcombine_u16(vmovn_u32(vshlq_u32(sh1_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh1_u32, shifts.val[1])));
|
||||
auto sh2 = vcombine_u16(vmovn_u32(vshlq_u32(sh2_u32, shifts.val[0])), vmovn_u32(vshlq_u32(sh2_u32, shifts.val[1])));
|
||||
vql1 = vaddq_u16(vql1, vandq_u16(vdupq_n_u16(0x7000), sh1));
|
||||
vql2 = vaddq_u16(vql2, vandq_u16(vdupq_n_u16(0x7000), sh2));
|
||||
auto oh1 = vdupq_n_u32(o_helper.val[ib+0]);
|
||||
auto oh2 = vdupq_n_u32(o_helper.val[ib+4]);
|
||||
vst1q_u32(values +0, vaddq_u32(vmovl_u16(vget_low_u16 (vql1)), oh1));
|
||||
vst1q_u32(values +4, vaddq_u32(vmovl_u16(vget_high_u16(vql1)), oh1));
|
||||
vst1q_u32(values +8, vaddq_u32(vmovl_u16(vget_low_u16 (vql2)), oh2));
|
||||
vst1q_u32(values+12, vaddq_u32(vmovl_u16(vget_high_u16(vql2)), oh2));
|
||||
//auto sh1 = vshlq_u32(vdupq_n_u32(shb[ib+0]), shifts);
|
||||
//auto sh2 = vshlq_u32(vdupq_n_u32(shb[ib+4]), shifts);
|
||||
//for (int j = 0; j < 4; ++j) {
|
||||
// const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
// const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
// values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
// values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
// values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
// values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
//}
|
||||
xv[ib+0] = trellis.next32(values+0);
|
||||
xv[ib+4] = trellis.next32(values+8);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user