diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 39267677..8c8df2f2 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12079,7 +12079,7 @@ static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const Data static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8<1, block_q8_1_x4> q8(info); + Q8<1, block_q8_K128> q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; @@ -12091,8 +12091,8 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); auto x = (const block_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { - auto scale_yd = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+0)); - auto scale_ym = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib].d+4)); + auto scale_yd = vdupq_n_f32(q8.y[0][ib].d); + auto scale_ym = vmulq_f32(scale_yd, vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[0][ib].bsums)))); for (int k = 0; k < 4; ++k) { auto sas = vld1_u16(x[4*ib+k].qh); auto scales4 = vand_u16(vshr_n_u16(sas, 12), vdup_n_u16(7)); @@ -12142,23 +12142,22 @@ static void mul_mat_iq1_s_r4_q8_1_1(int n, const void * vx, size_t bx, const Dat template static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); + Q8 q8(info); int nb = n / 32; GGML_ASSERT(nb%4 == 0); uint8x16_t qx[8]; int32x4_t acc[nrc_y] = {}; auto ms = vdup_n_u16(0x8000); auto mask = vdupq_n_s8(0x03); - float d8[8*nrc_y]; + float d8[4*nrc_y]; for (int ix= 0; ix < nrc_x; ix += 4) { auto dptr = (const ggml_half *)((const char *)vx + ix*bx); auto d1 = vcvt_f32_f16(vld1_f16((const float16_t *)dptr)); auto x = (const block_iq1_s_r4 *)(dptr + 4); for (int ib = 0; ib < nb/4; ++ib) { for (int iy = 0; iy < nrc_y; ++iy) { - auto scales = vld1q_f16((const float16_t *)q8.y[iy][ib].d); - vst1q_f32(d8+8*iy+0, vcvt_f32_f16(vget_low_f16(scales))); - vst1q_f32(d8+8*iy+4, vcvt_f32_f16(vget_high_f16(scales))); + auto scales = vcvtq_f32_s32(vmovl_s16(vld1_s16(q8.y[iy][ib].bsums))); + vst1q_f32(d8+4*iy, vmulq_f32(vdupq_n_f32(q8.y[iy][ib].d), scales)); } for (int k = 0; k < 4; ++k) { auto sas = vld1_u16(x[4*ib+k].qh); @@ -12200,8 +12199,8 @@ static void mul_mat_iq1_s_r4_q8_1(int n, const void * vx, size_t bx, const DataI sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[6]), y.val[1], 2); sumi = vdotq_laneq_s32(sumi, vreinterpretq_s8_u8(qx[7]), y.val[1], 3); sumi = vmulq_s32(scales, sumi); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+0]), vcvtq_f32_s32(sumi)); - acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[8*iy+k+4]), delta4); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.y[iy][ib].d), vcvtq_f32_s32(sumi)); + acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(d8[4*iy+k]), delta4); } } } @@ -13914,7 +13913,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_s_r4_q8_1); m.funcs[0] = mul_mat_iq1_s_r4_q8_1_1; m.func16 = mul_mat_iq1_s_r4_q8_1<16>; - expected_Btype = GGML_TYPE_Q8_1_X4; + expected_Btype = GGML_TYPE_Q8_K128; break; case GGML_TYPE_IQ1_M_R4: SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1_m_r4_q8_0);