mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-24 00:19:19 +00:00
iq1_kt: NEON GEMM/GEMV
Pathetic as usual
This commit is contained in:
@@ -2232,6 +2232,126 @@ void iqk_dequantize_iq3_kt_q80_r8(int n, const void * vx, size_t bx, void * vy,
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq1_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
|
||||
float32x4_t accd[k_acc];
|
||||
|
||||
const block_q8_0_x4 * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
int8x16x2_t xv[8];
|
||||
uint16x8x4_t idx;
|
||||
int32x4x4_t dot;
|
||||
uint16_t aux16[8];
|
||||
|
||||
auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = vld1q_s8_x2(y + 32*k);
|
||||
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
|
||||
}
|
||||
dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
|
||||
dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
|
||||
return vpaddq_s32(dot.val[0], dot.val[2]);
|
||||
};
|
||||
|
||||
float32x4x2_t scales;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = vdupq_n_f32(dptr[0]);
|
||||
const block_iq1_kt * x = (const block_iq1_kt *)(dptr + 1);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto sh = vld1_u8(x[i].sh);
|
||||
auto s16 = vmovl_s8(vqtbl1_s8(values, vand_u8(sh, vdup_n_u8(0xf))));
|
||||
scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16))));
|
||||
scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
|
||||
auto ql = vld1q_u8_x2(x[i].ql);
|
||||
auto qh = vld1q_u8(x[i].qh);
|
||||
auto qhl = vmovl_u8(vget_low_u8(qh));
|
||||
auto qhh = vmovl_u8(vget_high_u8(qh));
|
||||
idx.val[0] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 8)));
|
||||
idx.val[1] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[0])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 8)));
|
||||
idx.val[2] = vaddq_u16(vmovl_u8(vget_low_u8 (ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhl, 4)));
|
||||
idx.val[3] = vaddq_u16(vmovl_u8(vget_high_u8(ql.val[1])), vandq_u16(vdupq_n_u16(0xf00), vshlq_n_u16(qhh, 4)));
|
||||
//for (int k = 0; k < 4; ++k) idx.val[k] = vaddq_u16(idx.val[k], vdupq_n_u16(4096));
|
||||
auto sh16 = vandq_u16(vmovl_u8(sh), vdupq_n_u16(0xf0));
|
||||
auto sh32l = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_low_u16 (sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
|
||||
auto sh32h = vandq_u8(vreinterpretq_u8_u32(vmulq_u32(vmovl_u16(vget_high_u16(sh16)), vdupq_n_u32(0x01020408))), vdupq_n_u8(0x80));
|
||||
idx.val[0] = vaddq_u16(idx.val[0], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32l)), 5));
|
||||
idx.val[1] = vaddq_u16(idx.val[1], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32l)), 5));
|
||||
idx.val[2] = vaddq_u16(idx.val[2], vshlq_n_u16(vmovl_u8(vget_low_u8 (sh32h)), 5));
|
||||
idx.val[3] = vaddq_u16(idx.val[3], vshlq_n_u16(vmovl_u8(vget_high_u8(sh32h)), 5));
|
||||
//if constexpr (nrc_y == 1) {
|
||||
// const block_q8_0_x4& ybl = y[0][2*i+0];
|
||||
// const block_q8_0_x4& ybh = y[0][2*i+1];
|
||||
// auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
|
||||
// auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
|
||||
// int32x4x4_t suml = {};
|
||||
// int32x4x4_t sumh = {};
|
||||
// for (int ib = 0; ib < 4; ++ib) {
|
||||
// auto xl = trellis.next32(ql + 4*ib + 0, 4096);
|
||||
// auto xh = trellis.next32(ql + 4*ib + 16, 4096);
|
||||
// auto yl = vld1q_s8_x2(ybl.qs + 32*ib);
|
||||
// auto yh = vld1q_s8_x2(ybh.qs + 32*ib);
|
||||
// suml.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xl.val[0], yl.val[0]), xl.val[1], yl.val[1]);
|
||||
// sumh.val[ib] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xh.val[0], yh.val[0]), xh.val[1], yh.val[1]);
|
||||
// }
|
||||
// auto sl1 = vpaddq_s32(suml.val[0], suml.val[1]);
|
||||
// auto sl2 = vpaddq_s32(suml.val[2], suml.val[3]);
|
||||
// auto sl = vpaddq_s32(sl1, sl2);
|
||||
// auto sh1 = vpaddq_s32(sumh.val[0], sumh.val[1]);
|
||||
// auto sh2 = vpaddq_s32(sumh.val[2], sumh.val[3]);
|
||||
// auto sh = vpaddq_s32(sh1, sh2);
|
||||
// accd[0] = vfmaq_f32(accd[0], dyl, vcvtq_f32_s32(sl));
|
||||
// accd[1] = vfmaq_f32(accd[1], dyh, vcvtq_f32_s32(sh));
|
||||
//} else {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
vst1q_u16(aux16, idx.val[k]);
|
||||
xv[2*k+0] = trellis.next32(aux16+0, 4096);
|
||||
xv[2*k+1] = trellis.next32(aux16+4, 4096);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_0_x4& ybl = y[iy][2*i+0];
|
||||
const block_q8_0_x4& ybh = y[iy][2*i+1];
|
||||
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
|
||||
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
|
||||
auto sumil = compute_dot(ybl.qs, xv+0);
|
||||
auto sumih = compute_dot(ybh.qs, xv+4);
|
||||
//if constexpr (nrc_y == 1) {
|
||||
// accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
|
||||
// accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
|
||||
//} else {
|
||||
accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
|
||||
accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
|
||||
//}
|
||||
}
|
||||
//}
|
||||
}
|
||||
|
||||
//if constexpr (nrc_y == 1) {
|
||||
// info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
|
||||
//} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
}
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
@@ -2488,6 +2608,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_type(typeA) == GGML_TYPE_IQ1_KT) {
|
||||
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq1_kt_q8_0_x4_T, kernels);
|
||||
func16 = nullptr;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_type(typeB) != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -964,6 +964,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
case GGML_TYPE_IQ1_M_R4:
|
||||
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ1_KT:
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
|
||||
Reference in New Issue
Block a user