This commit is contained in:
Iwan Kawrakow
2024-12-04 15:05:55 +01:00
parent 59a3e12508
commit 6dec39627c

View File

@@ -7253,6 +7253,30 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
}
IQK_ALWAYS_INLINE int32x4_t interleaved_dotq(const int8x16_t * qx, const int8x16x2_t& y) {
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
return sumi;
}
IQK_ALWAYS_INLINE void prepare_iq4_nl_quants(const int8x16_t& values, const uint8x16_t& m4, const uint8x16x4_t& bits, int8x16_t * qx) {
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
}
template <int nrc_y>
void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
@@ -7269,25 +7293,10 @@ void mul_mat_iq4_nl_x4_q8_0(int n, const void * vx, size_t bx, const DataInfo& i
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
prepare_iq4_nl_quants(values, m4, bits, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
@@ -7323,25 +7332,10 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i
auto sl32 = vmovl_s16(vget_low_s16(sl16));
auto scales = vmulq_f32(d4, vcvtq_f32_s32(sl32));
auto bits = vld1q_u8_x4(iq4[ibl].qs + 64*ib);
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
qx[2] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
qx[3] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
qx[4] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
qx[5] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
qx[6] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
qx[7] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
prepare_iq4_nl_quants(values, m4, bits, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+32*ib);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(q8.scale(iy, ibl)));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}