iq2_bn_r4: NEON

PP-512 is now 296 t/s. TG-128 is ~20% faster than iq2_bn
for 1 thread, but saturates to about the same 93 t/s at
8 threads.
This commit is contained in:
Iwan Kawrakow
2024-12-05 17:40:12 +01:00
parent 32f8a33f5e
commit c848533580

View File

@@ -7386,63 +7386,126 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
Q8_16<nrc_y> q8(info);
auto m3 = vdupq_n_u8(0x3);
int nb = n / QK_IQ1BN;
int32x4_t acc[4*nrc_y] = {};
uint8x16_t qx[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = vld1q_u8_x2(iq2 + 64*ib);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
if constexpr (nrc_y == 1) {
auto mc = vdupq_n_u8(0xc);
int32x4_t acc[8];
uint8x16x4_t bits2;
//uint8x16x2_t bits2;
for (int ix = 0; ix < nrc_x; ix += 4) {
//acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
//acc[4] = acc[5] = acc[6] = acc[7] = vdupq_n_s32(0);
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
auto bits1 = vld1q_u8_x4(iq2);
for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
auto y = q8.load_quants(0, 0);
for (int k = 0; k < 4; ++k) {
acc[2*k+0] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], m3), y.val[k], 0);
acc[2*k+1] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], mc), y.val[k], 1);
acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2);
acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3);
}
bits = vld1q_u8_x2(iq2 + 64*ib + 32);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
for (int ib = 1; ib < nb; ++ib) {
bits1 = vld1q_u8_x4(iq2 + 64*ib);
for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
y = q8.load_quants(0, ib);
for (int k = 0; k < 4; ++k) {
acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0);
acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1);
acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2);
acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3);
}
//auto bits1 = vld1q_u8_x2(iq2 + 64*ib);
//for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
//auto y = q8.load_quants_32(0, 2*ib+0);
//for (int k = 0; k < 2; ++k) {
// acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0);
// acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1);
// acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2);
// acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3);
//}
//bits1 = vld1q_u8_x2(iq2 + 64*ib + 32);
//for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
//y = q8.load_quants_32(0, 2*ib+1);
//for (int k = 0; k < 2; ++k) {
// acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits1.val[k], m3), y.val[k], 0);
// acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits1.val[k], mc), y.val[k], 1);
// acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits2.val[k], m3), y.val[k], 2);
// acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits2.val[k], mc), y.val[k], 3);
//}
}
auto dy = q8.scale(0);
float32x4_t s = vfmaq_f32(vcvtq_f32_s32(acc[0]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[1]));
float32x4_t sumf = vmulq_f32(s, vmulq_laneq_f32(dl, dy, 0));
s = vfmaq_f32(vcvtq_f32_s32(acc[2]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[3]));
sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 1));
s = vfmaq_f32(vcvtq_f32_s32(acc[4]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[5]));
sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 2));
s = vfmaq_f32(vcvtq_f32_s32(acc[6]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[7]));
sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 3));
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
info.store(ix, 0, sumf);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
info.store(ix, iy, sumf);
acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_f32(0.f);
} else {
int32x4_t acc[4*nrc_y] = {};
uint8x16_t qx[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
auto dl = vld1q_f32(dptr);
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
for (int ib = 0; ib < nb; ++ib) {
auto bits = vld1q_u8_x2(iq2 + 64*ib);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[0], y.val[0], 0);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[1], y.val[0], 1);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[2], y.val[0], 2);
acc[4*iy + 0] = vdotq_laneq_s32(acc[4*iy + 0], qx[3], y.val[0], 3);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[4], y.val[1], 0);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[5], y.val[1], 1);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[6], y.val[1], 2);
acc[4*iy + 1] = vdotq_laneq_s32(acc[4*iy + 1], qx[7], y.val[1], 3);
}
bits = vld1q_u8_x2(iq2 + 64*ib + 32);
qx[0] = vandq_u8(bits.val[0], m3);
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m3);
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m3);
qx[3] = vshrq_n_u8(bits.val[0], 6);
qx[4] = vandq_u8(bits.val[1], m3);
qx[5] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m3);
qx[6] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m3);
qx[7] = vshrq_n_u8(bits.val[1], 6);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = q8.load_quants_32(iy, 2*ib+1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[0], y.val[0], 0);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[1], y.val[0], 1);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[2], y.val[0], 2);
acc[4*iy + 2] = vdotq_laneq_s32(acc[4*iy + 2], qx[3], y.val[0], 3);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[4], y.val[1], 0);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[5], y.val[1], 1);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[6], y.val[1], 2);
acc[4*iy + 3] = vdotq_laneq_s32(acc[4*iy + 3], qx[7], y.val[1], 3);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto dy = q8.scale(iy);
float32x4_t sumf = vmulq_f32(vcvtq_f32_s32(acc[4*iy+0]), vmulq_laneq_f32(dl, dy, 0));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+1]), vmulq_laneq_f32(dl, dy, 1));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+2]), vmulq_laneq_f32(dl, dy, 2));
sumf = vfmaq_f32(sumf, vcvtq_f32_s32(acc[4*iy+3]), vmulq_laneq_f32(dl, dy, 3));
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(iy)));
info.store(ix, iy, sumf);
acc[4*iy+0] = acc[4*iy+1] = acc[4*iy+2] = acc[4*iy+3] = vdupq_n_s32(0);
}
}
}
}
@@ -7998,9 +8061,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
m.funcs[2] = mul_mat_iq2_bn_r4_q8_k16<3>;
m.funcs[3] = mul_mat_iq2_bn_r4_q8_k16<4>;
m.funcs[4] = mul_mat_iq2_bn_r4_q8_k16<5>;
m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
//m.funcs[5] = mul_mat_iq2_bn_r4_q8_k16<6>;
//m.funcs[6] = mul_mat_iq2_bn_r4_q8_k16<7>;
//m.funcs[7] = mul_mat_iq2_bn_r4_q8_k16<8>;
expected_Btype = GGML_TYPE_Q8_K16;
break;
case GGML_TYPE_Q4_0: