iq1bn(no lookup): NEON attempts

We are at TG-128 = 25.7 t/s, which is quite a bit worse than
lookup.
This commit is contained in:
Kawrakow
2024-07-16 08:32:15 +02:00
parent 597ea12970
commit d0f9d146b8

View File

@@ -4393,9 +4393,19 @@ struct DequantizerIQ1BN {
return vld1q_s16(data);
}
static inline uint16x8_t load_mult() {
static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
//static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
static const uint16_t data[8] = {2187*8, 729*8, 243*8, 81*8, 27*8, 9*8, 3*8, 1*8};
return vld1q_u16(data);
}
//static inline uint8x16x4_t load_shuffles(uint16_t s0) {
// uint8x16x4_t r;
// auto step = vdupq_n_u8(4);
// r.val[0] = vreinterpretq_u8_u16(vdupq_n_u16(s0));
// r.val[1] = vaddq_u8(r.val[0], step);
// r.val[2] = vaddq_u8(r.val[1], step);
// r.val[3] = vaddq_u8(r.val[2], step);
// return r;
//}
const uint8x16_t shuff_l = load_shuffle_l();
const uint8x16_t shuff_h = load_shuffle_h();
@@ -4405,22 +4415,33 @@ struct DequantizerIQ1BN {
const uint16x8_t mask_hh = vdupq_n_u16(4096);
const int16x8_t shift_hh = load_shift_hh();
const uint16x8_t mult = load_mult();
const uint16x8_t mask = vdupq_n_u16(0x1fff);
const uint16x8_t m3 = vdupq_n_u16(3);
const uint8x16_t step = vdupq_n_u8(2);
const uint8x16_t shuff0 = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
//const uint8x16x4_t shuff1 = load_shuffles(0x0100);
//const uint8x16x4_t shuff2 = load_shuffles(0x0302);
//const uint16x8_t mask = vdupq_n_u16(0x1fff);
//const uint16x8_t m3 = vdupq_n_u16(3);
inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
auto aux1 = vqtbl1q_u8(data, shuff_l);
auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
auto step = vdupq_n_u8(2);
auto shuffle = shuff0;
//auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
//auto step = vdupq_n_u8(2);
for (int k = 0; k < 4; ++k) {
auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
//auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff1.val[k]));
//auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuff2.val[k]));
v1 = vmulq_u16(v1, mult);
v2 = vmulq_u16(v2, mult);
v1 = vshrq_n_u16(vhaddq_u16(v1, vshrq_n_u16(v1, 1)), 14);
v2 = vshrq_n_u16(vhaddq_u16(v2, vshrq_n_u16(v2, 1)), 14);
//v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
//v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
}
}
@@ -4448,9 +4469,9 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
deq.prepare_iq1bn_quants(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
deq.prepare_iq1bn_quants(x+2*i+1, v1);
deq.prepare_iq1bn_quants(x+2*i+1, v2);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v2.val[j]);
}
accd[0] = vaddq_s32(vaddq_s32(acc[0], acc[1]), vaddq_s32(acc[2], acc[3]));
}