iq2_tn: NEON

For TriLM-3.9B running on the M2-Max we get PP-512 = 193.5 t/s,
TG-128 = 75.5 t/s. This is in line with what we have for
iq2_bn ant 3.3B Bitnet.
This commit is contained in:
Iwan Kawrakow
2024-08-06 06:19:59 +02:00
parent 810285581c
commit e528505fc8

View File

@@ -3876,62 +3876,6 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
float d;
};
struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return true; }
template <typename Q8>
inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
}
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
process_scales(i, q8, acc);
return { vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1) };
}
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
}
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
auto m1 = vdupq_n_s8(1);
bits.b1.val[0] = vsubq_s8(bits.b1.val[0], m1);
bits.b1.val[1] = vsubq_s8(bits.b1.val[1], m1);
bits.b1.val[2] = vsubq_s8(bits.b1.val[2], m1);
bits.b1.val[3] = vsubq_s8(bits.b1.val[3], m1);
bits.b2.val[0] = vsubq_s8(bits.b2.val[0], m1);
bits.b2.val[1] = vsubq_s8(bits.b2.val[1], m1);
bits.b2.val[2] = vsubq_s8(bits.b2.val[2], m1);
bits.b2.val[3] = vsubq_s8(bits.b2.val[3], m1);
}
Q2bits bits;
float d;
};
// ============================= i-quants
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
@@ -4460,6 +4404,151 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
};
struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return true; }
//template <typename Q8>
//inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
// d = GGML_FP16_TO_FP32(x[i].d);
//}
inline void new_block(int i) {
d = GGML_FP16_TO_FP32(x[i].d);
}
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
}
}
template <typename Q8>
inline void compute1(const Q8& q8, int i, int j, int32x4_t * sumi) {
auto q8b_1 = q8.load_quants(0, i, 4*j+0);
sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(0, i, 4*j+1);
sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
q8b_1 = q8.load_quants(0, i, 4*j+2);
sumi[0] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[0], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_1.val[1]);
q8b_2 = q8.load_quants(0, i, 4*j+3);
sumi[1] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[1], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_2.val[1]);
}
IQK_ALWAYS_INLINE void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
auto m1 = vdupq_n_s8(1);
for (int k = 0; k < 4; ++k) {
bits.b1.val[k] = vsubq_s8(bits.b1.val[k], m1);
bits.b2.val[k] = vsubq_s8(bits.b2.val[k], m1);
}
}
Q2bits bits;
float d;
};
template <int nrc_y>
void mul_mat_iq2tn_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_K> q8(info);
DequantizerIQ2TN deq(vx, bx, nrc_y);
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
float32x4_t acc[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
for (int i = 0; i < nb; ++i) {
int32x4_t sumi[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) sumi[iy] = vdupq_n_s32(0);
//deq.process_scales(i, q8, acc);
deq.new_block(i);
deq.prepare(i, 0);
deq.compute(q8, i, 0, sumi);
deq.prepare(i, 1);
deq.compute(q8, i, 1, sumi);
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(sumi[iy]), vdupq_n_f32(deq.d*q8.scale(iy, i)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(acc[iy]));
}
}
}
void mul_mat_iq2tn_K_q8_K_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<1, block_q8_K> q8(info);
DequantizerIQ2TN deq(vx, bx, 1);
auto m1 = vdup_n_s16(-1);
for (int ix = 0; ix < nrc_x; ++ix) {
deq.new_row(ix);
float32x4_t acc[2] = {};
for (int i = 0; i < nb; ++i) {
int32x4_t sumi[2] = {};
deq.new_block(i);
auto bsums = q8.load_bsums(0, i);
bsums.val[0] = vaddq_s32(bsums.val[0], bsums.val[1]);
sumi[0] = vmlal_s16(sumi[0], vget_low_s16 (bsums.val[0]), m1);
sumi[1] = vmlal_s16(sumi[1], vget_high_s16(bsums.val[0]), m1);
deq.bits.prepare(deq.x[i].qs);
deq.compute1(q8, i, 0, sumi);
deq.bits.prepare(deq.x[i].qs+32);
deq.compute1(q8, i, 1, sumi);
auto vd = vdupq_n_f32(deq.d*q8.scale(0, i));
acc[0] = vmlaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vd);
acc[1] = vmlaq_f32(acc[1], vcvtq_f32_s32(sumi[1]), vd);
}
acc[0] = vaddq_f32(acc[0], acc[1]);
info.store(ix, 0, vaddvq_f32(acc[0]));
}
}
template <int nrc_y, typename Dequantizer>
void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -5450,7 +5539,15 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
MulMat::set_functions<DequantizerQ2K>(m);
break;
case GGML_TYPE_IQ2_TN:
MulMat::set_functions<DequantizerIQ2TN>(m);
//MulMat::set_functions<DequantizerIQ2TN>(m);
m.funcs[0] = mul_mat_iq2tn_K_q8_K_1;
m.funcs[1] = mul_mat_iq2tn_K_q8_K_T<2>;
m.funcs[2] = mul_mat_iq2tn_K_q8_K_T<3>;
m.funcs[3] = mul_mat_iq2tn_K_q8_K_T<4>;
m.funcs[4] = mul_mat_iq2tn_K_q8_K_T<5>;
m.funcs[5] = mul_mat_iq2tn_K_q8_K_T<6>;
m.funcs[6] = mul_mat_iq2tn_K_q8_K_T<7>;
m.funcs[7] = mul_mat_iq2tn_K_q8_K_T<8>;
break;
case GGML_TYPE_Q3_K:
MulMat::set_functions<DequantizerQ3K>(m);