mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 03:11:51 +00:00
iq4_xs_r4: Better ARM implementation
PP-512(LLaMA-3.1-8B) is now 131.3 t/s up from 115.8 t/s. iq4_xs_r4 is now the prompt processing champion on ARM.
This commit is contained in:
@@ -8309,42 +8309,38 @@ void mul_mat_iq4_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& i
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[8];
|
||||
int8x16x2_t iscales;
|
||||
float32x4x4_t scales;
|
||||
int32x4x4_t scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq4_xs_r4 * iq4 = (const block_iq4_xs_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
|
||||
if constexpr (nrc_y == 1) {
|
||||
d4 = vmulq_f32(d4, vdupq_n_f32(q8.scale(0, ibl)));
|
||||
}
|
||||
auto sl = vld1q_u8(iq4[ibl].scales_l);
|
||||
auto sh8 = vld1_u8(iq4[ibl].scales_h);
|
||||
auto sh = vcombine_u8(sh8, vshr_n_u8(sh8, 2));
|
||||
iscales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl, m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
|
||||
iscales.val[1] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl, 4), vandq_u8(sh, m3)), m32);
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
auto iscales16_1 = vmovl_s8(vget_low_s8(iscales.val[is]));
|
||||
auto iscales16_2 = vmovl_s8(vget_high_s8(iscales.val[is]));
|
||||
scales.val[0] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_1))));
|
||||
scales.val[1] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_1))));
|
||||
scales.val[2] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_low_s16(iscales16_2))));
|
||||
scales.val[3] = vmulq_f32(d4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(iscales16_2))));
|
||||
scales.val[0] = vmovl_s16(vget_low_s16(iscales16_1));
|
||||
scales.val[1] = vmovl_s16(vget_high_s16(iscales16_1));
|
||||
scales.val[2] = vmovl_s16(vget_low_s16(iscales16_2));
|
||||
scales.val[3] = vmovl_s16(vget_high_s16(iscales16_2));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
|
||||
prepare_iq4_nl_quants(values, m4, bits, qx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
if constexpr (nrc_y == 1) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], scales.val[ib], vcvtq_f32_s32(sumi));
|
||||
} else {
|
||||
auto d4d8 = vmulq_f32(scales.val[ib], vdupq_n_f32(q8.scale(iy, ibl)));
|
||||
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
@@ -9027,7 +9023,7 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
break;
|
||||
case GGML_TYPE_IQ4_XS_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
expected_Btype = GGML_TYPE_Q8_K32;
|
||||
break;
|
||||
case GGML_TYPE_Q3_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_q3_k_r4_q8_k);
|
||||
|
||||
Reference in New Issue
Block a user