mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-27 16:44:21 +00:00
New iq4_kt: slightly faster NEON
This commit is contained in:
@@ -1330,7 +1330,7 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
|
||||
union { uint32x4x2_t vec; uint32_t val[8]; } o_helper;
|
||||
|
||||
constexpr int k_acc = nrc_y;
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
|
||||
float32x4_t accd[k_acc];
|
||||
|
||||
@@ -1339,11 +1339,11 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
uint32_t values[64];
|
||||
int8x16x2_t xv[4];
|
||||
uint32_t values[16];
|
||||
int8x16x2_t xv[8];
|
||||
int32x4x4_t dot;
|
||||
|
||||
auto compute_dot = [&dot, &xv] (const int8_t * y) {
|
||||
auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = vld1q_s8_x2(y + 32*k);
|
||||
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
|
||||
@@ -1379,27 +1379,37 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
values[2*j+0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
values[2*j+1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
values[2*j+8] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
values[2*j+9] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
}
|
||||
xv[ib+0] = trellis.next32(values+0);
|
||||
xv[ib+4] = trellis.next32(values+8);
|
||||
}
|
||||
for (int i128 = 0; i128 < 2; ++i128) {
|
||||
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_0_x4& yb = y[iy][2*i+i128];
|
||||
auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d)));
|
||||
//auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16));
|
||||
//dy = vmulq_f32(scales.val[i128], dy);
|
||||
auto sumi = compute_dot(yb.qs);
|
||||
accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_0_x4& ybl = y[iy][2*i+0];
|
||||
const block_q8_0_x4& ybh = y[iy][2*i+1];
|
||||
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
|
||||
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
|
||||
auto sumil = compute_dot(ybl.qs, xv+0);
|
||||
auto sumih = compute_dot(ybh.qs, xv+4);
|
||||
if constexpr (nrc_y == 1) {
|
||||
accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
|
||||
accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
|
||||
} else {
|
||||
accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
|
||||
accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
if constexpr (nrc_y == 1) {
|
||||
info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user