mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 23:24:13 +00:00
FA: dedicated mat mul for D = 128 also for ARM_NEON
This commit is contained in:
@@ -12063,6 +12063,46 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q8_0_r4_q8_0_128(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
GGML_ASSERT(n == 128);
|
||||
int8x16x4_t qx[8];
|
||||
float32x4_t scales[4];
|
||||
float32x4_t scales_y[4];
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_q8_0_x4 * iq8 = (const block_q8_0_x4 *)((const char *)vx + ix*bx);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
scales[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq8[k].d));
|
||||
qx[2*k+0] = vld1q_s8_x4(iq8[k].qs);
|
||||
qx[2*k+1] = vld1q_s8_x4(iq8[k].qs+64);
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto by = (const block_q8_0_x4 *)info.src1_row(iy);
|
||||
auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)by->d));
|
||||
scales_y[0] = vmulq_laneq_f32(scales[0], d8, 0);
|
||||
scales_y[1] = vmulq_laneq_f32(scales[1], d8, 1);
|
||||
scales_y[2] = vmulq_laneq_f32(scales[2], d8, 2);
|
||||
scales_y[3] = vmulq_laneq_f32(scales[3], d8, 3);
|
||||
auto sumf = vdupq_n_f32(0.f);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto y = vld1q_s8_x2(by->qs+32*k);
|
||||
auto sumi = vdupq_n_s32(0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[0], y.val[0], 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[1], y.val[1], 0);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[2], y.val[0], 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+0].val[3], y.val[1], 1);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[0], y.val[0], 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[1], y.val[1], 2);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[2], y.val[0], 3);
|
||||
sumi = vdotq_laneq_s32(sumi, qx[2*k+1].val[3], y.val[1], 3);
|
||||
sumf = vfmaq_f32(sumf, scales_y[k], vcvtq_f32_s32(sumi));
|
||||
}
|
||||
info.store(ix, iy, sumf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
|
||||
m.funcs[0] = func<Dequantizer, 1>;\
|
||||
m.funcs[1] = func<Dequantizer, 2>;\
|
||||
@@ -13645,7 +13685,23 @@ struct FlashQKfp32 {
|
||||
}
|
||||
else if constexpr (std::is_same_v<KHelper, HelperQ80R4<D, k_step>>) {
|
||||
#ifdef __aarch64__
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
|
||||
if constexpr (D == 128) {
|
||||
if (q_step >= 64 && nq >= 64) {
|
||||
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<64>, 64);
|
||||
}
|
||||
else if (q_step >= 32 && nq >= 32) {
|
||||
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<32>, 32);
|
||||
}
|
||||
else if (q_step >= 16 && nq >= 16) {
|
||||
return std::make_pair(mul_mat_q8_0_r4_q8_0_128<16>, 16);
|
||||
}
|
||||
else {
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0_128, nq);
|
||||
}
|
||||
} else {
|
||||
MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
|
||||
}
|
||||
//MAKE_FUNCS_ONLY_NRC(mul_mat_q8_0_r4_q8_0, nq);
|
||||
#else
|
||||
#ifdef HAVE_FANCY_SIMD
|
||||
if constexpr (D == 128) {
|
||||
|
||||
Reference in New Issue
Block a user