mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-02 10:00:07 +00:00
iq1bn(no lookup): NEON
Pretty bad.
This commit is contained in:
@@ -4375,47 +4375,54 @@ static const uint64_t kall_signs[257] = {
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
const uint8x16x4_t sign_shuffles = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0000000000000000, 0x0101010101010101}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0202020202020202, 0x0303030303030303}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0404040404040404, 0x0505050505050505}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0606060606060606, 0x0707070707070707}),
|
||||
};
|
||||
const int8x16_t shift = vreinterpretq_s16_u64(vdupq_n_u64(0xfffafffcfffe0000));
|
||||
const uint8x16_t qmask = vdupq_n_u8(3);
|
||||
const uint8x16_t shuff1 = vreinterpretq_u8_u64(uint64x2_t{0x0100010001000100, 0x0908090809080908});
|
||||
const uint8x16_t mask1 = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
int8x16x4_t signs;
|
||||
uint64x2x4_t a;
|
||||
inline void prepare_iq1bn_quants(uint8_t extra, const uint8_t * ql, const uint8_t * qh, int8x16x4_t& v) {
|
||||
auto all_signs = vld1q_u8((const uint8_t *)(kall_signs + extra));
|
||||
//auto all_signs = vdupq_n_u8(extra);
|
||||
//all_signs = vorrq_u8(vceqq_u8(vandq_u8(all_signs, mask1), mask1), m1);
|
||||
signs.val[0] = vqtbl1q_u8(all_signs, sign_shuffles.val[0]);
|
||||
signs.val[1] = vqtbl1q_u8(all_signs, sign_shuffles.val[1]);
|
||||
signs.val[2] = vqtbl1q_u8(all_signs, sign_shuffles.val[2]);
|
||||
signs.val[3] = vqtbl1q_u8(all_signs, sign_shuffles.val[3]);
|
||||
|
||||
uint32_t aux32[2];
|
||||
std::memcpy(aux32, qh, 4);
|
||||
aux32[1] = aux32[0] & 0xf0f0f0f0;
|
||||
aux32[0] &= 0x0f0f0f0f;
|
||||
const uint8_t * h = (const uint8_t *)aux32;
|
||||
static inline uint8x16_t load_shuffle_l() {
|
||||
static const uint8_t data[16] = {1, 255, 2, 255, 3, 255, 4, 255, 5, 255, 6, 255, 7, 255, 8, 255};
|
||||
return vld1q_u8(data);
|
||||
}
|
||||
static inline uint8x16_t load_shuffle_h() {
|
||||
static const uint8_t data[16] = {9, 255, 10, 255, 11, 255, 12, 255, 9, 255, 10, 255, 11, 255, 12, 255};
|
||||
return vld1q_u8(data);
|
||||
}
|
||||
static inline uint8x16_t load_shuffle_hh() {
|
||||
static const uint8_t data[16] = {0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255, 0, 255};
|
||||
return vld1q_u8(data);
|
||||
}
|
||||
static inline int16x8_t load_shift_hh() {
|
||||
static const int16_t data[8] = {12, 11, 10, 9, 8, 7, 6, 5};
|
||||
return vld1q_s16(data);
|
||||
}
|
||||
static inline uint16x8_t load_mult() {
|
||||
static const uint16_t data[8] = {2187, 729, 243, 81, 27, 9, 3, 1};
|
||||
return vld1q_u16(data);
|
||||
}
|
||||
|
||||
a.val[0] = uint64x2_t{iq1bn_grid_u16[ql[0] | (h[0] << 8)], iq1bn_grid_u16[ql[1] | (h[4] << 4)]};
|
||||
a.val[1] = uint64x2_t{iq1bn_grid_u16[ql[2] | (h[1] << 8)], iq1bn_grid_u16[ql[3] | (h[5] << 4)]};
|
||||
a.val[2] = uint64x2_t{iq1bn_grid_u16[ql[4] | (h[2] << 8)], iq1bn_grid_u16[ql[5] | (h[6] << 4)]};
|
||||
a.val[3] = uint64x2_t{iq1bn_grid_u16[ql[6] | (h[3] << 8)], iq1bn_grid_u16[ql[7] | (h[7] << 4)]};
|
||||
const uint8x16_t shuff_l = load_shuffle_l();
|
||||
const uint8x16_t shuff_h = load_shuffle_h();
|
||||
const int32x4_t shift_h = {8, 8, 4, 4};
|
||||
const uint16x8_t mask_h = vdupq_n_u16(0x0f00);
|
||||
const uint8x16_t shuff_hh = load_shuffle_hh();
|
||||
const uint16x8_t mask_hh = vdupq_n_u16(4096);
|
||||
const int16x8_t shift_hh = load_shift_hh();
|
||||
const uint16x8_t mult = load_mult();
|
||||
const uint16x8_t mask = vdupq_n_u16(0x1fff);
|
||||
const uint16x8_t m3 = vdupq_n_u16(3);
|
||||
|
||||
v.val[0] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[0]), shuff1), shift), qmask), m1);
|
||||
v.val[1] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[1]), shuff1), shift), qmask), m1);
|
||||
v.val[2] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[2]), shuff1), shift), qmask), m1);
|
||||
v.val[3] = vsubq_s8(vandq_u8(vshlq_u16(vqtbl1q_u8(vreinterpretq_u8_u64(a.val[3]), shuff1), shift), qmask), m1);
|
||||
|
||||
v.val[0] = vmulq_s8(v.val[0], signs.val[0]);
|
||||
v.val[1] = vmulq_s8(v.val[1], signs.val[1]);
|
||||
v.val[2] = vmulq_s8(v.val[2], signs.val[2]);
|
||||
v.val[3] = vmulq_s8(v.val[3], signs.val[3]);
|
||||
inline void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
|
||||
auto data = vld1q_u8((const uint8_t *)x);
|
||||
auto aux1 = vqtbl1q_u8(data, shuff_l);
|
||||
auto aux2 = vandq_u16(vshlq_u32(vqtbl1q_u8(data, shuff_h), shift_h), mask_h);
|
||||
auto aux3 = vandq_u16(vshlq_u16(vqtbl1q_u8(data, shuff_hh), shift_hh), mask_hh);
|
||||
auto all = vorrq_u16(vorrq_u16(aux1, aux2), aux3);
|
||||
auto shuffle = vreinterpretq_u8_u16(vdupq_n_u16(0x0100));
|
||||
auto step = vdupq_n_u8(2);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto v1 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
|
||||
auto v2 = vreinterpretq_u16_u8(vqtbl1q_u8(all, shuffle)); shuffle = vaddq_u8(shuffle, step);
|
||||
v1 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v1, mult), mask), m3), 13);
|
||||
v2 = vshrq_n_u16(vmulq_u16(vandq_u16(vmulq_u16(v2, mult), mask), m3), 13);
|
||||
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(vcombine_u8(vmovn_u16(v1), vmovn_u16(v2))), m1);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4438,10 +4445,10 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
if constexpr (nrc_y == 1) {
|
||||
int32x4_t acc[4] = {};
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
|
||||
deq.prepare_iq1bn_quants(x+2*i+0, v1);
|
||||
auto q = q8.load_quants64(0, i, 0);
|
||||
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v1);
|
||||
deq.prepare_iq1bn_quants(x+2*i+1, v1);
|
||||
q = q8.load_quants64(0, i, 1);
|
||||
for (int j = 0; j < 4; ++j) acc[j] = ggml_vdotq_s32(acc[j], q.val[j], v1.val[j]);
|
||||
}
|
||||
@@ -4453,8 +4460,8 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
|
||||
for (int i = 0; i < nb/2; ++i) {
|
||||
|
||||
deq.prepare_iq1bn_quants(x[2*i+0].extra, x[2*i+0].ql, x[2*i+0].qh, v1);
|
||||
deq.prepare_iq1bn_quants(x[2*i+1].extra, x[2*i+1].ql, x[2*i+1].qh, v2);
|
||||
deq.prepare_iq1bn_quants(x+2*i+0, v1);
|
||||
deq.prepare_iq1bn_quants(x+2*i+1, v2);
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto q = q8.load_quants(iy, i, 0);
|
||||
@@ -4470,7 +4477,7 @@ static void mul_mat_iq1bn_q8_K64(int n, const void * vx, size_t bx, const DataIn
|
||||
}
|
||||
int i = 2*(nb/2);
|
||||
if (i < nb) {
|
||||
deq.prepare_iq1bn_quants(x[i].extra, x[i].ql, x[i].qh, v1);
|
||||
deq.prepare_iq1bn_quants(x+i, v1);
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto q = q8.load_quants(0, i/2, 0);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
|
||||
Reference in New Issue
Block a user