iq3_s_r4: NEON

This commit is contained in:
Iwan Kawrakow
2024-12-22 20:25:47 +01:00
parent ada68c3b37
commit b31c3e9103

View File

@@ -10643,6 +10643,69 @@ static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const Dat
}
}
template <int nrc_y>
static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
SignHelper sh;
uint32_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
auto qs = iq3[ibl].qs;
auto qh = iq3[ibl].qh;
auto scale_bits = vld1q_u8(iq3[ibl].scales);
uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) };
auto tmp = vzip1q_u32(scales8.val[0], scales8.val[1]);
scales8.val[1] = vzip2q_u32(scales8.val[0], scales8.val[1]);
scales8.val[0] = vorrq_u8(vshlq_n_u8(tmp, 1), vdupq_n_u8(1));
scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), vdupq_n_u8(1));
vst1q_u8_x2((uint8_t *)stored_scales, scales8);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib);
sh.init();
for (int i = 0; i < 4; ++i) {
qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[8*i+0] | ((qh[i] << 8) & 0x100)], iq3s_grid[qs[8*i+1] | ((qh[i] << 7) & 0x100)],
iq3s_grid[qs[8*i+2] | ((qh[i] << 6) & 0x100)], iq3s_grid[qs[8*i+3] | ((qh[i] << 5) & 0x100)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128);
qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[8*i+4] | ((qh[i] << 4) & 0x100)], iq3s_grid[qs[8*i+5] | ((qh[i] << 3) & 0x100)],
iq3s_grid[qs[8*i+6] | ((qh[i] << 2) & 0x100)], iq3s_grid[qs[8*i+7] | ((qh[i] << 1) & 0x100)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128);
}
auto sc16 = vmovl_s8(vreinterpret_s8_u32(vdup_n_u32(stored_scales[ib])));
auto scales = vmovl_s16(vget_low_s16(sc16));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
//auto sumi = interleaved_dotq(qx, y);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
auto sumi12 = vpaddq_s32(sumi1, sumi2);
auto sumi34 = vpaddq_s32(sumi3, sumi4);
auto sumi = vpaddq_s32(sumi12, sumi34);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qs += 32;
qh += 4;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y, int k_shift>
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
int32x4_t * isum) {
@@ -11960,6 +12023,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k);
m.func16 = mul_mat_iq3_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;