iq3_xxs_r4: NEON

This commit is contained in:
Iwan Kawrakow
2024-12-19 17:47:39 +01:00
parent b33e128d33
commit 1875b1d8e6

View File

@@ -9585,6 +9585,56 @@ void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i
}
}
template <int nrc_y>
static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)));
auto qs = iq3[ibl].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto sas = vld1q_u8(iq3[ibl].sas + 16*ib);
auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
auto sumi12 = vpaddq_s32(sumi1, sumi2);
auto sumi34 = vpaddq_s32(sumi3, sumi4);
auto sumi = vpaddq_s32(sumi12, sumi34);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qs += 32;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y, int k_shift>
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
int32x4_t * isum) {
@@ -10882,6 +10932,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_ks_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q2_K_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q2_k_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;