iq2_s_r4: NEON

This commit is contained in:
Iwan Kawrakow
2024-12-21 11:13:58 +01:00
parent 74ee045b06
commit fe8eda7b47

View File

@@ -10072,6 +10072,72 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
}
}
template <int nrc_y>
static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[2*nrc_y] = {};
int8x16_t qx[8];
uint16x8x4_t scales16;
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
auto qh = iq2[ibl].qh;
for (int is = 0; is < 2; ++is) {
auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
auto scales2 = vshrq_n_u8(scale_bits, 4);
scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
auto s1 = vzip1q_u8(scales1, scales2);
auto s2 = vzip2q_u8(scales1, scales2);
scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
scales16.val[1] = vmovl_u8(vget_high_u8(s1));
scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
scales16.val[3] = vmovl_u8(vget_high_u8(s2));
for (int ib = 0; ib < QK_K/64; ++ib) {
auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib);
sh.init();
for (int i = 0; i < 4; ++i) {
qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128);
qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128);
}
qs += 16; qh += 4;
auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -11427,6 +11493,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_s_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
expected_Btype = GGML_TYPE_Q8_K;