mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-08 07:20:12 +00:00
Cleanup - Arm i-quants should be good now
Still missing iq1_s and iq1_m, but I don't think I'll do those.
This commit is contained in:
210
iqk_mul_mat.cpp
210
iqk_mul_mat.cpp
@@ -1890,6 +1890,8 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
float d;
|
||||
};
|
||||
|
||||
// ============================= i-quants
|
||||
|
||||
struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
|
||||
static int8x16_t load_values() {
|
||||
@@ -1948,40 +1950,6 @@ struct DequantizerIQ4XS final : public BaseDequantizer<block_iq4_xs> {
|
||||
float d;
|
||||
};
|
||||
|
||||
//static const int8_t keven_signs_q2xs[1024] = {
|
||||
// 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1,
|
||||
// 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1,
|
||||
// 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1,
|
||||
// 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1,
|
||||
// 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1,
|
||||
// 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1,
|
||||
// 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1,
|
||||
// 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1,
|
||||
// 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1,
|
||||
// 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1,
|
||||
// 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1,
|
||||
// 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1,
|
||||
// 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1,
|
||||
// 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1,
|
||||
// 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1,
|
||||
// 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1,
|
||||
// 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1,
|
||||
// 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1,
|
||||
// 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1,
|
||||
// 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1,
|
||||
// 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1,
|
||||
// 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1,
|
||||
// 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1,
|
||||
// 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1,
|
||||
// 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1,
|
||||
// 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1,
|
||||
// 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1,
|
||||
// 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1,
|
||||
// 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1,
|
||||
// 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1,
|
||||
// 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1,
|
||||
// 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
//};
|
||||
const uint64_t keven_signs[128] = {
|
||||
0x0101010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0x010101010101ffff,
|
||||
0xff01010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0xff01010101ffffff,
|
||||
@@ -2022,73 +1990,6 @@ struct SimpleBits {
|
||||
uint8x16x4_t b2;
|
||||
};
|
||||
|
||||
const uint64_t kall_signs[256] = {
|
||||
0x0101010101010101, 0x01010101010101ff, 0x010101010101ff01, 0x010101010101ffff,
|
||||
0x0101010101ff0101, 0x0101010101ff01ff, 0x0101010101ffff01, 0x0101010101ffffff,
|
||||
0x01010101ff010101, 0x01010101ff0101ff, 0x01010101ff01ff01, 0x01010101ff01ffff,
|
||||
0x01010101ffff0101, 0x01010101ffff01ff, 0x01010101ffffff01, 0x01010101ffffffff,
|
||||
0x010101ff01010101, 0x010101ff010101ff, 0x010101ff0101ff01, 0x010101ff0101ffff,
|
||||
0x010101ff01ff0101, 0x010101ff01ff01ff, 0x010101ff01ffff01, 0x010101ff01ffffff,
|
||||
0x010101ffff010101, 0x010101ffff0101ff, 0x010101ffff01ff01, 0x010101ffff01ffff,
|
||||
0x010101ffffff0101, 0x010101ffffff01ff, 0x010101ffffffff01, 0x010101ffffffffff,
|
||||
0x0101ff0101010101, 0x0101ff01010101ff, 0x0101ff010101ff01, 0x0101ff010101ffff,
|
||||
0x0101ff0101ff0101, 0x0101ff0101ff01ff, 0x0101ff0101ffff01, 0x0101ff0101ffffff,
|
||||
0x0101ff01ff010101, 0x0101ff01ff0101ff, 0x0101ff01ff01ff01, 0x0101ff01ff01ffff,
|
||||
0x0101ff01ffff0101, 0x0101ff01ffff01ff, 0x0101ff01ffffff01, 0x0101ff01ffffffff,
|
||||
0x0101ffff01010101, 0x0101ffff010101ff, 0x0101ffff0101ff01, 0x0101ffff0101ffff,
|
||||
0x0101ffff01ff0101, 0x0101ffff01ff01ff, 0x0101ffff01ffff01, 0x0101ffff01ffffff,
|
||||
0x0101ffffff010101, 0x0101ffffff0101ff, 0x0101ffffff01ff01, 0x0101ffffff01ffff,
|
||||
0x0101ffffffff0101, 0x0101ffffffff01ff, 0x0101ffffffffff01, 0x0101ffffffffffff,
|
||||
0x01ff010101010101, 0x01ff0101010101ff, 0x01ff01010101ff01, 0x01ff01010101ffff,
|
||||
0x01ff010101ff0101, 0x01ff010101ff01ff, 0x01ff010101ffff01, 0x01ff010101ffffff,
|
||||
0x01ff0101ff010101, 0x01ff0101ff0101ff, 0x01ff0101ff01ff01, 0x01ff0101ff01ffff,
|
||||
0x01ff0101ffff0101, 0x01ff0101ffff01ff, 0x01ff0101ffffff01, 0x01ff0101ffffffff,
|
||||
0x01ff01ff01010101, 0x01ff01ff010101ff, 0x01ff01ff0101ff01, 0x01ff01ff0101ffff,
|
||||
0x01ff01ff01ff0101, 0x01ff01ff01ff01ff, 0x01ff01ff01ffff01, 0x01ff01ff01ffffff,
|
||||
0x01ff01ffff010101, 0x01ff01ffff0101ff, 0x01ff01ffff01ff01, 0x01ff01ffff01ffff,
|
||||
0x01ff01ffffff0101, 0x01ff01ffffff01ff, 0x01ff01ffffffff01, 0x01ff01ffffffffff,
|
||||
0x01ffff0101010101, 0x01ffff01010101ff, 0x01ffff010101ff01, 0x01ffff010101ffff,
|
||||
0x01ffff0101ff0101, 0x01ffff0101ff01ff, 0x01ffff0101ffff01, 0x01ffff0101ffffff,
|
||||
0x01ffff01ff010101, 0x01ffff01ff0101ff, 0x01ffff01ff01ff01, 0x01ffff01ff01ffff,
|
||||
0x01ffff01ffff0101, 0x01ffff01ffff01ff, 0x01ffff01ffffff01, 0x01ffff01ffffffff,
|
||||
0x01ffffff01010101, 0x01ffffff010101ff, 0x01ffffff0101ff01, 0x01ffffff0101ffff,
|
||||
0x01ffffff01ff0101, 0x01ffffff01ff01ff, 0x01ffffff01ffff01, 0x01ffffff01ffffff,
|
||||
0x01ffffffff010101, 0x01ffffffff0101ff, 0x01ffffffff01ff01, 0x01ffffffff01ffff,
|
||||
0x01ffffffffff0101, 0x01ffffffffff01ff, 0x01ffffffffffff01, 0x01ffffffffffffff,
|
||||
0xff01010101010101, 0xff010101010101ff, 0xff0101010101ff01, 0xff0101010101ffff,
|
||||
0xff01010101ff0101, 0xff01010101ff01ff, 0xff01010101ffff01, 0xff01010101ffffff,
|
||||
0xff010101ff010101, 0xff010101ff0101ff, 0xff010101ff01ff01, 0xff010101ff01ffff,
|
||||
0xff010101ffff0101, 0xff010101ffff01ff, 0xff010101ffffff01, 0xff010101ffffffff,
|
||||
0xff0101ff01010101, 0xff0101ff010101ff, 0xff0101ff0101ff01, 0xff0101ff0101ffff,
|
||||
0xff0101ff01ff0101, 0xff0101ff01ff01ff, 0xff0101ff01ffff01, 0xff0101ff01ffffff,
|
||||
0xff0101ffff010101, 0xff0101ffff0101ff, 0xff0101ffff01ff01, 0xff0101ffff01ffff,
|
||||
0xff0101ffffff0101, 0xff0101ffffff01ff, 0xff0101ffffffff01, 0xff0101ffffffffff,
|
||||
0xff01ff0101010101, 0xff01ff01010101ff, 0xff01ff010101ff01, 0xff01ff010101ffff,
|
||||
0xff01ff0101ff0101, 0xff01ff0101ff01ff, 0xff01ff0101ffff01, 0xff01ff0101ffffff,
|
||||
0xff01ff01ff010101, 0xff01ff01ff0101ff, 0xff01ff01ff01ff01, 0xff01ff01ff01ffff,
|
||||
0xff01ff01ffff0101, 0xff01ff01ffff01ff, 0xff01ff01ffffff01, 0xff01ff01ffffffff,
|
||||
0xff01ffff01010101, 0xff01ffff010101ff, 0xff01ffff0101ff01, 0xff01ffff0101ffff,
|
||||
0xff01ffff01ff0101, 0xff01ffff01ff01ff, 0xff01ffff01ffff01, 0xff01ffff01ffffff,
|
||||
0xff01ffffff010101, 0xff01ffffff0101ff, 0xff01ffffff01ff01, 0xff01ffffff01ffff,
|
||||
0xff01ffffffff0101, 0xff01ffffffff01ff, 0xff01ffffffffff01, 0xff01ffffffffffff,
|
||||
0xffff010101010101, 0xffff0101010101ff, 0xffff01010101ff01, 0xffff01010101ffff,
|
||||
0xffff010101ff0101, 0xffff010101ff01ff, 0xffff010101ffff01, 0xffff010101ffffff,
|
||||
0xffff0101ff010101, 0xffff0101ff0101ff, 0xffff0101ff01ff01, 0xffff0101ff01ffff,
|
||||
0xffff0101ffff0101, 0xffff0101ffff01ff, 0xffff0101ffffff01, 0xffff0101ffffffff,
|
||||
0xffff01ff01010101, 0xffff01ff010101ff, 0xffff01ff0101ff01, 0xffff01ff0101ffff,
|
||||
0xffff01ff01ff0101, 0xffff01ff01ff01ff, 0xffff01ff01ffff01, 0xffff01ff01ffffff,
|
||||
0xffff01ffff010101, 0xffff01ffff0101ff, 0xffff01ffff01ff01, 0xffff01ffff01ffff,
|
||||
0xffff01ffffff0101, 0xffff01ffffff01ff, 0xffff01ffffffff01, 0xffff01ffffffffff,
|
||||
0xffffff0101010101, 0xffffff01010101ff, 0xffffff010101ff01, 0xffffff010101ffff,
|
||||
0xffffff0101ff0101, 0xffffff0101ff01ff, 0xffffff0101ffff01, 0xffffff0101ffffff,
|
||||
0xffffff01ff010101, 0xffffff01ff0101ff, 0xffffff01ff01ff01, 0xffffff01ff01ffff,
|
||||
0xffffff01ffff0101, 0xffffff01ffff01ff, 0xffffff01ffffff01, 0xffffff01ffffffff,
|
||||
0xffffffff01010101, 0xffffffff010101ff, 0xffffffff0101ff01, 0xffffffff0101ffff,
|
||||
0xffffffff01ff0101, 0xffffffff01ff01ff, 0xffffffff01ffff01, 0xffffffff01ffffff,
|
||||
0xffffffffff010101, 0xffffffffff0101ff, 0xffffffffff01ff01, 0xffffffffff01ffff,
|
||||
0xffffffffffff0101, 0xffffffffffff01ff, 0xffffffffffffff01, 0xffffffffffffffff,
|
||||
};
|
||||
|
||||
inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {
|
||||
int32x4x2_t scales;
|
||||
scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1)));
|
||||
@@ -2144,21 +2045,11 @@ struct DequantizerIQ2XXS final : public BaseDequantizer<block_iq2_xxs> {
|
||||
};
|
||||
|
||||
inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) {
|
||||
//constexpr static const uint64x2_t scale_shuffle = { 0x0b030a0209010800, 0x0f070e060d050c04 };
|
||||
//auto aux1 = vld1_u8(sc);
|
||||
//auto aux2 = vshr_n_u8(aux1, 4);
|
||||
//auto scales8 = vqtbl1q_u8(vandq_u8(vcombine_u8(aux1, aux2), vdupq_n_u8(0xf)), vreinterpretq_u8_u64(scale_shuffle));
|
||||
//scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(scales8, 1), vdupq_n_u8(1)));
|
||||
//
|
||||
auto aux = vld1_u8(sc);
|
||||
auto scales_l = vand_u8(aux, vdup_n_u8(0xf));
|
||||
auto scales_h = vshr_n_u8(aux, 4);
|
||||
auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
||||
|
||||
//auto scales_l = vld1_u8(sc);
|
||||
//auto scales_h = vshr_n_u8(scales_l, 4);
|
||||
//auto aux1 = vandq_u8(vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)), vdupq_n_u8(0xf));
|
||||
|
||||
auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1)));
|
||||
int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) };
|
||||
return make_wider(scales16);
|
||||
@@ -2187,24 +2078,6 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
|
||||
b[1] = make1(qs + 2);
|
||||
b[2] = make1(qs + 4);
|
||||
b[3] = make1(qs + 6);
|
||||
// The following is actually slower
|
||||
//auto bits = vld1q_u16(qs);
|
||||
//auto vidx = vandq_u16(bits, vdupq_n_u16(511));
|
||||
//const uint16_t * idx = (const uint16_t *)&vidx;
|
||||
//b[0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[0])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[1])));
|
||||
//b[1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[2])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[3])));
|
||||
//b[2] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[4])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[5])));
|
||||
//b[3] = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + idx[6])), vld1_u8((const uint8_t *)(iq2xs_grid + idx[7])));
|
||||
//vidx = vshrq_n_u16(bits, 9);
|
||||
//b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[0])), vld1_s8((const int8_t *)(keven_signs + idx[1])))));
|
||||
//b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[2])), vld1_s8((const int8_t *)(keven_signs + idx[3])))));
|
||||
//b[2] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[2]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[4])), vld1_s8((const int8_t *)(keven_signs + idx[5])))));
|
||||
//b[3] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[3]),
|
||||
// vcombine_s8(vld1_s8((const int8_t *)(keven_signs + idx[6])), vld1_s8((const int8_t *)(keven_signs + idx[7])))));
|
||||
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
@@ -2218,13 +2091,22 @@ struct DequantizerIQ2XS final : public BaseDequantizer<block_iq2_xs> {
|
||||
|
||||
};
|
||||
|
||||
inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16, const uint8x16_t& smask, const uint8x16_t& step,
|
||||
const uint8x16_t& m1, uint8x16_t& shuffle) {
|
||||
auto aux = vqtbl1q_u8(signs16, shuffle);
|
||||
auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
|
||||
b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
|
||||
shuffle = vaddq_u8(shuffle, step);
|
||||
}
|
||||
struct SignHelper {
|
||||
|
||||
inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); }
|
||||
|
||||
inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) {
|
||||
auto aux = vqtbl1q_u8(signs16, shuffle);
|
||||
auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1));
|
||||
b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s));
|
||||
shuffle = vaddq_u8(shuffle, step);
|
||||
}
|
||||
|
||||
const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
const uint8x16_t step = vdupq_n_u8(2);
|
||||
uint8x16_t shuffle;
|
||||
};
|
||||
|
||||
struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
@@ -2238,8 +2120,7 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
return prepare_4bit_scales16(x[i].scales);
|
||||
}
|
||||
|
||||
static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh,
|
||||
const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, uint8x16_t * b) {
|
||||
static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) {
|
||||
uint32_t aux32[2];
|
||||
const uint16_t * aux16 = (const uint16_t *)aux32;
|
||||
for (int k = 0; k < 2; ++k) {
|
||||
@@ -2248,30 +2129,27 @@ struct DequantizerIQ2S final : public BaseDequantizer<block_iq2_s> {
|
||||
aux32[1] &= 0x03000300;
|
||||
b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))),
|
||||
vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1]))));
|
||||
apply_signs_1(b+2*k+0, signs16, smask, step, m1, shuffle);
|
||||
sh.apply_signs_1(b+2*k+0, signs16);
|
||||
|
||||
b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))),
|
||||
vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3]))));
|
||||
apply_signs_1(b+2*k+1, signs16, smask, step, m1, shuffle);
|
||||
sh.apply_signs_1(b+2*k+1, signs16);
|
||||
}
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
|
||||
const auto smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
const auto m1 = vdupq_n_u8(1);
|
||||
const auto step = vdupq_n_u8(2);
|
||||
|
||||
const auto * qs = x[i].qs + 16*j;
|
||||
const auto * qh = x[i].qh + 4*j;
|
||||
const auto signs16 = vld1q_u8(qs + QK_K/8);
|
||||
|
||||
auto shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
|
||||
make4(signs16, shuffle, qs+0, qh+0, smask, step, m1, bits.b1.val);
|
||||
make4(signs16, shuffle, qs+8, qh+2, smask, step, m1, bits.b2.val);
|
||||
sh.init();
|
||||
make4(sh, signs16, qs+0, qh+0, bits.b1.val);
|
||||
make4(sh, signs16, qs+8, qh+2, bits.b2.val);
|
||||
}
|
||||
|
||||
SimpleBits bits;
|
||||
SignHelper sh;
|
||||
|
||||
float d;
|
||||
|
||||
@@ -2333,52 +2211,38 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
|
||||
return scales;
|
||||
}
|
||||
|
||||
static inline void make2(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint16x8_t& idx_l, uint8_t qh,
|
||||
const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, const int8x16_t& hshift, uint8x16_t * b) {
|
||||
static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh,
|
||||
const int8x16_t& hshift, uint8x16_t * b) {
|
||||
auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256)));
|
||||
const uint16_t * idx = (const uint16_t *)&vindex;
|
||||
b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
|
||||
b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
|
||||
apply_signs_1(b+0, signs16, smask, step, m1, shuffle);
|
||||
apply_signs_1(b+1, signs16, smask, step, m1, shuffle);
|
||||
sh.apply_signs_1(b+0, signs16);
|
||||
sh.apply_signs_1(b+1, signs16);
|
||||
}
|
||||
static inline void make4(const uint8x16_t& signs16, uint8x16_t& shuffle, const uint8_t * qs, const uint8_t * qh,
|
||||
const uint8x16_t& smask, const uint8x16_t& step, const uint8x16_t& m1, const int8x16_t& hshift, uint8x16_t * b) {
|
||||
static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh,
|
||||
const int8x16_t& hshift, uint8x16_t * b) {
|
||||
auto idx_l = vld1q_u8(qs);
|
||||
make2(signs16, shuffle, vmovl_u8(vget_low_u8 (idx_l)), qh[0], smask, step, m1, hshift, b+0);
|
||||
make2(signs16, shuffle, vmovl_u8(vget_high_u8(idx_l)), qh[1], smask, step, m1, hshift, b+2);
|
||||
//auto vindex = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[0]), hshift), vdupq_n_u16(256)));
|
||||
//const uint16_t * idx = (const uint16_t *)&vindex;
|
||||
//b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
|
||||
//b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
|
||||
//apply_signs_1(b+0, signs16, smask, step, m1, shuffle);
|
||||
//apply_signs_1(b+1, signs16, smask, step, m1, shuffle);
|
||||
//vindex = vorrq_u16(vmovl_u8(vget_high_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[1]), hshift), vdupq_n_u16(256)));
|
||||
//b[2] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]});
|
||||
//b[3] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]});
|
||||
//apply_signs_1(b+2, signs16, smask, step, m1, shuffle);
|
||||
//apply_signs_1(b+3, signs16, smask, step, m1, shuffle);
|
||||
make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0);
|
||||
make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
|
||||
static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1};
|
||||
|
||||
const auto smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201));
|
||||
const auto m1 = vdupq_n_u8(1);
|
||||
const auto step = vdupq_n_u8(2);
|
||||
const auto hshift = vld1q_s16(k_shift);
|
||||
|
||||
const auto * qs = x[i].qs + 32*j;
|
||||
const auto * qh = x[i].qh + 4*j;
|
||||
const auto signs16 = vld1q_u8(x[i].signs + 16*j);
|
||||
|
||||
auto shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1));
|
||||
make4(signs16, shuffle, qs+ 0, qh+0, smask, step, m1, hshift, bits.b1.val);
|
||||
make4(signs16, shuffle, qs+16, qh+2, smask, step, m1, hshift, bits.b2.val);
|
||||
sh.init();
|
||||
make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val);
|
||||
make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val);
|
||||
}
|
||||
|
||||
SimpleBits bits;
|
||||
SignHelper sh;
|
||||
uint32x4x2_t gas;
|
||||
|
||||
float d;
|
||||
|
||||
Reference in New Issue
Block a user