iq1bn(no lookup): NEON

Pretty bad.
This commit is contained in:
Kawrakow
2024-07-15 20:40:14 +02:00
parent cd8fffc3cd
commit 597ea12970

View File

@@ -4375,47 +4375,54 @@ static const uint64_t kall_signs[257] = {
struct DequantizerIQ1BN {
const uint8x16_t m1 = vdupq_n_u8(1);
const uint8x16x4_t sign_shuffles = {
vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
};
const int8x16_t shift = vreinterpretq_s16_u64(vdupq_n_u64(0xfffafffcfffe0000));
const uint8x16_t qmask = vdupq_n_u8(3);
const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0908090809080908});
const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
int8x16x4_t signs;
uint64x2x4_t a;
inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) {
auto all_signs = vld1q_u8((const uint8_t *)(kall_signs + extra));
//auto all_signs = vdupq_n_u8(extra);
//all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
uint32_t aux32[2];
std::memcpy(aux32, qh, 4);
aux32[1] = aux32[0] & 0xf0f0f0f0;
aux32[0] &= 0x0f0f0f0f;
const uint8_t * h = (const uint8_t *)aux32;
static inline uint8x16_t load_shuffle_l() {
static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
return vld1q_u8(data);
}
static inline uint8x16_t load_shuffle_h() {
static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
return vld1q_u8(data);
}
static inline uint8x16_t load_shuffle_hh() {
static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
return vld1q_u8(data);
}
static inline int16x8_t load_shift_hh() {
static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
return vld1q_s16(data);
}
static inline uint16x8_t load_mult() {
static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
return vld1q_u16(data);
}
a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | (h[0] << 8)], iq1bn_grid_u16[ql[1] | (h[4] << 4)]};
a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | (h[1] << 8)], iq1bn_grid_u16[ql[3] | (h[5] << 4)]};
a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | (h[2] << 8)], iq1bn_grid_u16[ql[5] | (h[6] << 4)]};
a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | (h[3] << 8)], iq1bn_grid_u16[ql[7] | (h[7] << 4)]};
const uint8x16_t shuff_l = load_shuffle_l();
const uint8x16_t shuff_h = load_shuffle_h();
const int32x4_t shift_h = {8, 8, 4, 4};
const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
const uint8x16_t shuff_hh = load_shuffle_hh();
const uint16x8_t mask_hh = vdupq_n_u16(4096);
const int16x8_t shift_hh = load_shift_hh();
const uint16x8_t mult = load_mult();
const uint16x8_t mask = vdupq_n_u16(0x1fff);
const uint16x8_t m3 = vdupq_n_u16(3);
v.val[0] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
v.val[1] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
v.val[2] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
v.val[3] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
auto data = vld1q_u8((const uint8_t *)x);
auto aux1 = vqtbl1q_u8(data, shuff_l);
auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
auto step = vdupq_n_u8(2);
for (int k = 0; k < 4; ++k) {
auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
}
}
};
@@ -4438,10 +4445,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
if constexpr (nrc_y == 1) {
int32x4_t acc[4] = {};
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
deq.prepare_iq1bn_quants(x+2*i+0, v1);
auto q = q8.load_quants64(0, i, 0);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v1);
deq.prepare_iq1bn_quants(x+2*i+1, v1);
q = q8.load_quants64(0, i, 1);
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
}
@@ -4453,8 +4460,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
for (int i = 0; i < nb/2; ++i) {
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2);
deq.prepare_iq1bn_quants(x+2*i+0, v1);
deq.prepare_iq1bn_quants(x+2*i+1, v2);
for (int iy = 0; iy < nrc_y; ++iy) {
auto q = q8.load_quants(iy, i, 0);
@@ -4470,7 +4477,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
}
int i = 2*(nb/2);
if (i < nb) {
deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1);
deq.prepare_iq1bn_quants(x+i, v1);
if constexpr (nrc_y == 1) {
auto q = q8.load_quants(0, i/2, 0);
for (int j = 0; j < 4; ++j) {