Refactor iqk: factor out k-quants (NEON)

This commit is contained in:
Iwan Kawrakow
2025-05-18 17:41:54 +03:00
parent 28b94800c1
commit c805a19202
3 changed files with 601 additions and 422 deletions

View File

@@ -526,6 +526,7 @@ struct Q4Bits {
#endif
#else
// ------------------------------------ __aarch64__ --------------------------------------------------
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
@@ -547,6 +548,214 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
const block_q8 * y[nrc_y];
};
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
struct BaseDequantizer {
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
inline void new_row(int ix) {
if constexpr (has_row_scale) {
if constexpr (scale_is_f16) {
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
d = GGML_FP16_TO_FP32(*dptr);
x = (const block_q *)(dptr + 1);
} else {
const float * dptr = (const float *)((const char *)vx + ix*bx);
d = *dptr;
x = (const block_q *)(dptr + 1);
}
} else {
x = (const block_q *)((const char *)vx + ix*bx);
}
}
const void * vx;
const block_q * x;
const size_t bx;
const int nrc;
float d;
};
struct Q4bits {
const uint8x16_t m4b = vdupq_n_u8(0xf);
uint8x16x4_t b1, b2;
inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
b.val[0] = vandq_u8(val[0], m4b);
b.val[2] = vshrq_n_u8(val[0], 4);
b.val[1] = vandq_u8(val[1], m4b);
b.val[3] = vshrq_n_u8(val[1], 4);
}
inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
b.val[0] = vandq_u8(val[0], m4b);
b.val[1] = vshrq_n_u8(val[0], 4);
b.val[2] = vandq_u8(val[1], m4b);
b.val[3] = vshrq_n_u8(val[1], 4);
}
inline void prepare(const uint8_t * qs) {
auto q4bits = vld1q_u8_x2(qs);
prepare4(b1, q4bits.val);
q4bits = vld1q_u8_x2(qs+32);
prepare4(b2, q4bits.val);
}
inline void prepare_v2(const uint8_t * qs) {
auto q4bits = vld1q_u8_x4(qs);
prepare4(b1, q4bits.val+0);
prepare4(b2, q4bits.val+2);
}
inline void prepare64(const uint8_t * qs) {
auto q4bits = vld1q_u8_x4(qs);
b1.val[0] = vandq_u8(q4bits.val[0], m4b);
b1.val[1] = vandq_u8(q4bits.val[1], m4b);
b1.val[2] = vandq_u8(q4bits.val[2], m4b);
b1.val[3] = vandq_u8(q4bits.val[3], m4b);
b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
}
inline void prepare16(const uint8_t * qs) {
auto q4bits = vld1q_u8_x2(qs);
prepare4_16(b1, q4bits.val);
q4bits = vld1q_u8_x2(qs+32);
prepare4_16(b2, q4bits.val);
}
inline void prepare16_v2(const uint8_t * qs) {
auto q4bits = vld1q_u8_x4(qs);
prepare4_16(b1, q4bits.val+0);
prepare4_16(b2, q4bits.val+2);
}
};
struct Q2bits {
const uint8x16_t m4b = vdupq_n_u8(0x03);
uint8x16x4_t b1, b2;
inline void prepare(const uint8_t * qs) {
auto q2bits = vld1q_u8_x2(qs);
b1.val[0] = vandq_u8(q2bits.val[0], m4b);
b1.val[1] = vandq_u8(q2bits.val[1], m4b);
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
b1.val[2] = vandq_u8(q2bits.val[0], m4b);
b1.val[3] = vandq_u8(q2bits.val[1], m4b);
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
b2.val[0] = vandq_u8(q2bits.val[0], m4b);
b2.val[1] = vandq_u8(q2bits.val[1], m4b);
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
b2.val[2] = vandq_u8(q2bits.val[0], m4b);
b2.val[3] = vandq_u8(q2bits.val[1], m4b);
}
};
template <typename Q8>
static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
auto mzero = vdupq_n_s32(0);
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
auto p12 = vpaddq_s32(p1, p2);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
auto p34 = vpaddq_s32(p3, p4);
auto pall = vpaddq_s32(p12, p34);
sumi = vmlaq_s32(sumi, scales.val[j], pall);
}
template <typename Q8>
static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
auto mzero = vdupq_n_s32(0);
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
}
template <typename Dequantizer, int nrc_y>
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_K> q8(info);
Dequantizer deq(vx, bx, nrc_y);
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
float32x4_t acc[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
for (int i = 0; i < nb; ++i) {
int32x4_t sumi[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
deq.process_scales(i, q8, acc);
deq.prepare(i, 0);
deq.compute(q8, i, 0, sumi);
deq.prepare(i, 1);
deq.compute(q8, i, 1, sumi);
} else {
if constexpr (Dequantizer::num_blocks() == 8) {
auto scales = deq.new_block(i, q8, acc);
deq.prepare(i, 0);
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
deq.prepare(i, 1);
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
}
else if constexpr (Dequantizer::num_blocks() == 16) {
auto scales = deq.new_block(i, q8, acc);
deq.prepare(i, 0);
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
deq.prepare(i, 1);
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
}
else {
GGML_ASSERT(false);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(acc[iy]));
}
}
}
#endif
#endif

View File

@@ -1781,6 +1781,397 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
#else
// --------------------------------- __aarch64__ --------------------------------------
namespace {
template <typename Q8>
inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums8(iy, i);
int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
}
}
template <typename Q8>
inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums(iy, i);
int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
}
}
struct Scales8 {
uint32_t utmp[4];
const uint8_t * sc8 = (const uint8_t *)utmp;
template <typename Q8, typename Qx>
inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
make_q4_scales(x.scales, utmp);
int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
uint8x8_t scales8 = vld1_u8(sc8);
uint16x8_t scales16 = vmovl_u8(scales8);
int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
return scales;
}
};
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
return s8.process_scales_mins(x[i], q8, i, acc);
}
inline void prepare(int i, int j) {
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
else bits.prepare(x[i].qs+64*j);
}
Q4bits bits;
Scales8 s8;
};
struct HighBit5 {
const uint8x16_t mhb = vdupq_n_u8(0x10);
uint8x16x2_t bits;
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
if (do_shift) {
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
}
}
};
struct HighBit3 {
const uint8x16_t mhb = vdupq_n_u8(0x04);
uint8x16x2_t bits;
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
if (do_shift) {
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
}
}
};
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].qh);
return s8.process_scales_mins(x[i], q8, i, acc);
}
inline void prepare(int i, int j) {
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
else bits.prepare(x[i].qs+64*j);
h.apply(bits.b1, bits.b2, j == 0);
}
Q4bits bits;
HighBit5 h;
Scales8 s8;
uint8x16x2_t hbits;
};
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
int32x4x4_t scales = {
vmovl_s16(vget_low_s16 (scales16.val[0])),
vmovl_s16(vget_high_s16(scales16.val[0])),
vmovl_s16(vget_low_s16 (scales16.val[1])),
vmovl_s16(vget_high_s16(scales16.val[1])),
};
return scales;
}
template <typename Q8>
inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
accum_mins_16(scales16, q8, acc, i, c);
return make_wider(scales16);
}
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
}
inline void prepare(int i, int j) {
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
bits.prepare64(x[i].ql+64*j);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
}
Q4bits bits;
const uint8x16_t mhb = vdupq_n_u8(0x30);
};
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].hmask);
mask = vdupq_n_u8(0x01);
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
uint32_t aux2 = sc16[4] | (sc16[5] << 16);
aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
if (nrc > 1) {
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
}
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
return make_wider(scales16);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
if (nrc > 1) {
h.apply(bits.b1, bits.b2, j == 0);
} else {
auto minus4 = vdupq_n_u8(0xfc);
auto zero = vdupq_n_u8(0);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
}
}
uint32_t aux32[4];
Q2bits bits;
uint8x16_t mask;
HighBit3 h;
};
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return true; }
template <typename Q8>
inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
auto scales_and_mins = vld1q_u8(x[i].scales);
auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
}
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
process_scales(i, q8, acc);
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
return make_wider(scales16);
}
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
auto m1 = vdupq_n_u8(1);
auto shuffle = vdupq_n_u8(8*j);
bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
}
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
}
uint32_t aux32[4];
uint8x16_t scales8;
Q2bits bits;
};
}
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
auto etypeA = ggml_type(typeA);
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
: GGML_TYPE_Q8_K;
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
return false;
}
func16 = nullptr;
switch (typeA) {
case GGML_TYPE_Q2_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ2K, kernels)
break;
case GGML_TYPE_Q3_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ3K, kernels)
break;
case GGML_TYPE_Q4_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ4K, kernels)
break;
case GGML_TYPE_Q5_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ5K, kernels)
break;
case GGML_TYPE_Q6_K:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ6K, kernels)
break;
// case GGML_TYPE_IQ4_XS:
// set_functions<DequantizerIQ4XS>(kernels);
// break;
// case GGML_TYPE_Q2_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
// break;
// case GGML_TYPE_Q3_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
// break;
// case GGML_TYPE_Q4_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
// break;
// case GGML_TYPE_Q5_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
// break;
// case GGML_TYPE_Q6_K_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
// break;
// case GGML_TYPE_IQ4_XS_R8:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels)
// break;
// case GGML_TYPE_Q8_K_R8:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
//#ifdef HAVE_FANCY_SIMD
// func16 = mul_mat_q8_k_r8_q8_k<16>;
//#endif
// break;
// case GGML_TYPE_Q8_KV:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
//#ifdef HAVE_FANCY_SIMD
// func16 = mul_mat_q8_KV_q8_KV<16>;
//#endif
// break;
// case GGML_TYPE_Q8_KV_R8:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
// break;
default:
return false;
}
return true;
}
#endif
#endif

View File

@@ -873,195 +873,6 @@ struct Scales8 {
}
};
struct Q4bits {
const uint8x16_t m4b = vdupq_n_u8(0xf);
uint8x16x4_t b1, b2;
inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
b.val[0] = vandq_u8(val[0], m4b);
b.val[2] = vshrq_n_u8(val[0], 4);
b.val[1] = vandq_u8(val[1], m4b);
b.val[3] = vshrq_n_u8(val[1], 4);
}
inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
b.val[0] = vandq_u8(val[0], m4b);
b.val[1] = vshrq_n_u8(val[0], 4);
b.val[2] = vandq_u8(val[1], m4b);
b.val[3] = vshrq_n_u8(val[1], 4);
}
inline void prepare(const uint8_t * qs) {
auto q4bits = vld1q_u8_x2(qs);
prepare4(b1, q4bits.val);
q4bits = vld1q_u8_x2(qs+32);
prepare4(b2, q4bits.val);
}
inline void prepare_v2(const uint8_t * qs) {
auto q4bits = vld1q_u8_x4(qs);
prepare4(b1, q4bits.val+0);
prepare4(b2, q4bits.val+2);
}
inline void prepare64(const uint8_t * qs) {
auto q4bits = vld1q_u8_x4(qs);
b1.val[0] = vandq_u8(q4bits.val[0], m4b);
b1.val[1] = vandq_u8(q4bits.val[1], m4b);
b1.val[2] = vandq_u8(q4bits.val[2], m4b);
b1.val[3] = vandq_u8(q4bits.val[3], m4b);
b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
}
inline void prepare16(const uint8_t * qs) {
auto q4bits = vld1q_u8_x2(qs);
prepare4_16(b1, q4bits.val);
q4bits = vld1q_u8_x2(qs+32);
prepare4_16(b2, q4bits.val);
}
inline void prepare16_v2(const uint8_t * qs) {
auto q4bits = vld1q_u8_x4(qs);
prepare4_16(b1, q4bits.val+0);
prepare4_16(b2, q4bits.val+2);
}
};
struct Q2bits {
const uint8x16_t m4b = vdupq_n_u8(0x03);
uint8x16x4_t b1, b2;
inline void prepare(const uint8_t * qs) {
auto q2bits = vld1q_u8_x2(qs);
b1.val[0] = vandq_u8(q2bits.val[0], m4b);
b1.val[1] = vandq_u8(q2bits.val[1], m4b);
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
b1.val[2] = vandq_u8(q2bits.val[0], m4b);
b1.val[3] = vandq_u8(q2bits.val[1], m4b);
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
b2.val[0] = vandq_u8(q2bits.val[0], m4b);
b2.val[1] = vandq_u8(q2bits.val[1], m4b);
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
b2.val[2] = vandq_u8(q2bits.val[0], m4b);
b2.val[3] = vandq_u8(q2bits.val[1], m4b);
}
};
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
struct BaseDequantizer {
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
inline void new_row(int ix) {
if constexpr (has_row_scale) {
if constexpr (scale_is_f16) {
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
d = GGML_FP16_TO_FP32(*dptr);
x = (const block_q *)(dptr + 1);
} else {
const float * dptr = (const float *)((const char *)vx + ix*bx);
d = *dptr;
x = (const block_q *)(dptr + 1);
}
} else {
x = (const block_q *)((const char *)vx + ix*bx);
}
}
const void * vx;
const block_q * x;
const size_t bx;
const int nrc;
float d;
};
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
return s8.process_scales_mins(x[i], q8, i, acc);
}
inline void prepare(int i, int j) {
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
else bits.prepare(x[i].qs+64*j);
}
Q4bits bits;
Scales8 s8;
};
struct HighBit5 {
const uint8x16_t mhb = vdupq_n_u8(0x10);
uint8x16x2_t bits;
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
if (do_shift) {
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
}
}
};
struct HighBit3 {
const uint8x16_t mhb = vdupq_n_u8(0x04);
uint8x16x2_t bits;
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
if (do_shift) {
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
}
}
};
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 8; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].qh);
return s8.process_scales_mins(x[i], q8, i, acc);
}
inline void prepare(int i, int j) {
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
else bits.prepare(x[i].qs+64*j);
h.apply(bits.b1, bits.b2, j == 0);
}
Q4bits bits;
HighBit5 h;
Scales8 s8;
uint8x16x2_t hbits;
};
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
int32x4x4_t scales = {
vmovl_s16(vget_low_s16 (scales16.val[0])),
@@ -1081,171 +892,6 @@ inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8
return make_wider(scales16);
}
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
}
inline void prepare(int i, int j) {
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
bits.prepare64(x[i].ql+64*j);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
}
Q4bits bits;
const uint8x16_t mhb = vdupq_n_u8(0x30);
};
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return false; }
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
h.bits = vld1q_u8_x2(x[i].hmask);
mask = vdupq_n_u8(0x01);
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
uint32_t aux2 = sc16[4] | (sc16[5] << 16);
aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
if (nrc > 1) {
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
}
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
return make_wider(scales16);
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
if (nrc > 1) {
h.apply(bits.b1, bits.b2, j == 0);
} else {
auto minus4 = vdupq_n_u8(0xfc);
auto zero = vdupq_n_u8(0);
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
mask = vshlq_n_u8(mask, 1);
}
}
uint32_t aux32[4];
Q2bits bits;
uint8x16_t mask;
HighBit3 h;
};
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return true; }
template <typename Q8>
inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
auto scales_and_mins = vld1q_u8(x[i].scales);
auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
}
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
process_scales(i, q8, acc);
int16x8x2_t scales16;
scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
return make_wider(scales16);
}
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
auto m1 = vdupq_n_u8(1);
auto shuffle = vdupq_n_u8(8*j);
bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
}
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
}
uint32_t aux32[4];
uint8x16_t scales8;
Q2bits bits;
};
// ============================= i-quants
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
@@ -1969,64 +1615,6 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
};
template <typename Dequantizer, int nrc_y>
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_K> q8(info);
Dequantizer deq(vx, bx, nrc_y);
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
float32x4_t acc[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
for (int i = 0; i < nb; ++i) {
int32x4_t sumi[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
deq.process_scales(i, q8, acc);
deq.prepare(i, 0);
deq.compute(q8, i, 0, sumi);
deq.prepare(i, 1);
deq.compute(q8, i, 1, sumi);
} else {
if constexpr (Dequantizer::num_blocks() == 8) {
auto scales = deq.new_block(i, q8, acc);
deq.prepare(i, 0);
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
deq.prepare(i, 1);
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
}
else if constexpr (Dequantizer::num_blocks() == 16) {
auto scales = deq.new_block(i, q8, acc);
deq.prepare(i, 0);
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
deq.prepare(i, 1);
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
}
else {
GGML_ASSERT(false);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(acc[iy]));
}
}
}
// =========================================== Legacy quants
template <typename Block>
@@ -5095,20 +4683,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
switch (typeA) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);
break;
case GGML_TYPE_Q3_K:
MulMat::set_functions<DequantizerQ3K>(m);
break;
case GGML_TYPE_Q4_K:
MulMat::set_functions<DequantizerQ4K>(m);
break;
case GGML_TYPE_Q5_K:
MulMat::set_functions<DequantizerQ5K>(m);
break;
case GGML_TYPE_Q6_K:
MulMat::set_functions<DequantizerQ6K>(m);
break;
return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ4_XS:
MulMat::set_functions<DequantizerIQ4XS>(m);
break;