FA: repack Q8_0 to Q8_0_R8 (NEON)

Very slightly faster than the general purpose gemm, slightly
slower than the D = 128 special case gemm mul_mat_q8_0_r4_q8_0_128.
Still removing mul_mat_q8_0_r4_q8_0_128 as we simply don't have
enough vector registers to hold 8 interleaved rows, so there is
no point to have the special purpose implementation.
This commit is contained in:
Iwan Kawrakow
2025-01-26 12:24:38 +02:00
parent 3484ee6ddb
commit 56ca4c3ba9

View File

@@ -12228,46 +12228,6 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
template <int nrc_y>
void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
GGML_ASSERT(n == 128);
int8x16x4_t qx[8];
float32x4_t scales[4];
float32x4_t scales_y[4];
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
for (int k = 0; k < 4; ++k) {
scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d));
qx[2*k+0] = vld1q_s8_x4(iq8[k].qs);
qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto by = (const block_q8_0_x4 *)info.src1_row(iy);
auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d));
scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0);
scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1);
scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2);
scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3);
auto sumf = vdupq_n_f32(0.f);
for (int k = 0; k < 4; ++k) {
auto y = vld1q_s8_x2(by->qs+32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3);
sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi));
}
info.store(ix, iy, sumf);
}
}
}
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
m.funcs[0] = func<Dequantizer, 1>;\
m.funcs[1] = func<Dequantizer, 2>;\
@@ -12914,6 +12874,9 @@ struct HelperQ80R4 : public BaseHelper<step> {
std::vector<block_q8_0_r8> result(nblock * nk/8);
auto y = result.data();
const block_q8_0 * x8[8];
#ifdef __ARM_NEON
int8x16x2_t m0, m1, m2, m3;
#endif
for (int row = 0; row < nk; row += 8) {
for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
for (int ib = 0; ib < nblock; ++ib) {
@@ -12952,31 +12915,33 @@ struct HelperQ80R4 : public BaseHelper<step> {
_mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2);
_mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3);
#elif defined __ARM_NEON
auto m0 = vld1q_s8_x2(x4[0][ib].qs);
auto m1 = vld1q_s8_x2(x4[1][ib].qs);
auto m2 = vld1q_s8_x2(x4[2][ib].qs);
auto m3 = vld1q_s8_x2(x4[3][ib].qs);
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
vst1q_s8_x2(y[ib].qs + 0, m0);
vst1q_s8_x2(y[ib].qs + 32, m1);
vst1q_s8_x2(y[ib].qs + 64, m2);
vst1q_s8_x2(y[ib].qs + 96, m3);
for (int l = 0; l < 2; ++l) {
m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
}
#else
for (int l = 0; l < 4; ++l) {
for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0];
y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16];
for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
}
}
#endif
@@ -13794,23 +13759,7 @@ struct FlashQKfp32 {
}
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
#ifdef __aarch64__
if constexpr (D == 128) {
if (q_step >= 64 && nq >= 64) {
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64);
}
else if (q_step >= 32 && nq >= 32) {
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32);
}
else if (q_step >= 16 && nq >= 16) {
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16);
}
else {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq);
}
} else {
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
}
//MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
#else
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq);
#endif