iq2_tn: initial NEON version

This commit is contained in:
Iwan Kawrakow
2024-08-05 17:53:02 +02:00
parent a63ba11a25
commit 810285581c

View File

@@ -1484,13 +1484,6 @@ struct DequantizerQ6K final : public BaseDequantizer<block_q6_K> {
struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
DequantizerIQ2TN(const void * vx, size_t bx) : BaseDequantizer(vx, bx) {}
//template <typename Q8>
//inline void new_block(int i, const Q8& q8, __m256i * sumi) {
// d = GGML_FP16_TO_FP32(x[i].d);
// for (int iy = 0; iy < Q8::nrc_y; ++iy) {
// sumi[iy] = q8.load_bsums(iy, i);
// }
//}
inline void new_block(int i) {
d = GGML_FP16_TO_FP32(x[i].d);
}
@@ -3883,6 +3876,62 @@ struct DequantizerQ2K final : public BaseDequantizer<block_q2_K> {
float d;
};
struct DequantizerIQ2TN final : public BaseDequantizer<block_iq2_tn> {
DequantizerIQ2TN(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {}
constexpr static int num_blocks() { return 16; }
constexpr static bool should_scale_quants() { return true; }
template <typename Q8>
inline void process_scales(int i, [[maybe_unused]] const Q8& q8, [[maybe_unused]] float32x4_t * acc) {
d = GGML_FP16_TO_FP32(x[i].d);
}
template <typename Q8>
inline int32x4x4_t new_block(int i, const Q8& q8, float32x4_t * acc) {
process_scales(i, q8, acc);
return { vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1), vdupq_n_s32(1) };
}
template <typename Q8>
inline void compute(const Q8& q8, int i, int j, int32x4_t * sumi) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8b_1 = q8.load_quants(iy, i, 4*j+0);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[0]), q8b_1.val[0]),
vreinterpretq_s8_u8(bits.b1.val[1]), q8b_1.val[1]);
auto q8b_2 = q8.load_quants(iy, i, 4*j+1);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b1.val[2]), q8b_2.val[0]),
vreinterpretq_s8_u8(bits.b1.val[3]), q8b_2.val[1]);
auto q8b_3 = q8.load_quants(iy, i, 4*j+2);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[0]), q8b_3.val[0]),
vreinterpretq_s8_u8(bits.b2.val[1]), q8b_3.val[1]);
auto q8b_4 = q8.load_quants(iy, i, 4*j+3);
sumi[iy] = ggml_vdotq_s32(ggml_vdotq_s32(sumi[iy], vreinterpretq_s8_u8(bits.b2.val[2]), q8b_4.val[0]),
vreinterpretq_s8_u8(bits.b2.val[3]), q8b_4.val[1]);
}
}
inline void prepare(int i, int j) {
bits.prepare(x[i].qs+32*j);
auto m1 = vdupq_n_s8(1);
bits.b1.val[0] = vsubq_s8(bits.b1.val[0], m1);
bits.b1.val[1] = vsubq_s8(bits.b1.val[1], m1);
bits.b1.val[2] = vsubq_s8(bits.b1.val[2], m1);
bits.b1.val[3] = vsubq_s8(bits.b1.val[3], m1);
bits.b2.val[0] = vsubq_s8(bits.b2.val[0], m1);
bits.b2.val[1] = vsubq_s8(bits.b2.val[1], m1);
bits.b2.val[2] = vsubq_s8(bits.b2.val[2], m1);
bits.b2.val[3] = vsubq_s8(bits.b2.val[3], m1);
}
Q2bits bits;
float d;
};
// ============================= i-quants
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
@@ -5400,6 +5449,9 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q2_K:
MulMat::set_functions<DequantizerQ2K>(m);
break;
case GGML_TYPE_IQ2_TN:
MulMat::set_functions<DequantizerIQ2TN>(m);
break;
case GGML_TYPE_Q3_K:
MulMat::set_functions<DequantizerQ3K>(m);
break;