iq1_kt: NEON GEMM/GEMV

Pathetic as usual
This commit is contained in:
Iwan Kawrakow
2025-07-16 11:51:15 +02:00
parent 6d1ddf1c26
commit 3b6597c7a1
2 changed files with 130 additions and 0 deletions

View File

@@ -2232,6 +2232,126 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
}
}
template <int nrc_y>
void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
Trellis3 trellis;
auto values = vld1q_s8(iq4k_values);
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
float32x4_t accd[k_acc];
const block_q8_0_x4 * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
}
int8x16x2_t xv[8];
uint16x8x4_t idx;
int32x4x4_t dot;
uint16_t aux16[8];
auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
for (int k = 0; k < 4; ++k) {
auto yv = vld1q_s8_x2(y + 32*k);
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
}
dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
return vpaddq_s32(dot.val[0], dot.val[2]);
};
float32x4x2_t scales;
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = vdupq_n_f32(dptr[0]);
const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1);
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
for (int i = 0; i < nb; ++i) {
auto sh = vld1_u8(x[i].sh);
auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf))));
scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16))));
scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
auto ql = vld1q_u8_x2(x[i].ql);
auto qh = vld1q_u8(x[i].qh);
auto qhl = vmovl_u8(vget_low_u8(qh));
auto qhh = vmovl_u8(vget_high_u8(qh));
idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8)));
idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8)));
idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4)));
idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4)));
//for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096));
auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0));
auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5));
idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5));
idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5));
idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5));
//if constexpr (nrc_y == 1) {
// const block_q8_0_x4& ybl = y[0][2*i+0];
// const block_q8_0_x4& ybh = y[0][2*i+1];
// auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
// auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
// int32x4x4_t suml = {};
// int32x4x4_t sumh = {};
// for (int ib = 0; ib < 4; ++ib) {
// auto xl = trellis.next32(ql + 4*ib + 0, 4096);
// auto xh = trellis.next32(ql + 4*ib + 16, 4096);
// auto yl = vld1q_s8_x2(ybl.qs + 32*ib);
// auto yh = vld1q_s8_x2(ybh.qs + 32*ib);
// suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
// sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
// }
// auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
// auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
// auto sl = vpaddq_s32(sl1, sl2);
// auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
// auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
// auto sh = vpaddq_s32(sh1, sh2);
// accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
// accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
//} else {
for (int k = 0; k < 4; ++k) {
vst1q_u16(aux16, idx.val[k]);
xv[2*k+0] = trellis.next32(aux16+0, 4096);
xv[2*k+1] = trellis.next32(aux16+4, 4096);
}
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_0_x4& ybl = y[iy][2*i+0];
const block_q8_0_x4& ybh = y[iy][2*i+1];
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
auto sumil = compute_dot(ybl.qs, xv+0);
auto sumih = compute_dot(ybh.qs, xv+4);
//if constexpr (nrc_y == 1) {
// accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
// accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
//} else {
accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
//}
}
//}
}
//if constexpr (nrc_y == 1) {
// info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
//} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(accd[iy]));
}
//}
}
}
template <int nrc_y>
void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
@@ -2488,6 +2608,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
return false;
}
if (ggml_type(typeA) == GGML_TYPE_IQ1_KT) {
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_kt_q8_0_x4_T, kernels);
func16 = nullptr;
return true;
}
return false;
}
if (ggml_type(typeB) != GGML_TYPE_F16) {
return false;
}

View File

@@ -964,6 +964,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ1_KT:
case GGML_TYPE_IQ2_KT:
case GGML_TYPE_IQ3_KT:
case GGML_TYPE_IQ4_KT: