mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 23:49:23 +00:00
Make q8_0_r4 work with tensor row sizes that are not a multiple of 128
.. on NEON
This commit is contained in:
@@ -12439,12 +12439,32 @@ struct Q6_0_R4_Dequantizer {
|
||||
const int8x16_t m32 = vdupq_n_s8(-32);
|
||||
};
|
||||
|
||||
inline void qx_0_q8_0_dot(const int8x16_t * qx, const int8_t * qy, int32x4_t& sumi1, int32x4_t& sumi2) {
|
||||
auto y = vld1q_s8_x2(qy);
|
||||
sumi1 = sumi2 = vdupq_n_s32(0);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
Q8<nrc_y, block_q8_0_x4> q8(info);
|
||||
int nb = n / QK8_0;
|
||||
GGML_ASSERT(nb%4 == 0);
|
||||
float32x4_t acc[2*nrc_y] = {};
|
||||
int8x16_t qx[16];
|
||||
float d8[4*nrc_y];
|
||||
@@ -12459,32 +12479,29 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
|
||||
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
|
||||
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[4*ib4+k].qs + 16*j);
|
||||
int32x4_t sumi1, sumi2;
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
|
||||
auto sumi1 = vdupq_n_s32(0);
|
||||
auto sumi2 = vdupq_n_s32(0);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[0], y.val[0], 0);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[1], y.val[0], 0);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[2], y.val[0], 1);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[3], y.val[0], 1);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[4], y.val[0], 2);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[5], y.val[0], 2);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[6], y.val[0], 3);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[7], y.val[0], 3);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+0], y.val[1], 0);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+1], y.val[1], 0);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+2], y.val[1], 1);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+3], y.val[1], 1);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+4], y.val[1], 2);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+5], y.val[1], 2);
|
||||
sumi1 = vdotq_laneq_s32(sumi1, qx[8+6], y.val[1], 3);
|
||||
sumi2 = vdotq_laneq_s32(sumi2, qx[8+7], y.val[1], 3);
|
||||
qx_0_q8_0_dot(qx, q8.y[iy][ib4].qs+32*k, sumi1, sumi2);
|
||||
auto dy = vdupq_n_f32(d8[4*iy+k]);
|
||||
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
|
||||
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
auto scales16 = vld1q_f16((const float16_t *)iq8[ib].d);
|
||||
auto scales1 = vcvt_f32_f16(vget_low_f16 (scales16));
|
||||
auto scales2 = vcvt_f32_f16(vget_high_f16(scales16));
|
||||
for (int j = 0; j < 16; ++j) qx[j] = vld1q_s8(iq8[ib].qs + 16*j);
|
||||
int32x4_t sumi1, sumi2;
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_0 *)q8.y[iy];
|
||||
qx_0_q8_0_dot(qx, qy[ib].qs, sumi1, sumi2);
|
||||
auto dy = vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d));
|
||||
acc[2*iy+0] = vfmaq_f32(acc[2*iy+0], vmulq_f32(scales1, dy), vcvtq_f32_s32(sumi1));
|
||||
acc[2*iy+1] = vfmaq_f32(acc[2*iy+1], vmulq_f32(scales2, dy), vcvtq_f32_s32(sumi2));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix+0, iy, acc[2*iy+0]);
|
||||
info.store(ix+4, iy, acc[2*iy+1]);
|
||||
|
||||
Reference in New Issue
Block a user