iq4_knn: ARM_NEON

Pretty good performance - on M2-Max we get
PP-512(LLaMA-3.1-8B) = 89.5 t/s
TG-128(LLaMA-3.1-8B) = 27.65 t/s
This commit is contained in:
Iwan Kawrakow
2024-10-18 08:02:22 +02:00
parent cc912c3f7c
commit 780929a6d0

View File

@@ -5696,6 +5696,64 @@ void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info
}
}
template <int nrc_y>
void mul_mat_iq4_knn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n % QK_K == 0);
const int nb = n / QK_K;
Q8<nrc_y, block_q8_K> q8(info);
const char * qrow = (const char *)vx;
auto ml = vdupq_n_u8(0xf);
int8x16_t q4[16];
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)qrow;
float d = *dptr;
const int8_t * int_values = (const int8_t *)(dptr + 1);
auto values = vld1q_s8(int_values);
auto x = (const block_iq4_knn *)(int_values + 16);
float32x4_t acc[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
for (int i = 0; i < nb; ++i) {
auto bits = vld1q_u8_x4(x[i].qs);
for (int k = 0; k < 4; ++k) {
q4[2*k+0] = vqtbl1q_s8(values, vandq_u8(bits.val[k], ml));
q4[2*k+1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[k], 4));
}
bits = vld1q_u8_x4(x[i].qs+64);
for (int k = 0; k < 4; ++k) {
q4[2*k+8] = vqtbl1q_s8(values, vandq_u8(bits.val[k], ml));
q4[2*k+9] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[k], 4));
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi1 = vdupq_n_s32(0);
auto sumi2 = vdupq_n_s32(0);
for (int k = 0; k < 4; ++k) {
auto qy = q8.load_quants_64(iy, i, k);
sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q4[4*k+0], qy.val[0]), q4[4*k+1], qy.val[1]);
sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q4[4*k+2], qy.val[2]), q4[4*k+3], qy.val[3]);
}
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i)));
//acc[iy] = i > 0 ? vmlaq_f32(acc[iy], vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i))) :
// vmulq_f32(vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i)));
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(acc[iy]));
}
qrow += bx;
}
}
// =========================================== Legacy quants
template <typename Block>
@@ -6969,6 +7027,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ4_KSS:
MulMat::set_functions<DequantizerIQ4KSS>(m);
break;
case GGML_TYPE_IQ4_KNN:
m.funcs[0] = mul_mat_iq4_knn_q8_K<1>;
m.funcs[1] = mul_mat_iq4_knn_q8_K<2>;
m.funcs[2] = mul_mat_iq4_knn_q8_K<3>;
m.funcs[3] = mul_mat_iq4_knn_q8_K<4>;
m.funcs[4] = mul_mat_iq4_knn_q8_K<5>;
m.funcs[5] = mul_mat_iq4_knn_q8_K<6>;
m.funcs[6] = mul_mat_iq4_knn_q8_K<7>;
m.funcs[7] = mul_mat_iq4_knn_q8_K<8>;
break;
case GGML_TYPE_IQ2_KS:
MulMat::set_functions<DequantizerIQ2KS>(m);
break;