diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ebd3db54..9c7c45e8 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12063,6 +12063,46 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf } } +template +void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + GGML_ASSERT(nrc_x%4 == 0); + GGML_ASSERT(n == 128); + int8x16x4_t qx[8]; + float32x4_t scales[4]; + float32x4_t scales_y[4]; + for (int ix = 0; ix < nrc_x; ix += 4) { + const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); + for (int k = 0; k < 4; ++k) { + scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d)); + qx[2*k+0] = vld1q_s8_x4(iq8[k].qs); + qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64); + } + for (int iy = 0; iy < nrc_y; ++iy) { + auto by = (const block_q8_0_x4 *)info.src1_row(iy); + auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d)); + scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0); + scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1); + scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2); + scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3); + auto sumf = vdupq_n_f32(0.f); + for (int k = 0; k < 4; ++k) { + auto y = vld1q_s8_x2(by->qs+32*k); + auto sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1); + sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3); + sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3); + sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi)); + } + info.store(ix, iy, sumf); + } + } +} + #define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ m.funcs[0] = func;\ m.funcs[1] = func;\ @@ -13645,7 +13685,23 @@ struct FlashQKfp32 { } else if constexpr (std::is_same_v>) { #ifdef __aarch64__ - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); + if constexpr (D == 128) { + if (q_step >= 64 && nq >= 64) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64); + } + else if (q_step >= 32 && nq >= 32) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32); + } + else if (q_step >= 16 && nq >= 16) { + return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16); + } + else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq); + } + } else { + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); + } + //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); #else #ifdef HAVE_FANCY_SIMD if constexpr (D == 128) {