From 56ca4c3ba9ef7a1f3d245a47d4f7277c59c49f8e Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 26 Jan 2025 12:24:38 +0200 Subject: [PATCH] FA: repack Q8_0 to Q8_0_R8 (NEON) Very slightly faster than the general purpose gemm, slightly slower than the D = 128 special case gemm mul_mat_q8_0_r4_q8_0_128. Still removing mul_mat_q8_0_r4_q8_0_128 as we simply don't have enough vector registers to hold 8 interleaved rows, so there is no point to have the special purpose implementation. --- ggml/src/iqk/iqk_mul_mat.cpp | 109 ++++++++++------------------------- 1 file changed, 29 insertions(+), 80 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 44e31819..d8273415 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -12228,46 +12228,6 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf } } -template -void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - GGML_ASSERT(n == 128); - int8x16x4_t qx[8]; - float32x4_t scales[4]; - float32x4_t scales_y[4]; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx); - for (int k = 0; k < 4; ++k) { - scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d)); - qx[2*k+0] = vld1q_s8_x4(iq8[k].qs); - qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64); - } - for (int iy = 0; iy < nrc_y; ++iy) { - auto by = (const block_q8_0_x4 *)info.src1_row(iy); - auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d)); - scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0); - scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1); - scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2); - scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3); - auto sumf = vdupq_n_f32(0.f); - for (int k = 0; k < 4; ++k) { - auto y = vld1q_s8_x2(by->qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3); - sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi)); - } - info.store(ix, iy, sumf); - } - } -} - #define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ m.funcs[0] = func;\ m.funcs[1] = func;\ @@ -12914,6 +12874,9 @@ struct HelperQ80R4 : public BaseHelper { std::vector result(nblock * nk/8); auto y = result.data(); const block_q8_0 * x8[8]; +#ifdef __ARM_NEON + int8x16x2_t m0, m1, m2, m3; +#endif for (int row = 0; row < nk; row += 8) { for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride); for (int ib = 0; ib < nblock; ++ib) { @@ -12952,31 +12915,33 @@ struct HelperQ80R4 : public BaseHelper { _mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2); _mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3); #elif defined __ARM_NEON - auto m0 = vld1q_s8_x2(x4[0][ib].qs); - auto m1 = vld1q_s8_x2(x4[1][ib].qs); - auto m2 = vld1q_s8_x2(x4[2][ib].qs); - auto m3 = vld1q_s8_x2(x4[3][ib].qs); - auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); - auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); - m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); - row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); - m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); - m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); - vst1q_s8_x2(y[ib].qs + 0, m0); - vst1q_s8_x2(y[ib].qs + 32, m1); - vst1q_s8_x2(y[ib].qs + 64, m2); - vst1q_s8_x2(y[ib].qs + 96, m3); + for (int l = 0; l < 2; ++l) { + m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l); + m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l); + m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l); + m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l); + auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0])); + auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0])); + m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1])); + row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1])); + m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0]))); + m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1]))); + vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0); + vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1); + vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2); + vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3); + } #else for (int l = 0; l < 4; ++l) { - for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) { - y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0]; - y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16]; + for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) { + y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0]; + y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16]; } } #endif @@ -13794,23 +13759,7 @@ struct FlashQKfp32 { } else if constexpr (std::is_same_v>) { #ifdef __aarch64__ - if constexpr (D == 128) { - if (q_step >= 64 && nq >= 64) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64); - } - else if (q_step >= 32 && nq >= 32) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32); - } - else if (q_step >= 16 && nq >= 16) { - return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16); - } - else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq); - } - } else { - MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); - } - //MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); + MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq); #else MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq); #endif