mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq2_bn_r4: Experimenting on NEON
The matrix x vvector multiplication is erratic. iq2_bn_r4 is faster at 1, 2, and 4 threads, but saturates to a lower t/s at 8 threads compared to iq2_bn. iq2_bn actually manages 99 t/s at 8 threads and not 93 as I wrore in the last commit. iq2_bn_r4 performance has huge fluctuations at 4 and 8 threads.
This commit is contained in:
@@ -7387,63 +7387,86 @@ static void mul_mat_iq2_bn_r4_q8_k16(int n, const void * vx, size_t bx, const Da
|
||||
auto m3 = vdupq_n_u8(0x3);
|
||||
int nb = n / QK_IQ1BN;
|
||||
if constexpr (nrc_y == 1) {
|
||||
//uint8x16x4_t shuff = {
|
||||
// vreinterpretq_u8_u32(vdupq_n_u32(0x03020100)),
|
||||
// vreinterpretq_u8_u32(vdupq_n_u32(0x07060504)),
|
||||
// vreinterpretq_u8_u32(vdupq_n_u32(0x0b0a0908)),
|
||||
// vreinterpretq_u8_u32(vdupq_n_u32(0x0f0e0d0c)),
|
||||
//};
|
||||
auto mc = vdupq_n_u8(0xc);
|
||||
int32x4_t acc[8];
|
||||
uint8x16x4_t bits2;
|
||||
//uint8x16x2_t bits2;
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
for (int k = 0; k < 8; ++k) acc[k] = vdupq_n_s32(0);
|
||||
//acc[0] = acc[1] = acc[2] = acc[3] = vdupq_n_s32(0);
|
||||
//acc[4] = acc[5] = acc[6] = acc[7] = vdupq_n_s32(0);
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto dl = vld1q_f32(dptr);
|
||||
const uint8_t * iq2 = (const uint8_t *)(dptr + 4);
|
||||
auto bits1 = vld1q_u8_x4(iq2);
|
||||
for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
|
||||
auto y = q8.load_quants(0, 0);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
acc[2*k+0] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], m3), y.val[k], 0);
|
||||
acc[2*k+1] = vdotq_laneq_s32(vdupq_n_s32(0), vandq_u8(bits1.val[k], mc), y.val[k], 1);
|
||||
acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2);
|
||||
acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3);
|
||||
for (int ib = 0; ib < nb; ++ib) {
|
||||
auto y = q8.load_quants(0, ib);
|
||||
auto bits1 = vld1q_u8(iq2 + 64*ib);
|
||||
auto bits2 = vshrq_n_u8(bits1, 4);
|
||||
acc[0] = vdotq_laneq_s32(acc[0], vandq_u8(bits1, m3), y.val[0], 0);
|
||||
acc[1] = vdotq_laneq_s32(acc[1], vandq_u8(bits1, mc), y.val[0], 1);
|
||||
acc[0] = vdotq_laneq_s32(acc[0], vandq_u8(bits2, m3), y.val[0], 2);
|
||||
acc[1] = vdotq_laneq_s32(acc[1], vandq_u8(bits2, mc), y.val[0], 3);
|
||||
//acc[0] = vdotq_s32(acc[0], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[0], shuff.val[0]));
|
||||
//acc[1] = vdotq_s32(acc[1], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[0], shuff.val[1]));
|
||||
//acc[0] = vdotq_s32(acc[0], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[0], shuff.val[2]));
|
||||
//acc[1] = vdotq_s32(acc[1], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[0], shuff.val[3]));
|
||||
bits1 = vld1q_u8(iq2 + 64*ib + 16);
|
||||
bits2 = vshrq_n_u8(bits1, 4);
|
||||
acc[2] = vdotq_laneq_s32(acc[2], vandq_u8(bits1, m3), y.val[1], 0);
|
||||
acc[3] = vdotq_laneq_s32(acc[3], vandq_u8(bits1, mc), y.val[1], 1);
|
||||
acc[2] = vdotq_laneq_s32(acc[2], vandq_u8(bits2, m3), y.val[1], 2);
|
||||
acc[3] = vdotq_laneq_s32(acc[3], vandq_u8(bits2, mc), y.val[1], 3);
|
||||
//acc[2] = vdotq_s32(acc[2], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[1], shuff.val[0]));
|
||||
//acc[3] = vdotq_s32(acc[3], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[1], shuff.val[1]));
|
||||
//acc[2] = vdotq_s32(acc[2], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[1], shuff.val[2]));
|
||||
//acc[3] = vdotq_s32(acc[3], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[1], shuff.val[3]));
|
||||
bits1 = vld1q_u8(iq2 + 64*ib + 32);
|
||||
bits2 = vshrq_n_u8(bits1, 4);
|
||||
acc[4] = vdotq_laneq_s32(acc[4], vandq_u8(bits1, m3), y.val[2], 0);
|
||||
acc[5] = vdotq_laneq_s32(acc[5], vandq_u8(bits1, mc), y.val[2], 1);
|
||||
acc[4] = vdotq_laneq_s32(acc[4], vandq_u8(bits2, m3), y.val[2], 2);
|
||||
acc[5] = vdotq_laneq_s32(acc[5], vandq_u8(bits2, mc), y.val[2], 3);
|
||||
//acc[4] = vdotq_s32(acc[4], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[2], shuff.val[0]));
|
||||
//acc[5] = vdotq_s32(acc[5], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[2], shuff.val[1]));
|
||||
//acc[4] = vdotq_s32(acc[4], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[2], shuff.val[2]));
|
||||
//acc[5] = vdotq_s32(acc[5], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[2], shuff.val[3]));
|
||||
bits1 = vld1q_u8(iq2 + 64*ib + 48);
|
||||
bits2 = vshrq_n_u8(bits1, 4);
|
||||
acc[6] = vdotq_laneq_s32(acc[6], vandq_u8(bits1, m3), y.val[3], 0);
|
||||
acc[7] = vdotq_laneq_s32(acc[7], vandq_u8(bits1, mc), y.val[3], 1);
|
||||
acc[6] = vdotq_laneq_s32(acc[6], vandq_u8(bits2, m3), y.val[3], 2);
|
||||
acc[7] = vdotq_laneq_s32(acc[7], vandq_u8(bits2, mc), y.val[3], 3);
|
||||
//acc[6] = vdotq_s32(acc[6], vandq_u8(bits1, m3), vqtbl1q_s8(y.val[3], shuff.val[0]));
|
||||
//acc[7] = vdotq_s32(acc[7], vandq_u8(bits1, mc), vqtbl1q_s8(y.val[3], shuff.val[1]));
|
||||
//acc[6] = vdotq_s32(acc[6], vandq_u8(bits2, m3), vqtbl1q_s8(y.val[3], shuff.val[2]));
|
||||
//acc[7] = vdotq_s32(acc[7], vandq_u8(bits2, mc), vqtbl1q_s8(y.val[3], shuff.val[3]));
|
||||
}
|
||||
for (int ib = 1; ib < nb; ++ib) {
|
||||
bits1 = vld1q_u8_x4(iq2 + 64*ib);
|
||||
for (int k = 0; k < 4; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
|
||||
y = q8.load_quants(0, ib);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0);
|
||||
acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1);
|
||||
acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2);
|
||||
acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3);
|
||||
}
|
||||
//auto bits1 = vld1q_u8_x2(iq2 + 64*ib);
|
||||
//for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
|
||||
//auto y = q8.load_quants_32(0, 2*ib+0);
|
||||
//for (int k = 0; k < 2; ++k) {
|
||||
// acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits1.val[k], m3), y.val[k], 0);
|
||||
// acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits1.val[k], mc), y.val[k], 1);
|
||||
// acc[2*k+0] = vdotq_laneq_s32(acc[2*k+0], vandq_u8(bits2.val[k], m3), y.val[k], 2);
|
||||
// acc[2*k+1] = vdotq_laneq_s32(acc[2*k+1], vandq_u8(bits2.val[k], mc), y.val[k], 3);
|
||||
//}
|
||||
//bits1 = vld1q_u8_x2(iq2 + 64*ib + 32);
|
||||
//for (int k = 0; k < 2; ++k) bits2.val[k] = vshrq_n_u8(bits1.val[k], 4);
|
||||
//y = q8.load_quants_32(0, 2*ib+1);
|
||||
//for (int k = 0; k < 2; ++k) {
|
||||
// acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits1.val[k], m3), y.val[k], 0);
|
||||
// acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits1.val[k], mc), y.val[k], 1);
|
||||
// acc[2*k+4] = vdotq_laneq_s32(acc[2*k+4], vandq_u8(bits2.val[k], m3), y.val[k], 2);
|
||||
// acc[2*k+5] = vdotq_laneq_s32(acc[2*k+5], vandq_u8(bits2.val[k], mc), y.val[k], 3);
|
||||
//}
|
||||
}
|
||||
auto dy = q8.scale(0);
|
||||
float32x4_t s = vfmaq_f32(vcvtq_f32_s32(acc[0]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[1]));
|
||||
float32x4_t sumf = vmulq_f32(s, vmulq_laneq_f32(dl, dy, 0));
|
||||
s = vfmaq_f32(vcvtq_f32_s32(acc[2]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[3]));
|
||||
sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 1));
|
||||
s = vfmaq_f32(vcvtq_f32_s32(acc[4]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[5]));
|
||||
sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 2));
|
||||
s = vfmaq_f32(vcvtq_f32_s32(acc[6]), vdupq_n_f32(0.25f), vcvtq_f32_s32(acc[7]));
|
||||
sumf = vfmaq_f32(sumf, s, vmulq_laneq_f32(dl, dy, 3));
|
||||
//auto dy = q8.scale(0);
|
||||
//auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), vmulq_laneq_f32(dl, dy, 0));
|
||||
//auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), vmulq_laneq_f32(dl, dy, 0));
|
||||
//sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), vmulq_laneq_f32(dl, dy, 1));
|
||||
//sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), vmulq_laneq_f32(dl, dy, 1));
|
||||
//sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), vmulq_laneq_f32(dl, dy, 2));
|
||||
//sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), vmulq_laneq_f32(dl, dy, 2));
|
||||
//sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), vmulq_laneq_f32(dl, dy, 3));
|
||||
//sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), vmulq_laneq_f32(dl, dy, 3));
|
||||
auto dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 0)));
|
||||
auto sumf1 = vmulq_f32( vcvtq_f32_s32(acc[0]), dy);
|
||||
auto sumf2 = vmulq_f32( vcvtq_f32_s32(acc[1]), dy);
|
||||
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 1)));
|
||||
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[2]), dy);
|
||||
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[3]), dy);
|
||||
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 2)));
|
||||
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[4]), dy);
|
||||
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[5]), dy);
|
||||
dy = vmulq_f32(dl, vdupq_n_f32(q8.scale(0, 3)));
|
||||
sumf1 = vfmaq_f32(sumf1, vcvtq_f32_s32(acc[6]), dy);
|
||||
sumf2 = vfmaq_f32(sumf2, vcvtq_f32_s32(acc[7]), dy);
|
||||
auto sumf = vfmaq_f32(sumf1, vdupq_n_f32(0.25f), sumf2);
|
||||
sumf = vfmaq_f32(sumf, dl, vdupq_n_f32(-q8.sum_row(0)));
|
||||
info.store(ix, 0, sumf);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user