iq1_s_r4: Use Q8_K_128 instead of Q8_1_X4 for gemm (Neon)

This commit is contained in:
Iwan Kawrakow
2025-02-09 06:32:21 +02:00
parent b6c4ef9a35
commit 166a157c7a

View File

@@ -12079,7 +12079,7 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data
static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<1, block_q8_1_x4> q8(info);
Q8<1, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
@@ -12091,8 +12091,8 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat
auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
auto scale_yd = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+0));
auto scale_ym = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+4));
auto scale_yd = vdupq_n_f32(q8.y[0][ib].d);
auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums))));
for (int k = 0; k < 4; ++k) {
auto sas = vld1_u16(x[4*ib+k].qh);
auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7));
@@ -12142,23 +12142,22 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat
template <int nrc_y>
static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_1_x4> q8(info);
Q8<nrc_y, block_q8_K128> q8(info);
int nb = n / 32;
GGML_ASSERT(nb%4 == 0);
uint8x16_t qx[8];
int32x4_t acc[nrc_y] = {};
auto ms = vdup_n_u16(0x8000);
auto mask = vdupq_n_s8(0x03);
float d8[8*nrc_y];
float d8[4*nrc_y];
for (int ix= 0; ix < nrc_x; ix += 4) {
auto dptr = (const ggml_half *)((const char *)vx + ix*bx);
auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr));
auto x = (const block_iq1_s_r4 *)(dptr + 4);
for (int ib = 0; ib < nb/4; ++ib) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto scales = vld1q_f16((const float16_t *)q8.y[iy][ib].d);
vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vget_low_f16(scales)));
vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vget_high_f16(scales)));
auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums)));
vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales));
}
for (int k = 0; k < 4; ++k) {
auto sas = vld1_u16(x[4*ib+k].qh);
@@ -12200,8 +12199,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3);
sumi = vmulq_s32(scales, sumi);
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0]), vcvtq_f32_s32(sumi));
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+4]), delta4);
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi));
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4);
}
}
}
@@ -13914,7 +13913,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1);
m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1;
m.func16 = mul_mat_iq1_s_r4_q8_1<16>;
expected_Btype = GGML_TYPE_Q8_1_X4;
expected_Btype = GGML_TYPE_Q8_K128;
break;
case GGML_TYPE_IQ1_M_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0);