mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-01 03:41:53 +00:00
Refactor iqk: factor out iqk quants (NEON)
This commit is contained in:
@@ -2129,6 +2129,511 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
|
||||
#else
|
||||
// ----------------------------------------- __aarch64__ ---------------------------------------------
|
||||
|
||||
namespace {
|
||||
|
||||
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
|
||||
int32x4x4_t scales = {
|
||||
vmovl_s16(vget_low_s16 (scales16.val[0])),
|
||||
vmovl_s16(vget_high_s16(scales16.val[0])),
|
||||
vmovl_s16(vget_low_s16 (scales16.val[1])),
|
||||
vmovl_s16(vget_high_s16(scales16.val[1])),
|
||||
};
|
||||
return scales;
|
||||
}
|
||||
|
||||
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
|
||||
int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))};
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums(iy, i);
|
||||
int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
|
||||
int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
|
||||
int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
|
||||
int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
|
||||
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
|
||||
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
|
||||
struct Scale16Extra {
|
||||
template <typename Q8>
|
||||
static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val,
|
||||
const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) {
|
||||
uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra));
|
||||
e8 = vceqq_u8(vandq_u8(e8, emask), emask);
|
||||
e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff);
|
||||
int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)),
|
||||
vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))};
|
||||
accum_mins_16(extra16, q8, acc, i, d);
|
||||
return make_wider_8(scales8);
|
||||
}
|
||||
|
||||
constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040};
|
||||
constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09};
|
||||
};
|
||||
|
||||
// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because
|
||||
// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to
|
||||
// be signed or unsigned. As the Q8_K quants are signed, we need to have the
|
||||
// iq4_s quants also signed. We can only use unsigned values in k-quants
|
||||
// because they are all within the valid int8_t range.
|
||||
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
|
||||
DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare16(x[i].qs+64*j);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
const uint32_t * aux32 = (const uint32_t *)scales_h;
|
||||
uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
|
||||
uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
|
||||
int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
|
||||
return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16_t values;
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
|
||||
DequantizerIQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
hbits = vld1q_u8_x2(x[i].qh); // hbits.val[0] holds 0....15, 32...47, 64...79, 96...111, 128...143, 160...175, 192...207, 224...239
|
||||
// hbits.val[1] holds 16...31, 48...63, 80...95, 112..127, 144...159, 176...191, 208...223, 240...255
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 2, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+64*j);
|
||||
if (j == 1) {
|
||||
for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
|
||||
}
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm));
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
const uint32_t * aux32 = (const uint32_t *)scales_h;
|
||||
uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
|
||||
uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
|
||||
int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
|
||||
return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x2_t values;
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
const uint8x16_t hm = vdupq_n_u8(0x10);
|
||||
uint8x16x2_t hbits;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
|
||||
DequantizerIQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x4(iq6nl_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 1, vld1q_s8(x[i].scales), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+64*j);
|
||||
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hm));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hm));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), hm));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), hm));
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x4_t values;
|
||||
const uint8x16_t hm = vdupq_n_u8(0x30);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
|
||||
DequantizerIQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 5, make_scales(x[i].scales), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(const uint8_t * scales_l) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(scl8), vdupq_n_s8(-8));
|
||||
return vqtbl1q_s8(scales, hshuff);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1));
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
|
||||
DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
if (j == 0) {
|
||||
hbits = vld1q_u8_x2(x[i].qh);
|
||||
}
|
||||
else {
|
||||
hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
|
||||
hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
|
||||
}
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask));
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1));
|
||||
uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask);
|
||||
signs = vorrq_u8(signs, vdupq_n_u8(1));
|
||||
// scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15
|
||||
// signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
|
||||
scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle)));
|
||||
return vqtbl1q_s8(scales, hshuff);
|
||||
}
|
||||
inline static uint8x16_t load_sign_shuffle() {
|
||||
static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
|
||||
return vld1q_u8(k_shuff);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
uint8x16x2_t hbits;
|
||||
const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1));
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
const uint8x16_t hmask = vdupq_n_u8(4);
|
||||
const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010});
|
||||
const uint8x16_t sign_shuffle = load_sign_shuffle();
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
|
||||
|
||||
DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
(void)q8;
|
||||
(void)acc;
|
||||
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare16(x[i].qs+64*j);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[x[i].scales[4*j+k] & 1], bits.b1.val[k]));
|
||||
bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[x[i].scales[4*j+k] & 1], bits.b2.val[k]));
|
||||
}
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x2_t values;
|
||||
const uint16x8_t mask = vdupq_n_u16(254);
|
||||
const int16x8_t m127 = vdupq_n_s16(-127);
|
||||
};
|
||||
|
||||
struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
|
||||
DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc),
|
||||
values(vld1q_s8_x4(iq5nl_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
(void)q8;
|
||||
(void)acc;
|
||||
auto sas8 = vld1_u8(x[i].scales);
|
||||
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127);
|
||||
hbits = vld1q_u8_x2(x[i].qh);
|
||||
sas = vcombine_u8(sas8, sas8);
|
||||
sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5);
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+64*j);
|
||||
if (j == 1) {
|
||||
for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
|
||||
}
|
||||
auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5);
|
||||
bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)));
|
||||
bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)));
|
||||
shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5);
|
||||
bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)));
|
||||
bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)));
|
||||
for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
|
||||
shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5);
|
||||
bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)));
|
||||
bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)));
|
||||
shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5);
|
||||
bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)));
|
||||
bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)));
|
||||
for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x4_t values;
|
||||
const uint8x16_t hm = vdupq_n_u8(0x10);
|
||||
const uint16x8_t mask = vdupq_n_u16(254);
|
||||
const int16x8_t m127 = vdupq_n_s16(-127);
|
||||
uint8x16x2_t hbits;
|
||||
uint8x16_t sas;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
|
||||
|
||||
DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
(void)q8;
|
||||
(void)acc;
|
||||
auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs);
|
||||
q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift));
|
||||
aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift));
|
||||
q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask);
|
||||
q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1));
|
||||
q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask);
|
||||
q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1));
|
||||
}
|
||||
make_quants(q4bits_1, bits, aux);
|
||||
auto scales16 = vld1q_s16(aux);
|
||||
scales16 = vaddq_s16(vandq_s16(scales16, mask), m127);
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const {
|
||||
bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b));
|
||||
bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4));
|
||||
bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b));
|
||||
bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4));
|
||||
bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b));
|
||||
bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4));
|
||||
bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b));
|
||||
bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4));
|
||||
}
|
||||
inline void prepare([[maybe_unused]] int i, int j) {
|
||||
if (j == 0) return;
|
||||
make_quants(q4bits_2, bits, aux+4);
|
||||
}
|
||||
static int16x8_t load_shift() {
|
||||
static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
return vld1q_s16(k_shift);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x2_t values;
|
||||
const uint16x8_t mask = vdupq_n_s16(254);
|
||||
const uint16x8_t bmask = vdupq_n_u16(0xfffe);
|
||||
const uint16x8_t m1 = vdupq_n_u16(1);
|
||||
const int16x8_t shift = load_shift();
|
||||
const int16x8_t m127 = vdupq_n_s16(-127);
|
||||
uint16x8x4_t q4bits_2;
|
||||
int16_t aux[8];
|
||||
};
|
||||
|
||||
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
|
||||
DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
|
||||
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
|
||||
uint32_t aux32 = sc16[0] | (sc16[1] << 16);
|
||||
uint8x8_t scales8 = vreinterpret_u8_u32(vdup_n_u32(aux32));
|
||||
scales8 = vand_u8(vzip1_u8(scales8, vshr_n_u8(scales8, 4)), vdup_n_u8(0xf));
|
||||
uint8x8_t sh = vand_u8(vceq_u8(vand_u8(vdup_n_u8(x[i].extra >> 8), hmask), vdup_n_u8(0)), vdup_n_u8(16));
|
||||
int16x8_t scales16 = vmovl_s8(vsub_s8(vreinterpret_s8_u8(scales8), vreinterpret_s8_u8(sh)));
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
uint8_t extra = x[i].extra >> 4*j;
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
bits.b1.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[0]);
|
||||
bits.b1.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[1]); extra >>= 1;
|
||||
bits.b1.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[2]);
|
||||
bits.b1.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[3]); extra >>= 1;
|
||||
bits.b2.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[0]);
|
||||
bits.b2.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[1]); extra >>= 1;
|
||||
bits.b2.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[2]);
|
||||
bits.b2.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[3]);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
const uint8x8_t hmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201));
|
||||
const int8x16x2_t values = { vreinterpretq_s8_u64(vdupq_n_u64(0x1101f3e1)), vreinterpretq_s8_u64(vdupq_n_u64(0x1606f8e6)) };
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
|
||||
|
||||
auto etypeA = ggml_type(typeA);
|
||||
auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K;
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2KS, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2K, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3K, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4KSS, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4KS, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ4K, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ5KS, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ5K, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ6_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ6K, kernels);
|
||||
break;
|
||||
// case GGML_TYPE_IQ2_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
|
||||
// break;
|
||||
// case GGML_TYPE_IQ3_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
|
||||
//#ifdef HAVE_FANCY_SIMD
|
||||
// func16 = mul_mat_iq3_k_r4_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
// case GGML_TYPE_IQ4_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
|
||||
// func16 = mul_mat_iq4_k_r4_q8_k<16>;
|
||||
// break;
|
||||
// case GGML_TYPE_IQ4_KS_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
|
||||
//#ifndef HAVE_FANCY_SIMD
|
||||
// // For some reason Zen4 does not like this particular function
|
||||
// func16 = mul_mat_iq4_ks_r4_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
// case GGML_TYPE_IQ5_KS_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
|
||||
//#ifndef HAVE_FANCY_SIMD
|
||||
// // For some reason Zen4 does not like this particular function
|
||||
// func16 = mul_mat_iq5_ks_r4_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1632,6 +1632,7 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
#else
|
||||
// --------------------------------------- __aarch64__ ---------------------------------------------
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -785,64 +785,6 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
||||
const block_q8 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
||||
const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
||||
auto mzero = vdupq_n_s32(0);
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
||||
vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
||||
vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
|
||||
auto p12 = vpaddq_s32(p1, p2);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
||||
vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
||||
vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
|
||||
auto p34 = vpaddq_s32(p3, p4);
|
||||
|
||||
auto pall = vpaddq_s32(p12, p34);
|
||||
sumi = vmlaq_s32(sumi, scales.val[j], pall);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
||||
const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
||||
|
||||
auto mzero = vdupq_n_s32(0);
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
|
||||
auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
|
||||
sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
|
||||
auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
|
||||
sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums8(iy, i);
|
||||
int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
|
||||
int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
|
||||
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
|
||||
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
@@ -856,23 +798,6 @@ inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * a
|
||||
}
|
||||
}
|
||||
|
||||
struct Scales8 {
|
||||
uint32_t utmp[4];
|
||||
const uint8_t * sc8 = (const uint8_t *)utmp;
|
||||
template <typename Q8, typename Qx>
|
||||
inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
|
||||
make_q4_scales(x.scales, utmp);
|
||||
int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
|
||||
accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
|
||||
|
||||
uint8x8_t scales8 = vld1_u8(sc8);
|
||||
uint16x8_t scales16 = vmovl_u8(scales8);
|
||||
int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
|
||||
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
|
||||
return scales;
|
||||
}
|
||||
};
|
||||
|
||||
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
|
||||
int32x4x4_t scales = {
|
||||
vmovl_s16(vget_low_s16 (scales16.val[0])),
|
||||
@@ -883,15 +808,6 @@ inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
|
||||
return scales;
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
|
||||
accum_mins_16(scales16, q8, acc, i, c);
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
// ============================= i-quants
|
||||
|
||||
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
|
||||
@@ -921,385 +837,6 @@ struct Scale16Extra {
|
||||
// be signed or unsigned. As the Q8_K quants are signed, we need to have the
|
||||
// iq4_s quants also signed. We can only use unsigned values in k-quants
|
||||
// because they are all within the valid int8_t range.
|
||||
struct DequantizerIQ4K final : public BaseDequantizer<block_iq4_k> {
|
||||
DequantizerIQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8(iq4k_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare16(x[i].qs+64*j);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
const uint32_t * aux32 = (const uint32_t *)scales_h;
|
||||
uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
|
||||
uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
|
||||
int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
|
||||
return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16_t values;
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ5K final : public BaseDequantizer<block_iq5_k> {
|
||||
DequantizerIQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq5nl_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
hbits = vld1q_u8_x2(x[i].qh); // hbits.val[0] holds 0....15, 32...47, 64...79, 96...111, 128...143, 160...175, 192...207, 224...239
|
||||
// hbits.val[1] holds 16...31, 48...63, 80...95, 112..127, 144...159, 176...191, 208...223, 240...255
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 2, make_scales(x[i].scales_l, x[i].scales_h), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+64*j);
|
||||
if (j == 1) {
|
||||
for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
|
||||
}
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm));
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl2q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl2q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(const uint8_t * scales_l, const uint8_t * scales_h) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
const uint32_t * aux32 = (const uint32_t *)scales_h;
|
||||
uint32x4_t sch_32 = {aux32[0] << 4, aux32[0] << 2, aux32[0], aux32[0] >> 2};
|
||||
uint8x16_t sch8 = vandq_u8(vreinterpretq_u8_u32(sch_32), vdupq_n_u8(0x30));
|
||||
int8x16_t scales8 = vorrq_u8(scl8, vqtbl1q_u8(sch8, hshuff));
|
||||
return vaddq_s8(vqtbl1q_s8(scales8, hshuff), vdupq_n_s8(-32));
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x2_t values;
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
const uint8x16_t hm = vdupq_n_u8(0x10);
|
||||
uint8x16x2_t hbits;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ6K final : public BaseDequantizer<block_iq6_k> {
|
||||
DequantizerIQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x4(iq6nl_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 1, vld1q_s8(x[i].scales), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+64*j);
|
||||
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hm));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hm));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), hm));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), hm));
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x4_t values;
|
||||
const uint8x16_t hm = vdupq_n_u8(0x30);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ2K final : public BaseDequantizer<block_iq2_k> {
|
||||
DequantizerIQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 5, make_scales(x[i].scales), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(const uint8_t * scales_l) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(scl8), vdupq_n_s8(-8));
|
||||
return vqtbl1q_s8(scales, hshuff);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x000000001101f3e1));
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ3K final : public BaseDequantizer<block_iq3_k> {
|
||||
DequantizerIQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return Scale16Extra::new_block(i, d, x[i].extra, 4, make_scales(x[i].scales_h, x[i].scales_l), q8, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
if (j == 0) {
|
||||
hbits = vld1q_u8_x2(x[i].qh);
|
||||
}
|
||||
else {
|
||||
hbits.val[0] = vshrq_n_u8(hbits.val[0], 4);
|
||||
hbits.val[1] = vshrq_n_u8(hbits.val[1], 4);
|
||||
}
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hmask));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hmask));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hmask));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hmask));
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], hmask));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], hmask));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 1), hmask));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 1), hmask));
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vqtbl1q_s8(values, bits.b1.val[k]);
|
||||
bits.b2.val[k] = vqtbl1q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
}
|
||||
inline int8x16_t make_scales(uint16_t sign_bits, const uint8_t * scales_l) const {
|
||||
uint8x8_t aux = vld1_u8(scales_l);
|
||||
uint8x16_t scl8 = vandq_u8(vcombine_u8(aux, vshr_n_u8(aux, 4)), vdupq_n_u8(0xf));
|
||||
int8x16_t scales = vaddq_s8(vreinterpretq_s8_u8(vshlq_n_u8(scl8, 1)), vdupq_n_s8(1));
|
||||
uint8x16_t signs = vceqq_u8(vandq_u8(vreinterpretq_u8_u16(vdupq_n_u16(sign_bits)), sign_mask), sign_mask);
|
||||
signs = vorrq_u8(signs, vdupq_n_u8(1));
|
||||
// scales are 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15
|
||||
// signs are 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15
|
||||
scales = vmulq_s8(scales, vreinterpretq_s8_u8(vqtbl1q_u8(signs, sign_shuffle)));
|
||||
return vqtbl1q_s8(scales, hshuff);
|
||||
}
|
||||
inline static uint8x16_t load_sign_shuffle() {
|
||||
static uint8_t k_shuff[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15};
|
||||
return vld1q_u8(k_shuff);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
uint8x16x2_t hbits;
|
||||
const int8x16_t values = vreinterpretq_s8_u64(vdupq_n_u64(0x2f1c0d01f6e9d8c1));
|
||||
const uint8x16_t hshuff = vreinterpretq_u8_u32(uint32x4_t{0x09010800, 0x0b030a02, 0x0d050c04, 0x0f070e06});
|
||||
const uint8x16_t hmask = vdupq_n_u8(4);
|
||||
const uint8x16_t sign_mask = vreinterpretq_u8_u64(uint64x2_t{0x0808040402020101, 0x8080404020201010});
|
||||
const uint8x16_t sign_shuffle = load_sign_shuffle();
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
|
||||
|
||||
DequantizerIQ4KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
(void)q8;
|
||||
(void)acc;
|
||||
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(vld1_u8(x[i].scales)), mask)), m127);
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare16(x[i].qs+64*j);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
bits.b1.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[x[i].scales[4*j+k] & 1], bits.b1.val[k]));
|
||||
bits.b2.val[k] = vreinterpretq_u8_s8(vqtbl1q_s8(values.val[x[i].scales[4*j+k] & 1], bits.b2.val[k]));
|
||||
}
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x2_t values;
|
||||
const uint16x8_t mask = vdupq_n_u16(254);
|
||||
const int16x8_t m127 = vdupq_n_s16(-127);
|
||||
};
|
||||
|
||||
struct DequantizerIQ5KS final : public BaseDequantizer<block_iq5_ks, true> {
|
||||
DequantizerIQ5KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc),
|
||||
values(vld1q_s8_x4(iq5nl_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
(void)q8;
|
||||
(void)acc;
|
||||
auto sas8 = vld1_u8(x[i].scales);
|
||||
auto scales16 = vaddq_s16(vreinterpretq_s16_u16(vandq_u16(vmovl_u8(sas8), mask)), m127);
|
||||
hbits = vld1q_u8_x2(x[i].qh);
|
||||
sas = vcombine_u8(sas8, sas8);
|
||||
sas = vshlq_n_u8(vandq_u8(sas, vdupq_n_u8(1)), 5);
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+64*j);
|
||||
if (j == 1) {
|
||||
for (int k = 0; k < 2; ++k) hbits.val[k] = vshrq_n_u8(hbits.val[k], 4);
|
||||
}
|
||||
auto shift = vdupq_n_u8((x[i].scales[4*j+0] & 1) << 5);
|
||||
bits.b1.val[0] = vaddq_u8(shift, vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), hm)));
|
||||
bits.b1.val[1] = vaddq_u8(shift, vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), hm)));
|
||||
shift = vdupq_n_u8((x[i].scales[4*j+1] & 1) << 5);
|
||||
bits.b1.val[2] = vaddq_u8(shift, vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 3), hm)));
|
||||
bits.b1.val[3] = vaddq_u8(shift, vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 3), hm)));
|
||||
for (int k = 0; k < 4; ++k) bits.b1.val[k] = vqtbl4q_s8(values, bits.b1.val[k]);
|
||||
shift = vdupq_n_u8((x[i].scales[4*j+2] & 1) << 5);
|
||||
bits.b2.val[0] = vaddq_u8(shift, vorrq_u8(bits.b2.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 2), hm)));
|
||||
bits.b2.val[1] = vaddq_u8(shift, vorrq_u8(bits.b2.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 2), hm)));
|
||||
shift = vdupq_n_u8((x[i].scales[4*j+3] & 1) << 5);
|
||||
bits.b2.val[2] = vaddq_u8(shift, vorrq_u8(bits.b2.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 1), hm)));
|
||||
bits.b2.val[3] = vaddq_u8(shift, vorrq_u8(bits.b2.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 1), hm)));
|
||||
for (int k = 0; k < 4; ++k) bits.b2.val[k] = vqtbl4q_s8(values, bits.b2.val[k]);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x4_t values;
|
||||
const uint8x16_t hm = vdupq_n_u8(0x10);
|
||||
const uint16x8_t mask = vdupq_n_u16(254);
|
||||
const int16x8_t m127 = vdupq_n_s16(-127);
|
||||
uint8x16x2_t hbits;
|
||||
uint8x16_t sas;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerIQ4KSS final : public BaseDequantizer<block_iq4_kss, true> {
|
||||
|
||||
DequantizerIQ4KSS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc), values(vld1q_s8_x2(iq4k_values)) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
(void)q8;
|
||||
(void)acc;
|
||||
auto q4bits_1 = vld1q_u16_x4((const uint16_t *)x[i].qs);
|
||||
q4bits_2 = vld1q_u16_x4((const uint16_t *)x[i].qs + 32);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
aux[k+0] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_1.val[k], m1), shift));
|
||||
aux[k+4] = vaddvq_s16(vshlq_s16(vandq_u16(q4bits_2.val[k], m1), shift));
|
||||
q4bits_1.val[k] = vandq_u16(q4bits_1.val[k], bmask);
|
||||
q4bits_1.val[k] = veorq_u16(q4bits_1.val[k], vshrq_n_u16(q4bits_1.val[k], 1));
|
||||
q4bits_2.val[k] = vandq_u16(q4bits_2.val[k], bmask);
|
||||
q4bits_2.val[k] = veorq_u16(q4bits_2.val[k], vshrq_n_u16(q4bits_2.val[k], 1));
|
||||
}
|
||||
make_quants(q4bits_1, bits, aux);
|
||||
auto scales16 = vld1q_s16(aux);
|
||||
scales16 = vaddq_s16(vandq_s16(scales16, mask), m127);
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void make_quants(uint16x8x4_t& q4bits, Q4bits& bits, const int16_t * aux) const {
|
||||
bits.b1.val[0] = vqtbl1q_s8(values.val[aux[0] & 1], vandq_u8(q4bits.val[0], bits.m4b));
|
||||
bits.b1.val[1] = vqtbl1q_s8(values.val[aux[0] & 1], vshrq_n_u8(q4bits.val[0], 4));
|
||||
bits.b1.val[2] = vqtbl1q_s8(values.val[aux[1] & 1], vandq_u8(q4bits.val[1], bits.m4b));
|
||||
bits.b1.val[3] = vqtbl1q_s8(values.val[aux[1] & 1], vshrq_n_u8(q4bits.val[1], 4));
|
||||
bits.b2.val[0] = vqtbl1q_s8(values.val[aux[2] & 1], vandq_u8(q4bits.val[2], bits.m4b));
|
||||
bits.b2.val[1] = vqtbl1q_s8(values.val[aux[2] & 1], vshrq_n_u8(q4bits.val[2], 4));
|
||||
bits.b2.val[2] = vqtbl1q_s8(values.val[aux[3] & 1], vandq_u8(q4bits.val[3], bits.m4b));
|
||||
bits.b2.val[3] = vqtbl1q_s8(values.val[aux[3] & 1], vshrq_n_u8(q4bits.val[3], 4));
|
||||
}
|
||||
inline void prepare([[maybe_unused]] int i, int j) {
|
||||
if (j == 0) return;
|
||||
make_quants(q4bits_2, bits, aux+4);
|
||||
}
|
||||
static int16x8_t load_shift() {
|
||||
static const int16_t k_shift[8] = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
return vld1q_s16(k_shift);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
const int8x16x2_t values;
|
||||
const uint16x8_t mask = vdupq_n_s16(254);
|
||||
const uint16x8_t bmask = vdupq_n_u16(0xfffe);
|
||||
const uint16x8_t m1 = vdupq_n_u16(1);
|
||||
const int16x8_t shift = load_shift();
|
||||
const int16x8_t m127 = vdupq_n_s16(-127);
|
||||
uint16x8x4_t q4bits_2;
|
||||
int16_t aux[8];
|
||||
};
|
||||
|
||||
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
|
||||
DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
|
||||
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
|
||||
uint32_t aux32 = sc16[0] | (sc16[1] << 16);
|
||||
uint8x8_t scales8 = vreinterpret_u8_u32(vdup_n_u32(aux32));
|
||||
scales8 = vand_u8(vzip1_u8(scales8, vshr_n_u8(scales8, 4)), vdup_n_u8(0xf));
|
||||
uint8x8_t sh = vand_u8(vceq_u8(vand_u8(vdup_n_u8(x[i].extra >> 8), hmask), vdup_n_u8(0)), vdup_n_u8(16));
|
||||
int16x8_t scales16 = vmovl_s8(vsub_s8(vreinterpret_s8_u8(scales8), vreinterpret_s8_u8(sh)));
|
||||
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
|
||||
return scales;
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
uint8_t extra = x[i].extra >> 4*j;
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
bits.b1.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[0]);
|
||||
bits.b1.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[1]); extra >>= 1;
|
||||
bits.b1.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[2]);
|
||||
bits.b1.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[3]); extra >>= 1;
|
||||
bits.b2.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[0]);
|
||||
bits.b2.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[1]); extra >>= 1;
|
||||
bits.b2.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[2]);
|
||||
bits.b2.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[3]);
|
||||
}
|
||||
|
||||
Q2bits bits;
|
||||
const uint8x8_t hmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201));
|
||||
const int8x16x2_t values = { vreinterpretq_s8_u64(vdupq_n_u64(0x1101f3e1)), vreinterpretq_s8_u64(vdupq_n_u64(0x1606f8e6)) };
|
||||
|
||||
};
|
||||
|
||||
struct SimpleBits {
|
||||
uint8x16x4_t b1;
|
||||
@@ -4269,33 +3806,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
MulMat::set_functions<DequantizerIQ4KS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
MulMat::set_functions<DequantizerIQ4KSS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
MulMat::set_functions<DequantizerIQ2KS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K:
|
||||
MulMat::set_functions<DequantizerIQ4K>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K:
|
||||
MulMat::set_functions<DequantizerIQ5K>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
MulMat::set_functions<DequantizerIQ5KS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ6_K:
|
||||
MulMat::set_functions<DequantizerIQ6K>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K:
|
||||
MulMat::set_functions<DequantizerIQ2K>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K:
|
||||
MulMat::set_functions<DequantizerIQ3K>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
case GGML_TYPE_IQ4_KS:
|
||||
case GGML_TYPE_IQ4_K:
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
MulMat::set_functions<DequantizerIQ2XXS>(m);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user