mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-11 06:20:09 +00:00
FA: repack Q8_0 to Q8_0_R8 (NEON)
Very slightly faster than the general purpose gemm, slightly slower than the D = 128 special case gemm mul_mat_q8_0_r4_q8_0_128. Still removing mul_mat_q8_0_r4_q8_0_128 as we simply don't have enough vector registers to hold 8 interleaved rows, so there is no point to have the special purpose implementation.
This commit is contained in:
@@ -12228,46 +12228,6 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
GGML_ASSERT(n == 128);
|
||||
int8x16x4_t qx[8];
|
||||
float32x4_t scales[4];
|
||||
float32x4_t scales_y[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d));
|
||||
qx[2*k+0] = vld1q_s8_x4(iq8[k].qs);
|
||||
qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto by = (const block_q8_0_x4 *)info.src1_row(iy);
|
||||
auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d));
|
||||
scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0);
|
||||
scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1);
|
||||
scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2);
|
||||
scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3);
|
||||
auto sumf = vdupq_n_f32(0.f);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto y = vld1q_s8_x2(by->qs+32*k);
|
||||
auto sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3);
|
||||
sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi));
|
||||
}
|
||||
info.store(ix, iy, sumf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
|
||||
m.funcs[0] = func<Dequantizer, 1>;\
|
||||
m.funcs[1] = func<Dequantizer, 2>;\
|
||||
@@ -12914,6 +12874,9 @@ struct HelperQ80R4 : public BaseHelper<step> {
|
||||
std::vector<block_q8_0_r8> result(nblock * nk/8);
|
||||
auto y = result.data();
|
||||
const block_q8_0 * x8[8];
|
||||
#ifdef __ARM_NEON
|
||||
int8x16x2_t m0, m1, m2, m3;
|
||||
#endif
|
||||
for (int row = 0; row < nk; row += 8) {
|
||||
for (int k = 0; k < 8; ++k) x8[k] = (const block_q8_0 *)(q8.data + (row + k)*q8.stride);
|
||||
for (int ib = 0; ib < nblock; ++ib) {
|
||||
@@ -12952,31 +12915,33 @@ struct HelperQ80R4 : public BaseHelper<step> {
|
||||
_mm256_storeu_si256((__m256i *)y[ib].qs + 6, m2);
|
||||
_mm256_storeu_si256((__m256i *)y[ib].qs + 7, m3);
|
||||
#elif defined __ARM_NEON
|
||||
auto m0 = vld1q_s8_x2(x4[0][ib].qs);
|
||||
auto m1 = vld1q_s8_x2(x4[1][ib].qs);
|
||||
auto m2 = vld1q_s8_x2(x4[2][ib].qs);
|
||||
auto m3 = vld1q_s8_x2(x4[3][ib].qs);
|
||||
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
|
||||
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
|
||||
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
|
||||
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
|
||||
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
vst1q_s8_x2(y[ib].qs + 0, m0);
|
||||
vst1q_s8_x2(y[ib].qs + 32, m1);
|
||||
vst1q_s8_x2(y[ib].qs + 64, m2);
|
||||
vst1q_s8_x2(y[ib].qs + 96, m3);
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
m0.val[0] = vld1q_s8(x8[0][ib].qs+16*l); m0.val[1] = vld1q_s8(x8[4][ib].qs+16*l);
|
||||
m1.val[0] = vld1q_s8(x8[1][ib].qs+16*l); m1.val[1] = vld1q_s8(x8[5][ib].qs+16*l);
|
||||
m2.val[0] = vld1q_s8(x8[2][ib].qs+16*l); m2.val[1] = vld1q_s8(x8[6][ib].qs+16*l);
|
||||
m3.val[0] = vld1q_s8(x8[3][ib].qs+16*l); m3.val[1] = vld1q_s8(x8[7][ib].qs+16*l);
|
||||
auto row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[0]), vreinterpretq_s32_s8(m1.val[0]));
|
||||
auto row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[0]), vreinterpretq_s32_s8(m3.val[0]));
|
||||
m0.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m1.val[0] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
m2.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m3.val[0] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
row01 = vtrnq_s32(vreinterpretq_s32_s8(m0.val[1]), vreinterpretq_s32_s8(m1.val[1]));
|
||||
row23 = vtrnq_s32(vreinterpretq_s32_s8(m2.val[1]), vreinterpretq_s32_s8(m3.val[1]));
|
||||
m0.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m1.val[1] = vreinterpretq_s8_s64(vtrn1q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
m2.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[0]), vreinterpretq_s64_s32(row23.val[0])));
|
||||
m3.val[1] = vreinterpretq_s8_s64(vtrn2q_s64(vreinterpretq_s64_s32(row01.val[1]), vreinterpretq_s64_s32(row23.val[1])));
|
||||
vst1q_s8_x2(y[ib].qs + 0 + 128*l, m0);
|
||||
vst1q_s8_x2(y[ib].qs + 32 + 128*l, m1);
|
||||
vst1q_s8_x2(y[ib].qs + 64 + 128*l, m2);
|
||||
vst1q_s8_x2(y[ib].qs + 96 + 128*l, m3);
|
||||
}
|
||||
#else
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
for (int k = 0; k < 4; ++k) for (int i = 0; i < 4; ++i) {
|
||||
y[ib].qs[32*l+4*k+i+ 0] = x4[k][ib].qs[i+4*l+ 0];
|
||||
y[ib].qs[32*l+4*k+i+16] = x4[k][ib].qs[i+4*l+16];
|
||||
for (int k = 0; k < 8; ++k) for (int i = 0; i < 4; ++i) {
|
||||
y[ib].qs[32*l+4*k+i+ 0] = x8[k][ib].qs[i+4*l+ 0];
|
||||
y[ib].qs[32*l+4*k+i+128] = x8[k][ib].qs[i+4*l+16];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -13794,23 +13759,7 @@ struct FlashQKfp32 {
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
if constexpr (D == 128) {
|
||||
if (q_step >= 64 && nq >= 64) {
|
||||
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64);
|
||||
}
|
||||
else if (q_step >= 32 && nq >= 32) {
|
||||
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32);
|
||||
}
|
||||
else if (q_step >= 16 && nq >= 16) {
|
||||
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16);
|
||||
}
|
||||
else {
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq);
|
||||
}
|
||||
} else {
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
|
||||
}
|
||||
//MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
|
||||
#else
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_1, nq);
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user