mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Refactor iqk: factor out k-quants (NEON)
This commit is contained in:
@@ -526,6 +526,7 @@ struct Q4Bits {
|
||||
#endif
|
||||
|
||||
#else
|
||||
// ------------------------------------ __aarch64__ --------------------------------------------------
|
||||
|
||||
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
||||
|
||||
@@ -547,6 +548,214 @@ template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
|
||||
const block_q8 * y[nrc_y];
|
||||
};
|
||||
|
||||
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
|
||||
struct BaseDequantizer {
|
||||
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
|
||||
inline void new_row(int ix) {
|
||||
if constexpr (has_row_scale) {
|
||||
if constexpr (scale_is_f16) {
|
||||
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
|
||||
d = GGML_FP16_TO_FP32(*dptr);
|
||||
x = (const block_q *)(dptr + 1);
|
||||
} else {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
d = *dptr;
|
||||
x = (const block_q *)(dptr + 1);
|
||||
}
|
||||
} else {
|
||||
x = (const block_q *)((const char *)vx + ix*bx);
|
||||
}
|
||||
}
|
||||
const void * vx;
|
||||
const block_q * x;
|
||||
const size_t bx;
|
||||
const int nrc;
|
||||
float d;
|
||||
};
|
||||
|
||||
struct Q4bits {
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
uint8x16x4_t b1, b2;
|
||||
inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
|
||||
b.val[0] = vandq_u8(val[0], m4b);
|
||||
b.val[2] = vshrq_n_u8(val[0], 4);
|
||||
b.val[1] = vandq_u8(val[1], m4b);
|
||||
b.val[3] = vshrq_n_u8(val[1], 4);
|
||||
}
|
||||
inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
|
||||
b.val[0] = vandq_u8(val[0], m4b);
|
||||
b.val[1] = vshrq_n_u8(val[0], 4);
|
||||
b.val[2] = vandq_u8(val[1], m4b);
|
||||
b.val[3] = vshrq_n_u8(val[1], 4);
|
||||
}
|
||||
inline void prepare(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x2(qs);
|
||||
prepare4(b1, q4bits.val);
|
||||
q4bits = vld1q_u8_x2(qs+32);
|
||||
prepare4(b2, q4bits.val);
|
||||
}
|
||||
inline void prepare_v2(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
prepare4(b1, q4bits.val+0);
|
||||
prepare4(b2, q4bits.val+2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
b1.val[0] = vandq_u8(q4bits.val[0], m4b);
|
||||
b1.val[1] = vandq_u8(q4bits.val[1], m4b);
|
||||
b1.val[2] = vandq_u8(q4bits.val[2], m4b);
|
||||
b1.val[3] = vandq_u8(q4bits.val[3], m4b);
|
||||
b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
|
||||
b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
|
||||
b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
|
||||
b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
|
||||
}
|
||||
inline void prepare16(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x2(qs);
|
||||
prepare4_16(b1, q4bits.val);
|
||||
q4bits = vld1q_u8_x2(qs+32);
|
||||
prepare4_16(b2, q4bits.val);
|
||||
}
|
||||
inline void prepare16_v2(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
prepare4_16(b1, q4bits.val+0);
|
||||
prepare4_16(b2, q4bits.val+2);
|
||||
}
|
||||
};
|
||||
|
||||
struct Q2bits {
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x03);
|
||||
uint8x16x4_t b1, b2;
|
||||
inline void prepare(const uint8_t * qs) {
|
||||
auto q2bits = vld1q_u8_x2(qs);
|
||||
b1.val[0] = vandq_u8(q2bits.val[0], m4b);
|
||||
b1.val[1] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b1.val[2] = vandq_u8(q2bits.val[0], m4b);
|
||||
b1.val[3] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b2.val[0] = vandq_u8(q2bits.val[0], m4b);
|
||||
b2.val[1] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b2.val[2] = vandq_u8(q2bits.val[0], m4b);
|
||||
b2.val[3] = vandq_u8(q2bits.val[1], m4b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Q8>
|
||||
static inline void compute_8_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
||||
const int32x4x2_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
||||
auto mzero = vdupq_n_s32(0);
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
auto p1 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
||||
vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1]); // block 1
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
auto p2 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
||||
vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1]); // block 2
|
||||
auto p12 = vpaddq_s32(p1, p2);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
auto p3 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
||||
vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1]); // block 1
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
auto p4 = ggml_vdotq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
||||
vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1]); // block 2
|
||||
auto p34 = vpaddq_s32(p3, p4);
|
||||
|
||||
auto pall = vpaddq_s32(p12, p34);
|
||||
sumi = vmlaq_s32(sumi, scales.val[j], pall);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_t& qx_2, const Q8& q8,
|
||||
const int32x4x4_t& scales, int iy, int i, int j, int32x4_t& sumi) {
|
||||
|
||||
auto mzero = vdupq_n_s32(0);
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
auto p1 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[0]), q8b_1.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[1]), q8b_1.val[1])); // blocks 0, 0, 1, 1,
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
auto p2 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[2]), q8b_2.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_1.val[3]), q8b_2.val[1])); // blocks 3, 3, 4, 4,
|
||||
auto p12 = vpaddq_s32(p1, p2); // blocks 0, 1, 2, 3
|
||||
sumi = vmlaq_s32(sumi, scales.val[2*j+0], p12);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
auto p3 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[0]), q8b_3.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[1]), q8b_3.val[1])); // block 4, 4, 5, 5,
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
auto p4 = vpaddq_s32(ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[2]), q8b_4.val[0]),
|
||||
ggml_vdotq_s32(mzero, vreinterpretq_s8_u8(qx_2.val[3]), q8b_4.val[1])); // block 6, 6, 7, 7,
|
||||
auto p34 = vpaddq_s32(p3, p4); // blocks 4, 5, 6, 7
|
||||
sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34);
|
||||
}
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx, nrc_y);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
float32x4_t acc[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
int32x4_t sumi[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
|
||||
|
||||
if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
|
||||
deq.process_scales(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
deq.compute(q8, i, 0, sumi);
|
||||
deq.prepare(i, 1);
|
||||
deq.compute(q8, i, 1, sumi);
|
||||
} else {
|
||||
if constexpr (Dequantizer::num_blocks() == 8) {
|
||||
auto scales = deq.new_block(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
||||
deq.prepare(i, 1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
||||
}
|
||||
else if constexpr (Dequantizer::num_blocks() == 16) {
|
||||
auto scales = deq.new_block(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
||||
deq.prepare(i, 1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(acc[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1781,6 +1781,397 @@ bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_
|
||||
#else
|
||||
// --------------------------------- __aarch64__ --------------------------------------
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Q8>
|
||||
inline void accum_mins_8(const int16x8_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums8(iy, i);
|
||||
int32x4_t b1 = vmull_s16(vget_low_s16(mins), vget_low_s16(q8s));
|
||||
int32x4_t b2 = vmull_s16(vget_high_s16(mins), vget_high_s16(q8s));
|
||||
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(b1, b2));
|
||||
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
template <typename Q8>
|
||||
inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8s = q8.load_bsums(iy, i);
|
||||
int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
|
||||
int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
|
||||
int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
|
||||
int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
|
||||
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
|
||||
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
|
||||
struct Scales8 {
|
||||
uint32_t utmp[4];
|
||||
const uint8_t * sc8 = (const uint8_t *)utmp;
|
||||
template <typename Q8, typename Qx>
|
||||
inline int32x4x2_t process_scales_mins(const Qx& x, const Q8& q8, int i, float32x4_t * acc) {
|
||||
make_q4_scales(x.scales, utmp);
|
||||
int16x8_t mins = vmovl_s8(vld1_s8((const int8_t *)sc8 + 8));
|
||||
accum_mins_8(mins, q8, acc, i, -GGML_FP16_TO_FP32(x.dmin));
|
||||
|
||||
uint8x8_t scales8 = vld1_u8(sc8);
|
||||
uint16x8_t scales16 = vmovl_u8(scales8);
|
||||
int32x4x2_t scales = {vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales16))),
|
||||
vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales16)))};
|
||||
return scales;
|
||||
}
|
||||
};
|
||||
|
||||
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return s8.process_scales_mins(x[i], q8, i, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
|
||||
else bits.prepare(x[i].qs+64*j);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
Scales8 s8;
|
||||
|
||||
};
|
||||
|
||||
struct HighBit5 {
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x10);
|
||||
uint8x16x2_t bits;
|
||||
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
|
||||
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
|
||||
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
|
||||
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
|
||||
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
|
||||
|
||||
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
|
||||
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
|
||||
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
|
||||
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
|
||||
|
||||
if (do_shift) {
|
||||
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
|
||||
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct HighBit3 {
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x04);
|
||||
uint8x16x2_t bits;
|
||||
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
|
||||
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
|
||||
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
|
||||
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
|
||||
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
|
||||
|
||||
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
|
||||
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
|
||||
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
|
||||
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
|
||||
|
||||
if (do_shift) {
|
||||
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
|
||||
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
|
||||
DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
h.bits = vld1q_u8_x2(x[i].qh);
|
||||
return s8.process_scales_mins(x[i], q8, i, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
|
||||
else bits.prepare(x[i].qs+64*j);
|
||||
h.apply(bits.b1, bits.b2, j == 0);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
HighBit5 h;
|
||||
Scales8 s8;
|
||||
|
||||
uint8x16x2_t hbits;
|
||||
|
||||
};
|
||||
|
||||
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
|
||||
int32x4x4_t scales = {
|
||||
vmovl_s16(vget_low_s16 (scales16.val[0])),
|
||||
vmovl_s16(vget_high_s16(scales16.val[0])),
|
||||
vmovl_s16(vget_low_s16 (scales16.val[1])),
|
||||
vmovl_s16(vget_high_s16(scales16.val[1])),
|
||||
};
|
||||
return scales;
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8, float32x4_t * acc, int i, float c) {
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
|
||||
accum_mins_16(scales16, q8, acc, i, c);
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
|
||||
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
|
||||
|
||||
bits.prepare64(x[i].ql+64*j);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
|
||||
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
|
||||
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x30);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
h.bits = vld1q_u8_x2(x[i].hmask);
|
||||
mask = vdupq_n_u8(0x01);
|
||||
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
|
||||
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
|
||||
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
|
||||
uint32_t aux2 = sc16[4] | (sc16[5] << 16);
|
||||
aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
|
||||
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
|
||||
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
|
||||
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
|
||||
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
|
||||
if (nrc > 1) {
|
||||
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
|
||||
}
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
if (nrc > 1) {
|
||||
h.apply(bits.b1, bits.b2, j == 0);
|
||||
} else {
|
||||
auto minus4 = vdupq_n_u8(0xfc);
|
||||
auto zero = vdupq_n_u8(0);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t aux32[4];
|
||||
|
||||
Q2bits bits;
|
||||
|
||||
uint8x16_t mask;
|
||||
HighBit3 h;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return true; }
|
||||
|
||||
template <typename Q8>
|
||||
inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
auto scales_and_mins = vld1q_u8(x[i].scales);
|
||||
auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
|
||||
accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
|
||||
|
||||
scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
process_scales(i, q8, acc);
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
|
||||
auto m1 = vdupq_n_u8(1);
|
||||
auto shuffle = vdupq_n_u8(8*j);
|
||||
bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
|
||||
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
|
||||
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
}
|
||||
|
||||
uint32_t aux32[4];
|
||||
|
||||
uint8x16_t scales8;
|
||||
|
||||
Q2bits bits;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_kquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
|
||||
|
||||
auto etypeA = ggml_type(typeA);
|
||||
auto expected_type_B = etypeA == GGML_TYPE_IQ4_XS_R8 || etypeA == GGML_TYPE_Q4_K_R4 || etypeA == GGML_TYPE_Q5_K_R4 ? GGML_TYPE_Q8_K32
|
||||
: etypeA == GGML_TYPE_Q8_K_R8 ? GGML_TYPE_Q8_KR8
|
||||
: etypeA == GGML_TYPE_Q8_KV || etypeA == GGML_TYPE_Q8_KV_R8 ? GGML_TYPE_Q8_KV
|
||||
: GGML_TYPE_Q8_K;
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
|
||||
return false;
|
||||
}
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_Q2_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ2K, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ3K, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ4K, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ5K, kernels)
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerQ6K, kernels)
|
||||
break;
|
||||
// case GGML_TYPE_IQ4_XS:
|
||||
// set_functions<DequantizerIQ4XS>(kernels);
|
||||
// break;
|
||||
// case GGML_TYPE_Q2_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q2_k_r4_q8_k, kernels)
|
||||
// break;
|
||||
// case GGML_TYPE_Q3_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q3_k_r4_q8_k, kernels)
|
||||
// break;
|
||||
// case GGML_TYPE_Q4_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_k_r4_q8_k, kernels)
|
||||
// break;
|
||||
// case GGML_TYPE_Q5_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_k_r4_q8_k, kernels)
|
||||
// break;
|
||||
// case GGML_TYPE_Q6_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_k_r4_q8_k, kernels)
|
||||
// break;
|
||||
// case GGML_TYPE_IQ4_XS_R8:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_xs_r8_q8_k_avx2, kernels)
|
||||
// break;
|
||||
// case GGML_TYPE_Q8_K_R8:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_k_r8_q8_k, kernels)
|
||||
//#ifdef HAVE_FANCY_SIMD
|
||||
// func16 = mul_mat_q8_k_r8_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
// case GGML_TYPE_Q8_KV:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_q8_KV, kernels)
|
||||
//#ifdef HAVE_FANCY_SIMD
|
||||
// func16 = mul_mat_q8_KV_q8_KV<16>;
|
||||
//#endif
|
||||
// break;
|
||||
// case GGML_TYPE_Q8_KV_R8:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_KV_r8_q8_KV, kernels);
|
||||
// break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -873,195 +873,6 @@ struct Scales8 {
|
||||
}
|
||||
};
|
||||
|
||||
struct Q4bits {
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
uint8x16x4_t b1, b2;
|
||||
inline void prepare4(uint8x16x4_t& b, const uint8x16_t * val) const {
|
||||
b.val[0] = vandq_u8(val[0], m4b);
|
||||
b.val[2] = vshrq_n_u8(val[0], 4);
|
||||
b.val[1] = vandq_u8(val[1], m4b);
|
||||
b.val[3] = vshrq_n_u8(val[1], 4);
|
||||
}
|
||||
inline void prepare4_16(uint8x16x4_t& b, const uint8x16_t * val) const {
|
||||
b.val[0] = vandq_u8(val[0], m4b);
|
||||
b.val[1] = vshrq_n_u8(val[0], 4);
|
||||
b.val[2] = vandq_u8(val[1], m4b);
|
||||
b.val[3] = vshrq_n_u8(val[1], 4);
|
||||
}
|
||||
inline void prepare(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x2(qs);
|
||||
prepare4(b1, q4bits.val);
|
||||
q4bits = vld1q_u8_x2(qs+32);
|
||||
prepare4(b2, q4bits.val);
|
||||
}
|
||||
inline void prepare_v2(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
prepare4(b1, q4bits.val+0);
|
||||
prepare4(b2, q4bits.val+2);
|
||||
}
|
||||
inline void prepare64(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
b1.val[0] = vandq_u8(q4bits.val[0], m4b);
|
||||
b1.val[1] = vandq_u8(q4bits.val[1], m4b);
|
||||
b1.val[2] = vandq_u8(q4bits.val[2], m4b);
|
||||
b1.val[3] = vandq_u8(q4bits.val[3], m4b);
|
||||
b2.val[0] = vshrq_n_u8(q4bits.val[0], 4);
|
||||
b2.val[1] = vshrq_n_u8(q4bits.val[1], 4);
|
||||
b2.val[2] = vshrq_n_u8(q4bits.val[2], 4);
|
||||
b2.val[3] = vshrq_n_u8(q4bits.val[3], 4);
|
||||
}
|
||||
inline void prepare16(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x2(qs);
|
||||
prepare4_16(b1, q4bits.val);
|
||||
q4bits = vld1q_u8_x2(qs+32);
|
||||
prepare4_16(b2, q4bits.val);
|
||||
}
|
||||
inline void prepare16_v2(const uint8_t * qs) {
|
||||
auto q4bits = vld1q_u8_x4(qs);
|
||||
prepare4_16(b1, q4bits.val+0);
|
||||
prepare4_16(b2, q4bits.val+2);
|
||||
}
|
||||
};
|
||||
|
||||
struct Q2bits {
|
||||
const uint8x16_t m4b = vdupq_n_u8(0x03);
|
||||
uint8x16x4_t b1, b2;
|
||||
inline void prepare(const uint8_t * qs) {
|
||||
auto q2bits = vld1q_u8_x2(qs);
|
||||
b1.val[0] = vandq_u8(q2bits.val[0], m4b);
|
||||
b1.val[1] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b1.val[2] = vandq_u8(q2bits.val[0], m4b);
|
||||
b1.val[3] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b2.val[0] = vandq_u8(q2bits.val[0], m4b);
|
||||
b2.val[1] = vandq_u8(q2bits.val[1], m4b);
|
||||
|
||||
q2bits.val[0] = vshrq_n_u8(q2bits.val[0], 2);
|
||||
q2bits.val[1] = vshrq_n_u8(q2bits.val[1], 2);
|
||||
b2.val[2] = vandq_u8(q2bits.val[0], m4b);
|
||||
b2.val[3] = vandq_u8(q2bits.val[1], m4b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename block_q, bool has_row_scale = false, bool scale_is_f16 = false>
|
||||
struct BaseDequantizer {
|
||||
BaseDequantizer(const void * vx, size_t bx, int nrc) : vx(vx), x(nullptr), bx(bx), nrc(nrc) {}
|
||||
inline void new_row(int ix) {
|
||||
if constexpr (has_row_scale) {
|
||||
if constexpr (scale_is_f16) {
|
||||
const ggml_half * dptr = (const ggml_half *)((const char *)vx + ix*bx);
|
||||
d = GGML_FP16_TO_FP32(*dptr);
|
||||
x = (const block_q *)(dptr + 1);
|
||||
} else {
|
||||
const float * dptr = (const float *)((const char *)vx + ix*bx);
|
||||
d = *dptr;
|
||||
x = (const block_q *)(dptr + 1);
|
||||
}
|
||||
} else {
|
||||
x = (const block_q *)((const char *)vx + ix*bx);
|
||||
}
|
||||
}
|
||||
const void * vx;
|
||||
const block_q * x;
|
||||
const size_t bx;
|
||||
const int nrc;
|
||||
float d;
|
||||
};
|
||||
|
||||
struct DequantizerQ4K final : public BaseDequantizer<block_q4_K> {
|
||||
DequantizerQ4K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return s8.process_scales_mins(x[i], q8, i, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
|
||||
else bits.prepare(x[i].qs+64*j);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
Scales8 s8;
|
||||
|
||||
};
|
||||
|
||||
struct HighBit5 {
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x10);
|
||||
uint8x16x2_t bits;
|
||||
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
|
||||
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 4), mhb));
|
||||
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 4), mhb));
|
||||
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 3), mhb));
|
||||
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 3), mhb));
|
||||
|
||||
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
|
||||
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
|
||||
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
|
||||
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
|
||||
|
||||
if (do_shift) {
|
||||
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
|
||||
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct HighBit3 {
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x04);
|
||||
uint8x16x2_t bits;
|
||||
inline void apply(uint8x16x4_t& b1, uint8x16x4_t& b2, bool do_shift) {
|
||||
b1.val[0] = vorrq_u8(b1.val[0], vandq_u8(vshlq_n_u8(bits.val[0], 2), mhb));
|
||||
b1.val[1] = vorrq_u8(b1.val[1], vandq_u8(vshlq_n_u8(bits.val[1], 2), mhb));
|
||||
b1.val[2] = vorrq_u8(b1.val[2], vandq_u8(vshlq_n_u8(bits.val[0], 1), mhb));
|
||||
b1.val[3] = vorrq_u8(b1.val[3], vandq_u8(vshlq_n_u8(bits.val[1], 1), mhb));
|
||||
|
||||
b2.val[0] = vorrq_u8(b2.val[0], vandq_u8(bits.val[0], mhb));
|
||||
b2.val[1] = vorrq_u8(b2.val[1], vandq_u8(bits.val[1], mhb));
|
||||
b2.val[2] = vorrq_u8(b2.val[2], vandq_u8(vshrq_n_u8(bits.val[0], 1), mhb));
|
||||
b2.val[3] = vorrq_u8(b2.val[3], vandq_u8(vshrq_n_u8(bits.val[1], 1), mhb));
|
||||
|
||||
if (do_shift) {
|
||||
bits.val[0] = vshrq_n_u8(bits.val[0], 4);
|
||||
bits.val[1] = vshrq_n_u8(bits.val[1], 4);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct DequantizerQ5K final : public BaseDequantizer<block_q5_K> {
|
||||
DequantizerQ5K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 8; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x2_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
h.bits = vld1q_u8_x2(x[i].qh);
|
||||
return s8.process_scales_mins(x[i], q8, i, acc);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
if (nrc == 1) bits.prepare_v2(x[i].qs+64*j);
|
||||
else bits.prepare(x[i].qs+64*j);
|
||||
h.apply(bits.b1, bits.b2, j == 0);
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
HighBit5 h;
|
||||
Scales8 s8;
|
||||
|
||||
uint8x16x2_t hbits;
|
||||
|
||||
};
|
||||
|
||||
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
|
||||
int32x4x4_t scales = {
|
||||
vmovl_s16(vget_low_s16 (scales16.val[0])),
|
||||
@@ -1081,171 +892,6 @@ inline int32x4x4_t process_scales_mins_16(const int8x16_t& scales8, const Q8& q8
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
|
||||
DequantizerQ6K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
return process_scales_mins_16(vld1q_s8(x[i].scales), q8, acc, i, -32.f*d);
|
||||
}
|
||||
inline void prepare(int i, int j) {
|
||||
|
||||
auto hbits = vld1q_u8_x2(x[i].qh + 32*j);
|
||||
|
||||
bits.prepare64(x[i].ql+64*j);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(vshlq_n_u8(hbits.val[0], 4), mhb));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(vshlq_n_u8(hbits.val[1], 4), mhb));
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(vshlq_n_u8(hbits.val[0], 2), mhb));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(vshlq_n_u8(hbits.val[1], 2), mhb));
|
||||
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(hbits.val[0], mhb));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(hbits.val[1], mhb));
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(vshrq_n_u8(hbits.val[0], 2), mhb));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(vshrq_n_u8(hbits.val[1], 2), mhb));
|
||||
|
||||
}
|
||||
|
||||
Q4bits bits;
|
||||
|
||||
const uint8x16_t mhb = vdupq_n_u8(0x30);
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ3K final : public BaseDequantizer<block_q3_K> {
|
||||
DequantizerQ3K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return false; }
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
h.bits = vld1q_u8_x2(x[i].hmask);
|
||||
mask = vdupq_n_u8(0x01);
|
||||
const uint16_t * sc16 = (const uint16_t *)x[i].scales;
|
||||
uint32_t aux0 = sc16[0] | (sc16[1] << 16);
|
||||
uint32_t aux1 = sc16[2] | (sc16[3] << 16);
|
||||
uint32_t aux2 = sc16[4] | (sc16[5] << 16);
|
||||
aux32[0] = (aux0 & 0x0f0f0f0f) | ((aux2 << 4) & 0x30303030);
|
||||
aux32[1] = (aux1 & 0x0f0f0f0f) | ((aux2 << 2) & 0x30303030);
|
||||
aux32[2] = ((aux0 >> 4) & 0x0f0f0f0f) | ((aux2 >> 0) & 0x30303030);
|
||||
aux32[3] = ((aux1 >> 4) & 0x0f0f0f0f) | ((aux2 >> 2) & 0x30303030);
|
||||
auto scales8 = vaddq_s8(vld1q_s8((const int8_t *)aux32), vdupq_n_s8(-32));
|
||||
if (nrc > 1) {
|
||||
return process_scales_mins_16(scales8, q8, acc, i, -4.f*d);
|
||||
}
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(scales8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(scales8));
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
if (nrc > 1) {
|
||||
h.apply(bits.b1, bits.b2, j == 0);
|
||||
} else {
|
||||
auto minus4 = vdupq_n_u8(0xfc);
|
||||
auto zero = vdupq_n_u8(0);
|
||||
bits.b1.val[0] = vorrq_u8(bits.b1.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b1.val[1] = vorrq_u8(bits.b1.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b1.val[2] = vorrq_u8(bits.b1.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b1.val[3] = vorrq_u8(bits.b1.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b2.val[0] = vorrq_u8(bits.b2.val[0], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b2.val[1] = vorrq_u8(bits.b2.val[1], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
bits.b2.val[2] = vorrq_u8(bits.b2.val[2], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[0], mask), zero)));
|
||||
bits.b2.val[3] = vorrq_u8(bits.b2.val[3], vandq_u8(minus4, vceqq_u8(vandq_u8(h.bits.val[1], mask), zero)));
|
||||
mask = vshlq_n_u8(mask, 1);
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t aux32[4];
|
||||
|
||||
Q2bits bits;
|
||||
|
||||
uint8x16_t mask;
|
||||
HighBit3 h;
|
||||
|
||||
};
|
||||
|
||||
struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
|
||||
DequantizerQ2K(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
|
||||
|
||||
constexpr static int num_blocks() { return 16; }
|
||||
constexpr static bool should_scale_quants() { return true; }
|
||||
|
||||
template <typename Q8>
|
||||
inline void process_scales(int i, const Q8& q8, float32x4_t * acc) {
|
||||
d = GGML_FP16_TO_FP32(x[i].d);
|
||||
auto scales_and_mins = vld1q_u8(x[i].scales);
|
||||
auto mins8 = vreinterpretq_s8_u8(vshrq_n_u8(scales_and_mins, 4));
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(mins8));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(mins8));
|
||||
accum_mins_16(scales16, q8, acc, i, -GGML_FP16_TO_FP32(x[i].dmin));
|
||||
|
||||
scales8 = vandq_u8(scales_and_mins, vdupq_n_u8(0xf));
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
|
||||
process_scales(i, q8, acc);
|
||||
int16x8x2_t scales16;
|
||||
scales16.val[0] = vmovl_s8(vget_low_s8(vreinterpretq_s8_u8(scales8)));
|
||||
scales16.val[1] = vmovl_s8(vget_high_s8(vreinterpretq_s8_u8(scales8)));
|
||||
return make_wider(scales16);
|
||||
}
|
||||
|
||||
template <typename Q8>
|
||||
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
|
||||
auto m1 = vdupq_n_u8(1);
|
||||
auto shuffle = vdupq_n_u8(8*j);
|
||||
bits.b1.val[0] = vmulq_u8(bits.b1.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b1.val[1] = vmulq_u8(bits.b1.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b1.val[2] = vmulq_u8(bits.b1.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b1.val[3] = vmulq_u8(bits.b1.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[0] = vmulq_u8(bits.b2.val[0], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[1] = vmulq_u8(bits.b2.val[1], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[2] = vmulq_u8(bits.b2.val[2], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
bits.b2.val[3] = vmulq_u8(bits.b2.val[3], vqtbl1q_u8(scales8, shuffle)); shuffle = vaddq_u8(shuffle, m1);
|
||||
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
|
||||
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
|
||||
|
||||
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
|
||||
|
||||
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
|
||||
|
||||
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
|
||||
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
|
||||
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void prepare(int i, int j) {
|
||||
bits.prepare(x[i].qs+32*j);
|
||||
}
|
||||
|
||||
uint32_t aux32[4];
|
||||
|
||||
uint8x16_t scales8;
|
||||
|
||||
Q2bits bits;
|
||||
|
||||
};
|
||||
|
||||
// ============================= i-quants
|
||||
|
||||
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
|
||||
@@ -1969,64 +1615,6 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
|
||||
|
||||
};
|
||||
|
||||
template <typename Dequantizer, int nrc_y>
|
||||
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
|
||||
Dequantizer deq(vx, bx, nrc_y);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
deq.new_row(ix);
|
||||
|
||||
float32x4_t acc[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
int32x4_t sumi[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
|
||||
|
||||
if constexpr (nrc_y > 1 && Dequantizer::should_scale_quants()) {
|
||||
deq.process_scales(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
deq.compute(q8, i, 0, sumi);
|
||||
deq.prepare(i, 1);
|
||||
deq.compute(q8, i, 1, sumi);
|
||||
} else {
|
||||
if constexpr (Dequantizer::num_blocks() == 8) {
|
||||
auto scales = deq.new_block(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
||||
deq.prepare(i, 1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_8_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
||||
}
|
||||
else if constexpr (Dequantizer::num_blocks() == 16) {
|
||||
auto scales = deq.new_block(i, q8, acc);
|
||||
deq.prepare(i, 0);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 0, sumi[iy]);
|
||||
deq.prepare(i, 1);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) compute_16_blocks(deq.bits.b1, deq.bits.b2, q8, scales, iy, i, 1, sumi[iy]);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(acc[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================== Legacy quants
|
||||
|
||||
template <typename Block>
|
||||
@@ -5095,20 +4683,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_Q2_K:
|
||||
MulMat::set_functions<DequantizerQ2K>(m);
|
||||
break;
|
||||
case GGML_TYPE_Q3_K:
|
||||
MulMat::set_functions<DequantizerQ3K>(m);
|
||||
break;
|
||||
case GGML_TYPE_Q4_K:
|
||||
MulMat::set_functions<DequantizerQ4K>(m);
|
||||
break;
|
||||
case GGML_TYPE_Q5_K:
|
||||
MulMat::set_functions<DequantizerQ5K>(m);
|
||||
break;
|
||||
case GGML_TYPE_Q6_K:
|
||||
MulMat::set_functions<DequantizerQ6K>(m);
|
||||
break;
|
||||
return iqk_set_kernels_kquants(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ4_XS:
|
||||
MulMat::set_functions<DequantizerIQ4XS>(m);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user