Bitnet(2.25 bpw): NEON

We get PP-512 = 192 t/s, TG-128 = 72 t/s
This commit is contained in:
Kawrakow
2024-06-18 11:11:46 +02:00
parent 39982764d7
commit 766975ecfa

View File

@@ -4199,61 +4199,22 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
Q8_K64<nrc_y> q8(info);
float32x4_t accd[nrc_y];
int8x16x4_t signs;
uint64x2x4_t a;
int8x16x4_t v;
const auto m1 = vdupq_n_u8(1);
const uint8x16x4_t sign_shuffles = {
vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
};
const auto shift = vreinterpretq_s8_u32(vdupq_n_u32(0xfafcfe00));
const auto qmask = vdupq_n_u8(3);
const auto shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0909090908080808});
const auto mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
const auto mask2 = vdupq_n_s8(3);
for (int ix = 0; ix < nrc_x; ++ix) {
const block_iq2_bn * x = (const block_iq2_bn *)((const char *)vx + ix*bx);
float d = GGML_FP16_TO_FP32(*(const ggml_half *)x);
auto extra_ptr = (const uint16_t *)x;
const float d = GGML_FP16_TO_FP32(x[0].d);
auto all_signs = vdupq_n_u8(extra_ptr[1]);
all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
auto ql = (const uint8_t *)(extra_ptr + 2);
auto qh = ql + QK_IQ1BN/8;
a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | ((qh[0] << 8) & 0x0f00)], iq1bn_grid_u16[ql[1] | ((qh[0] << 4) & 0x0f00)]};
a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | ((qh[1] << 8) & 0x0f00)], iq1bn_grid_u16[ql[3] | ((qh[1] << 4) & 0x0f00)]};
a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | ((qh[2] << 8) & 0x0f00)], iq1bn_grid_u16[ql[5] | ((qh[2] << 4) & 0x0f00)]};
a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | ((qh[3] << 8) & 0x0f00)], iq1bn_grid_u16[ql[7] | ((qh[3] << 4) & 0x0f00)]};
v.val[0] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
v.val[1] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
v.val[2] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
v.val[3] = vsubq_s8(vandq_u8(vshlq_u8(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, 0);
int32x4_t sumi = vdupq_n_s32(0);
for (int j = 0; j < 4; ++j) {
sumi = ggml_vdotq_s32(sumi, q.val[j], v.val[j]);
}
accd[0] = vmulq_f32(vdupq_n_f32(q8.scale(0, 0)), vcvtq_f32_s32(sumi));
} else {
{
auto q2bits = vld1q_u8(x[0].qs);
v.val[0] = vsubq_s8(vandq_s8(q2bits, mask2), m1);
v.val[1] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 2), mask2), m1);
v.val[2] = vsubq_s8(vandq_s8(vshrq_n_u8(q2bits, 4), mask2), m1);
v.val[3] = vsubq_s8(vshrq_n_u8(q2bits, 6), m1);
for (int iy = 0; iy < nrc_y; ++iy) {
int32x4_t sumi = vdupq_n_s32(0);
auto q = q8.load_quants(iy, 0, 0);
@@ -4262,6 +4223,7 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
sumi = ggml_vdotq_s32(ggml_vdotq_s32(sumi, q.val[0], v.val[2]), q.val[1], v.val[3]);
accd[iy] = vmulq_f32(vdupq_n_f32(q8.scale(iy, 0)), vcvtq_f32_s32(sumi));
}
}
for (int i = 1; i < nb; ++i) {