From df139c5649bc4a68f2f452cf8b2374588c03d2a1 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 7 Dec 2024 18:47:37 +0100 Subject: [PATCH] qx_0_r4_q8_0 template Applied to q4_0_r4 and q5_0_r4. It makes q5_0_r4 PP ~7% faster. --- ggml/src/iqk/iqk_mul_mat.cpp | 361 +++++++++++++++++++++++------------ src/llama.cpp | 4 + 2 files changed, 239 insertions(+), 126 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index b3c6d004..3164f02f 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -7793,146 +7793,238 @@ void mul_mat_iq4_nl_x4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& } } -//template -//inline void do_1_block(int ib4, const Q8& q8, const float32x4_t * d8, const block_iq4_nl_x4 * iq4, -// int8x16_t * qx, float32x4_t * acc, const uint8x16_t& m4, const uint8x16_t& m88) { -// auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); -// auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); -// for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); -// qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows -// qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 -// qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 -// qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 -// qx[4] = vandq_u8(bits.val[0], m4); // 8..11 -// qx[5] = vandq_u8(bits.val[1], m4); // 24..27 -// qx[6] = vandq_u8(bits.val[2], m4); // 12..15 -// qx[7] = vandq_u8(bits.val[3], m4); // 28..31 -// for (int iy = 0; iy < nrc_y; ++iy) { -// auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); -// auto sumi = vdupq_n_s32(0); -// sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); -// sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); -// sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); -// sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); -// sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); -// sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); -// sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); -// sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); -// //auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); -// //auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); -// auto d4d8 = vmulq_laneq_f32(scales, d8[iy], k); -// acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); +//template +//void mul_mat_q4_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +// GGML_ASSERT(nrc_x%4 == 0); +// Q8 q8(info); +// auto m4 = vdupq_n_u8(0xf0); +// auto m88 = vdupq_n_u8(0x88); +// auto norm = vdupq_n_f32(1.f/16); +// int nb = n / QK4_NL; +// GGML_ASSERT(nb%4 == 0); +// int8x16_t qx[8]; +// float d8[4*nrc_y]; +// float32x4_t acc[nrc_y] = {}; +// for (int ix = 0; ix < nrc_x; ix += 4) { +// const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx); +// for (int ib4 = 0; ib4 < nb/4; ++ib4) { +// for (int iy = 0; iy < nrc_y; ++iy) { +// vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); +// } +// for (int k = 0; k < 4; ++k) { +// auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); +// auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); +// for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); +// qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows +// qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 +// qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 +// qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 +// qx[4] = vandq_u8(bits.val[0], m4); // 8..11 +// qx[5] = vandq_u8(bits.val[1], m4); // 24..27 +// qx[6] = vandq_u8(bits.val[2], m4); // 12..15 +// qx[7] = vandq_u8(bits.val[3], m4); // 28..31 +// for (int iy = 0; iy < nrc_y; ++iy) { +// auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); +// auto sumi = interleaved_dotq(qx, y); +// auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); +// acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); +// } +// } +// } +// for (int iy = 0; iy < nrc_y; ++iy) { +// info.store(ix, iy, vmulq_f32(norm, acc[iy])); +// acc[iy] = vdupq_n_f32(0.f); +// } // } //} -template -void mul_mat_q4_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +template +void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { GGML_ASSERT(nrc_x%4 == 0); Q8 q8(info); - auto m4 = vdupq_n_u8(0xf0); - auto m88 = vdupq_n_u8(0x88); - auto norm = vdupq_n_f32(1.f/16); + Dequantizer deq(vx, bx); int nb = n / QK4_NL; GGML_ASSERT(nb%4 == 0); int8x16_t qx[8]; float d8[4*nrc_y]; - //float32x4_t d8[nrc_y]; float32x4_t acc[nrc_y] = {}; for (int ix = 0; ix < nrc_x; ix += 4) { - const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx); + deq.new_row(ix); for (int ib4 = 0; ib4 < nb/4; ++ib4) { for (int iy = 0; iy < nrc_y; ++iy) { - //d8[iy] = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)); vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d))); } - //do_1_block(ib4, q8, d8, iq4, qx, acc, m4, m88); - //do_1_block(ib4, q8, d8, iq4, qx, acc, m4, m88); - //do_1_block(ib4, q8, d8, iq4, qx, acc, m4, m88); - //do_1_block(ib4, q8, d8, iq4, qx, acc, m4, m88); for (int k = 0; k < 4; ++k) { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); - auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); - for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); - qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows - qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 - qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 - qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 - qx[4] = vandq_u8(bits.val[0], m4); // 8..11 - qx[5] = vandq_u8(bits.val[1], m4); // 24..27 - qx[6] = vandq_u8(bits.val[2], m4); // 12..15 - qx[7] = vandq_u8(bits.val[3], m4); // 28..31 + auto scales = deq.prepare(ib4, k, qx); for (int iy = 0; iy < nrc_y; ++iy) { auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); - //auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); + auto sumi = interleaved_dotq(qx, y); auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k])); acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); } } } for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, vmulq_f32(norm, acc[iy])); + info.store(ix, iy, deq.result(acc[iy])); acc[iy] = vdupq_n_f32(0.f); } } } -template -void mul_mat_q5_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - GGML_ASSERT(nrc_x%4 == 0); - Q8 q8(info); - auto m4 = vdupq_n_u8(0x0f); - auto m5 = vdupq_n_u8(0x10); - auto m16 = vdupq_n_s8(-16); - int nb = n / QK5_0; - GGML_ASSERT(nb%4 == 0); - int8x16_t qx[8]; - float32x4_t acc[nrc_y] = {}; - for (int ix = 0; ix < nrc_x; ix += 4) { - const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); - for (int ib4 = 0; ib4 < nb/4; ++ib4) { - for (int k = 0; k < 4; ++k) { - auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); - auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); - auto hbits = vld1q_u8(iq5[4*ib4+k].qh); - qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 - qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 - qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 - qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23 - qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11 - qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27 - qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15 - qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31 - for (int iy = 0; iy < nrc_y; ++iy) { - auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); - auto sumi = vdupq_n_s32(0); - sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); - sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); - sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); - sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); - sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); - sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); - sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); - sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); - auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); - acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); - } - } - } - for (int iy = 0; iy < nrc_y; ++iy) { - info.store(ix, iy, acc[iy]); - acc[iy] = vdupq_n_f32(0.f); - } +//template +//void mul_mat_qx_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +// GGML_ASSERT(nrc_x%4 == 0); +// Q8<1, block_q8_0_x4> q8(info); +// Dequantizer deq(vx, bx); +// int nb = n / QK4_NL; +// GGML_ASSERT(nb%4 == 0); +// int8x16_t qx[8]; +// float32x4_t acc[1] = {}; +// int32x4_t sumi[4]; +// for (int ix = 0; ix < nrc_x; ix += 4) { +// deq.new_row(ix); +// for (int ib4 = 0; ib4 < nb/4; ++ib4) { +// auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib4].d)); +// for (int k = 0; k < 4; ++k) { +// deq.prepare_bits(ib4, k, qx); +// auto y = vld1q_s8_x2(q8.y[0][ib4].qs+32*k); +// sumi[k] = interleaved_dotq(qx, y); +// } +// auto scales = deq.all_scales(ib4); +// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vmulq_laneq_f32(scales.val[0], d8, 0)); +// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[1]), vmulq_laneq_f32(scales.val[1], d8, 1)); +// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[2]), vmulq_laneq_f32(scales.val[2], d8, 2)); +// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[3]), vmulq_laneq_f32(scales.val[3], d8, 3)); +// } +// info.store(ix, 0, deq.result(acc[0])); +// acc[0] = vdupq_n_f32(0.f); +// } +//} + +struct Q4_0_R4_Dequantizer { + Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq4 = (const block_iq4_nl_x4 *)(cx + ix*bx); } + //inline void prepare_bits(int ib4, int k, int8x16_t * qx) const { + // auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + // for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); + // qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows + // qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 + // qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 + // qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 + // qx[4] = vandq_u8(bits.val[0], m4); // 8..11 + // qx[5] = vandq_u8(bits.val[1], m4); // 24..27 + // qx[6] = vandq_u8(bits.val[2], m4); // 12..15 + // qx[7] = vandq_u8(bits.val[3], m4); // 28..31 + //} + //inline float32x4x4_t all_scales(int ib4) const { + // float32x4x4_t r; + // for (int k = 0; k < 4; ++k) r.val[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); + // return r; + //} + inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d)); + //prepare_bits(ib4, k, qx); + auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs); + for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]); + qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows + qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19 + qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7 + qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23 + qx[4] = vandq_u8(bits.val[0], m4); // 8..11 + qx[5] = vandq_u8(bits.val[1], m4); // 24..27 + qx[6] = vandq_u8(bits.val[2], m4); // 12..15 + qx[7] = vandq_u8(bits.val[3], m4); // 28..31 + return scales; } -} + inline float32x4_t result(float32x4_t acc) const { + return vmulq_f32(norm, acc); + } + + const char * cx; + const size_t bx; + const block_iq4_nl_x4 * iq4; + const uint8x16_t m4 = vdupq_n_u8(0xf0); + const uint8x16_t m88 = vdupq_n_u8(0x88); + const float32x4_t norm = vdupq_n_f32(1.f/16); +}; + +struct Q5_0_R4_Dequantizer { + Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {} + inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); } + inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const { + auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); + auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); + auto hbits = vld1q_u8(iq5[4*ib4+k].qh); + qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 + qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 + qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 + qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23 + qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11 + qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27 + qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15 + qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31 + return scales; + } + inline float32x4_t result(float32x4_t acc) const { + return acc; + } + + const char * cx; + const size_t bx; + const block_q5_0_r4 * iq5; + const uint8x16_t m4 = vdupq_n_u8(0x0f); + const uint8x16_t m5 = vdupq_n_u8(0x10); + const int8x16_t m16 = vdupq_n_s8(-16); +}; + +//template +//void mul_mat_q5_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { +// GGML_ASSERT(nrc_x%4 == 0); +// Q8 q8(info); +// auto m4 = vdupq_n_u8(0x0f); +// auto m5 = vdupq_n_u8(0x10); +// auto m16 = vdupq_n_s8(-16); +// int nb = n / QK5_0; +// GGML_ASSERT(nb%4 == 0); +// int8x16_t qx[8]; +// float32x4_t acc[nrc_y] = {}; +// for (int ix = 0; ix < nrc_x; ix += 4) { +// const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx); +// for (int ib4 = 0; ib4 < nb/4; ++ib4) { +// for (int k = 0; k < 4; ++k) { +// auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d)); +// auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs); +// auto hbits = vld1q_u8(iq5[4*ib4+k].qh); +// qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3 +// qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19 +// qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7 +// qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23 +// qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11 +// qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27 +// qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15 +// qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31 +// for (int iy = 0; iy < nrc_y; ++iy) { +// auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k); +// auto sumi = vdupq_n_s32(0); +// sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0); +// sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0); +// sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1); +// sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1); +// sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2); +// sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2); +// sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3); +// sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3); +// auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k]))); +// acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi)); +// } +// } +// } +// for (int iy = 0; iy < nrc_y; ++iy) { +// info.store(ix, iy, acc[iy]); +// acc[iy] = vdupq_n_f32(0.f); +// } +// } +//} template void mul_mat_q6_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { @@ -8223,25 +8315,42 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { expected_Btype = GGML_TYPE_Q8_K; break; case GGML_TYPE_Q4_0_R4: - m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>; - m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>; - m.funcs[2] = mul_mat_q4_0_r4_q8_0<3>; - m.funcs[3] = mul_mat_q4_0_r4_q8_0<4>; - m.funcs[4] = mul_mat_q4_0_r4_q8_0<5>; - m.funcs[5] = mul_mat_q4_0_r4_q8_0<6>; - m.funcs[6] = mul_mat_q4_0_r4_q8_0<7>; - m.funcs[7] = mul_mat_q4_0_r4_q8_0<8>; + //m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>; + //m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>; + //m.funcs[2] = mul_mat_q4_0_r4_q8_0<3>; + //m.funcs[3] = mul_mat_q4_0_r4_q8_0<4>; + //m.funcs[4] = mul_mat_q4_0_r4_q8_0<5>; + //m.funcs[5] = mul_mat_q4_0_r4_q8_0<6>; + //m.funcs[6] = mul_mat_q4_0_r4_q8_0<7>; + //m.funcs[7] = mul_mat_q4_0_r4_q8_0<8>; + //m.funcs[0] = mul_mat_qx_r4_q8_0_1; + m.funcs[0] = mul_mat_qx_r4_q8_0; + m.funcs[1] = mul_mat_qx_r4_q8_0; + m.funcs[2] = mul_mat_qx_r4_q8_0; + m.funcs[3] = mul_mat_qx_r4_q8_0; + m.funcs[4] = mul_mat_qx_r4_q8_0; + m.funcs[5] = mul_mat_qx_r4_q8_0; + m.funcs[6] = mul_mat_qx_r4_q8_0; + m.funcs[7] = mul_mat_qx_r4_q8_0; expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q5_0_R4: - m.funcs[0] = mul_mat_q5_0_r4_q8_0<1>; - m.funcs[1] = mul_mat_q5_0_r4_q8_0<2>; - m.funcs[2] = mul_mat_q5_0_r4_q8_0<3>; - m.funcs[3] = mul_mat_q5_0_r4_q8_0<4>; - m.funcs[4] = mul_mat_q5_0_r4_q8_0<5>; - m.funcs[5] = mul_mat_q5_0_r4_q8_0<6>; - m.funcs[6] = mul_mat_q5_0_r4_q8_0<7>; - m.funcs[7] = mul_mat_q5_0_r4_q8_0<8>; + //m.funcs[0] = mul_mat_q5_0_r4_q8_0<1>; + //m.funcs[1] = mul_mat_q5_0_r4_q8_0<2>; + //m.funcs[2] = mul_mat_q5_0_r4_q8_0<3>; + //m.funcs[3] = mul_mat_q5_0_r4_q8_0<4>; + //m.funcs[4] = mul_mat_q5_0_r4_q8_0<5>; + //m.funcs[5] = mul_mat_q5_0_r4_q8_0<6>; + //m.funcs[6] = mul_mat_q5_0_r4_q8_0<7>; + //m.funcs[7] = mul_mat_q5_0_r4_q8_0<8>; + m.funcs[0] = mul_mat_qx_r4_q8_0; + m.funcs[1] = mul_mat_qx_r4_q8_0; + m.funcs[2] = mul_mat_qx_r4_q8_0; + m.funcs[3] = mul_mat_qx_r4_q8_0; + m.funcs[4] = mul_mat_qx_r4_q8_0; + m.funcs[5] = mul_mat_qx_r4_q8_0; + m.funcs[6] = mul_mat_qx_r4_q8_0; + m.funcs[7] = mul_mat_qx_r4_q8_0; expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q6_0_R4: diff --git a/src/llama.cpp b/src/llama.cpp index ad76a7b8..0e1aadbd 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16569,6 +16569,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0; else chunk_size_multiplier = 4; } + else if (new_type == GGML_TYPE_Q5_0_R4) { + if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0; + else chunk_size_multiplier = 4; + } else if (new_type == GGML_TYPE_Q6_0_R4) { if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0; else chunk_size_multiplier = 4;