mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
iq4_knn: ARM_NEON
Pretty good performance - on M2-Max we get PP-512(LLaMA-3.1-8B) = 89.5 t/s TG-128(LLaMA-3.1-8B) = 27.65 t/s
This commit is contained in:
@@ -5696,6 +5696,64 @@ void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_knn_q8_K(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n % QK_K == 0);
|
||||
const int nb = n / QK_K;
|
||||
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
|
||||
const char * qrow = (const char *)vx;
|
||||
|
||||
auto ml = vdupq_n_u8(0xf);
|
||||
|
||||
int8x16_t q4[16];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
|
||||
const float * dptr = (const float *)qrow;
|
||||
float d = *dptr;
|
||||
const int8_t * int_values = (const int8_t *)(dptr + 1);
|
||||
auto values = vld1q_s8(int_values);
|
||||
auto x = (const block_iq4_knn *)(int_values + 16);
|
||||
|
||||
float32x4_t acc[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
auto bits = vld1q_u8_x4(x[i].qs);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
q4[2*k+0] = vqtbl1q_s8(values, vandq_u8(bits.val[k], ml));
|
||||
q4[2*k+1] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[k], 4));
|
||||
}
|
||||
bits = vld1q_u8_x4(x[i].qs+64);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
q4[2*k+8] = vqtbl1q_s8(values, vandq_u8(bits.val[k], ml));
|
||||
q4[2*k+9] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[k], 4));
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto sumi1 = vdupq_n_s32(0);
|
||||
auto sumi2 = vdupq_n_s32(0);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto qy = q8.load_quants_64(iy, i, k);
|
||||
sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(sumi1, q4[4*k+0], qy.val[0]), q4[4*k+1], qy.val[1]);
|
||||
sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(sumi2, q4[4*k+2], qy.val[2]), q4[4*k+3], qy.val[3]);
|
||||
}
|
||||
acc[iy] = vmlaq_f32(acc[iy], vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i)));
|
||||
//acc[iy] = i > 0 ? vmlaq_f32(acc[iy], vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i))) :
|
||||
// vmulq_f32(vcvtq_f32_s32(vaddq_s32(sumi1, sumi2)), vdupq_n_f32(d*q8.scale(iy, i)));
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(acc[iy]));
|
||||
}
|
||||
|
||||
qrow += bx;
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================== Legacy quants
|
||||
|
||||
template <typename Block>
|
||||
@@ -6969,6 +7027,16 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_IQ4_KSS:
|
||||
MulMat::set_functions<DequantizerIQ4KSS>(m);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KNN:
|
||||
m.funcs[0] = mul_mat_iq4_knn_q8_K<1>;
|
||||
m.funcs[1] = mul_mat_iq4_knn_q8_K<2>;
|
||||
m.funcs[2] = mul_mat_iq4_knn_q8_K<3>;
|
||||
m.funcs[3] = mul_mat_iq4_knn_q8_K<4>;
|
||||
m.funcs[4] = mul_mat_iq4_knn_q8_K<5>;
|
||||
m.funcs[5] = mul_mat_iq4_knn_q8_K<6>;
|
||||
m.funcs[6] = mul_mat_iq4_knn_q8_K<7>;
|
||||
m.funcs[7] = mul_mat_iq4_knn_q8_K<8>;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_KS:
|
||||
MulMat::set_functions<DequantizerIQ2KS>(m);
|
||||
break;
|
||||
|
||||
Reference in New Issue
Block a user