iq2_ks: ARM_NEON

This commit is contained in:
Iwan Kawrakow
2024-10-13 09:47:11 +03:00
parent 18cdf624f8
commit 550c40e27f

View File

@@ -4325,14 +4325,20 @@ struct Q2bits {
}
};
template <typename block_q, bool has_row_scale = false>
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
struct BaseDequantizer {
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
inline void new_row(int ix) {
if constexpr (has_row_scale) {
const float * dptr = (const float *)((const char *)vx + ix*bx);
d = *dptr;
x = (const block_q *)(dptr + 1);
if constexpr (scale_is_f16) {
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
d = GGML_FP16_TO_FP32(*dptr);
x = (const block_q *)(dptr + 1);
} else {
const float * dptr = (const float *)((const char *)vx + ix*bx);
d = *dptr;
x = (const block_q *)(dptr + 1);
}
} else {
x = (const block_q *)((const char *)vx + ix*bx);
}
@@ -4939,6 +4945,42 @@ struct DequantizerIQ4KS final : public BaseDequantizer<block_iq4_ks, true> {
const int16x8_t m127 = vdupq_n_s16(-127);
};
struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true> {
DequantizerIQ2KS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
uint32_t aux32 = sc16[0] | (sc16[1] << 16);
uint8x8_t scales8 = vreinterpret_u8_u32(vdup_n_u32(aux32));
scales8 = vand_u8(vzip1_u8(scales8, vshr_n_u8(scales8, 4)), vdup_n_u8(0xf));
uint8x8_t sh = vand_u8(vceq_u8(vand_u8(vdup_n_u8(x[i].extra >> 8), hmask), vdup_n_u8(0)), vdup_n_u8(16));
int16x8_t scales16 = vmovl_s8(vsub_s8(vreinterpret_s8_u8(scales8), vreinterpret_s8_u8(sh)));
int32x4x2_t scales = {vmovl_s16(vget_low_s16(scales16)), vmovl_s16(vget_high_s16(scales16))};
return scales;
}
inline void prepare(int i, int j) {
uint8_t extra = x[i].extra >> 4*j;
bits.prepare(x[i].qs+32*j);
bits.b1.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[0]);
bits.b1.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[1]); extra >>= 1;
bits.b1.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[2]);
bits.b1.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b1.val[3]); extra >>= 1;
bits.b2.val[0] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[0]);
bits.b2.val[1] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[1]); extra >>= 1;
bits.b2.val[2] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[2]);
bits.b2.val[3] = vqtbl1q_s8(values.val[extra & 1], bits.b2.val[3]);
}
Q2bits bits;
const uint8x8_t hmask = vreinterpret_u8_u64(vdup_n_u64(0x8040201008040201));
const int8x16x2_t values = { vreinterpretq_s8_u64(vdupq_n_u64(0x1101f3e1)), vreinterpretq_s8_u64(vdupq_n_u64(0x1606f8e6)) };
};
struct SimpleBits {
uint8x16x4_t b1;
uint8x16x4_t b2;
@@ -6674,6 +6716,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ4_KS:
MulMat::set_functions<DequantizerIQ4KS>(m);
break;
case GGML_TYPE_IQ2_KS:
MulMat::set_functions<DequantizerIQ2KS>(m);
break;
case GGML_TYPE_IQ4_K:
MulMat::set_functions<DequantizerIQ4K>(m);
break;