mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-25 07:34:10 +00:00
Make q5,6_0_r4, iq4_nl_e4 work with row size that are not a multiple of 128
also on NEON.
This commit is contained in:
@@ -12215,7 +12215,6 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
Q8<nrc_y, block_q8_0_x4> q8(info);
|
||||
Dequantizer deq(vx, bx);
|
||||
int nb = n / QK4_NL;
|
||||
GGML_ASSERT(nb%4 == 0);
|
||||
int8x16_t qx[8];
|
||||
float d8[4*nrc_y];
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
@@ -12226,7 +12225,7 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
|
||||
}
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto scales = deq.prepare(ib4, k, qx);
|
||||
auto scales = deq.prepare(4*ib4+k, qx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
@@ -12235,6 +12234,16 @@ void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int ib = 4*(nb/4); ib < nb; ++ib) {
|
||||
auto scales = deq.prepare(ib, qx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto qy = (const block_q8_0 *)q8.y[iy];
|
||||
auto y = vld1q_s8_x2(qy[ib].qs);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(qy[ib].d)));
|
||||
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, deq.result(acc[iy]));
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
@@ -12292,9 +12301,9 @@ void mul_mat_qx_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& info,
|
||||
struct IQ4_NL_R4_Dequantizer {
|
||||
IQ4_NL_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx), values(vld1q_s8(iq4k_values)) {}
|
||||
inline void new_row(int ix) { iq4 = (const block_iq4_nl_r4 *)(cx + ix*bx); }
|
||||
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
|
||||
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
|
||||
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
|
||||
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
|
||||
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ib].d));
|
||||
auto bits = vld1q_u8_x4(iq4[ib].qs);
|
||||
prepare_iq4_nl_quants(values, m4, bits, qx);
|
||||
return scales;
|
||||
}
|
||||
@@ -12370,10 +12379,10 @@ struct Q4_0_R8_Dequantizer {
|
||||
struct Q5_0_R4_Dequantizer {
|
||||
Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
|
||||
inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
|
||||
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
|
||||
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
|
||||
auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
|
||||
auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
|
||||
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
|
||||
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ib].d));
|
||||
auto lbits = vld1q_u8_x4(iq5[ib].qs);
|
||||
auto hbits = vld1q_u8(iq5[ib].qh);
|
||||
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
|
||||
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
|
||||
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
|
||||
@@ -12399,10 +12408,10 @@ struct Q5_0_R4_Dequantizer {
|
||||
struct Q6_0_R4_Dequantizer {
|
||||
Q6_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
|
||||
inline void new_row(int ix) { iq6 = (const block_q6_0_r4 *)(cx + ix*bx); }
|
||||
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
|
||||
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[4*ib4+k].d));
|
||||
auto lbits = vld1q_u8_x4(iq6[4*ib4+k].qs);
|
||||
auto hbits = vld1q_u8_x2(iq6[4*ib4+k].qh);
|
||||
inline float32x4_t prepare(int ib, int8x16_t * qx) const {
|
||||
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq6[ib].d));
|
||||
auto lbits = vld1q_u8_x4(iq6[ib].qs);
|
||||
auto hbits = vld1q_u8_x2(iq6[ib].qh);
|
||||
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 4), m6), m32); // 0...3
|
||||
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits.val[1], 4), m6), m32); // 16..19
|
||||
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits.val[0], 2), m6), m32); // 4...7
|
||||
|
||||
Reference in New Issue
Block a user