qx_0_r4_q8_0 template

Applied to q4_0_r4 and q5_0_r4. It makes q5_0_r4 PP
~7% faster.
This commit is contained in:
Iwan Kawrakow
2024-12-07 18:47:37 +01:00
parent 12d3ea1e30
commit df139c5649
2 changed files with 239 additions and 126 deletions

View File

@@ -7793,146 +7793,238 @@ void mul_mat_iq4_nl_x4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo&
}
}
//template <int nrc_y, int k>
//inline void do_1_block(int ib4, const Q8<nrc_y, block_q8_0_x4>& q8, const float32x4_t * d8, const block_iq4_nl_x4 * iq4,
// int8x16_t * qx, float32x4_t * acc, const uint8x16_t& m4, const uint8x16_t& m88) {
// auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
// auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
// for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
// qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
// qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
// qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
// qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
// qx[4] = vandq_u8(bits.val[0], m4); // 8..11
// qx[5] = vandq_u8(bits.val[1], m4); // 24..27
// qx[6] = vandq_u8(bits.val[2], m4); // 12..15
// qx[7] = vandq_u8(bits.val[3], m4); // 28..31
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
// auto sumi = vdupq_n_s32(0);
// sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
// sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
// sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
// sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
// sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
// sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
// sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
// sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
// //auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
// //auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
// auto d4d8 = vmulq_laneq_f32(scales, d8[iy], k);
// acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
//template <int nrc_y>
//void mul_mat_q4_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(nrc_x%4 == 0);
// Q8<nrc_y, block_q8_0_x4> q8(info);
// auto m4 = vdupq_n_u8(0xf0);
// auto m88 = vdupq_n_u8(0x88);
// auto norm = vdupq_n_f32(1.f/16);
// int nb = n / QK4_NL;
// GGML_ASSERT(nb%4 == 0);
// int8x16_t qx[8];
// float d8[4*nrc_y];
// float32x4_t acc[nrc_y] = {};
// for (int ix = 0; ix < nrc_x; ix += 4) {
// const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
// for (int ib4 = 0; ib4 < nb/4; ++ib4) {
// for (int iy = 0; iy < nrc_y; ++iy) {
// vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
// }
// for (int k = 0; k < 4; ++k) {
// auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
// auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
// for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
// qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
// qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
// qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
// qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
// qx[4] = vandq_u8(bits.val[0], m4); // 8..11
// qx[5] = vandq_u8(bits.val[1], m4); // 24..27
// qx[6] = vandq_u8(bits.val[2], m4); // 12..15
// qx[7] = vandq_u8(bits.val[3], m4); // 28..31
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
// auto sumi = interleaved_dotq(qx, y);
// auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
// acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
// }
// }
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// info.store(ix, iy, vmulq_f32(norm, acc[iy]));
// acc[iy] = vdupq_n_f32(0.f);
// }
// }
//}
template <int nrc_y>
void mul_mat_q4_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
template <typename Dequantizer, int nrc_y>
void mul_mat_qx_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
auto m4 = vdupq_n_u8(0xf0);
auto m88 = vdupq_n_u8(0x88);
auto norm = vdupq_n_f32(1.f/16);
Dequantizer deq(vx, bx);
int nb = n / QK4_NL;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float d8[4*nrc_y];
//float32x4_t d8[nrc_y];
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_iq4_nl_x4 * iq4 = (const block_iq4_nl_x4 *)((const char *)vx + ix*bx);
deq.new_row(ix);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int iy = 0; iy < nrc_y; ++iy) {
//d8[iy] = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d));
vst1q_f32(d8+4*iy, vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[iy][ib4].d)));
}
//do_1_block<nrc_y, 0>(ib4, q8, d8, iq4, qx, acc, m4, m88);
//do_1_block<nrc_y, 1>(ib4, q8, d8, iq4, qx, acc, m4, m88);
//do_1_block<nrc_y, 2>(ib4, q8, d8, iq4, qx, acc, m4, m88);
//do_1_block<nrc_y, 3>(ib4, q8, d8, iq4, qx, acc, m4, m88);
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
qx[4] = vandq_u8(bits.val[0], m4); // 8..11
qx[5] = vandq_u8(bits.val[1], m4); // 24..27
qx[6] = vandq_u8(bits.val[2], m4); // 12..15
qx[7] = vandq_u8(bits.val[3], m4); // 28..31
auto scales = deq.prepare(ib4, k, qx);
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
//auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
auto sumi = interleaved_dotq(qx, y);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(d8[4*iy+k]));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(norm, acc[iy]));
info.store(ix, iy, deq.result(acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
void mul_mat_q5_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_0_x4> q8(info);
auto m4 = vdupq_n_u8(0x0f);
auto m5 = vdupq_n_u8(0x10);
auto m16 = vdupq_n_s8(-16);
int nb = n / QK5_0;
GGML_ASSERT(nb%4 == 0);
int8x16_t qx[8];
float32x4_t acc[nrc_y] = {};
for (int ix = 0; ix < nrc_x; ix += 4) {
const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
for (int ib4 = 0; ib4 < nb/4; ++ib4) {
for (int k = 0; k < 4; ++k) {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
auto sumi = vdupq_n_s32(0);
sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
//template <typename Dequantizer>
//void mul_mat_qx_r4_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(nrc_x%4 == 0);
// Q8<1, block_q8_0_x4> q8(info);
// Dequantizer deq(vx, bx);
// int nb = n / QK4_NL;
// GGML_ASSERT(nb%4 == 0);
// int8x16_t qx[8];
// float32x4_t acc[1] = {};
// int32x4_t sumi[4];
// for (int ix = 0; ix < nrc_x; ix += 4) {
// deq.new_row(ix);
// for (int ib4 = 0; ib4 < nb/4; ++ib4) {
// auto d8 = vcvt_f32_f16(vld1_f16((const float16_t *)q8.y[0][ib4].d));
// for (int k = 0; k < 4; ++k) {
// deq.prepare_bits(ib4, k, qx);
// auto y = vld1q_s8_x2(q8.y[0][ib4].qs+32*k);
// sumi[k] = interleaved_dotq(qx, y);
// }
// auto scales = deq.all_scales(ib4);
// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[0]), vmulq_laneq_f32(scales.val[0], d8, 0));
// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[1]), vmulq_laneq_f32(scales.val[1], d8, 1));
// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[2]), vmulq_laneq_f32(scales.val[2], d8, 2));
// acc[0] = vfmaq_f32(acc[0], vcvtq_f32_s32(sumi[3]), vmulq_laneq_f32(scales.val[3], d8, 3));
// }
// info.store(ix, 0, deq.result(acc[0]));
// acc[0] = vdupq_n_f32(0.f);
// }
//}
struct Q4_0_R4_Dequantizer {
Q4_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq4 = (const block_iq4_nl_x4 *)(cx + ix*bx); }
//inline void prepare_bits(int ib4, int k, int8x16_t * qx) const {
// auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
// for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
// qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
// qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
// qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
// qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
// qx[4] = vandq_u8(bits.val[0], m4); // 8..11
// qx[5] = vandq_u8(bits.val[1], m4); // 24..27
// qx[6] = vandq_u8(bits.val[2], m4); // 12..15
// qx[7] = vandq_u8(bits.val[3], m4); // 28..31
//}
//inline float32x4x4_t all_scales(int ib4) const {
// float32x4x4_t r;
// for (int k = 0; k < 4; ++k) r.val[k] = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
// return r;
//}
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[4*ib4+k].d));
//prepare_bits(ib4, k, qx);
auto bits = vld1q_u8_x4(iq4[4*ib4+k].qs);
for (int j = 0; j < 4; ++j) bits.val[j] = veorq_u8(m88, bits.val[j]);
qx[0] = vshlq_n_u8(bits.val[0], 4); // 0...3 from the 4 rows
qx[1] = vshlq_n_u8(bits.val[1], 4); // 16..19
qx[2] = vshlq_n_u8(bits.val[2], 4); // 4...7
qx[3] = vshlq_n_u8(bits.val[3], 4); // 20..23
qx[4] = vandq_u8(bits.val[0], m4); // 8..11
qx[5] = vandq_u8(bits.val[1], m4); // 24..27
qx[6] = vandq_u8(bits.val[2], m4); // 12..15
qx[7] = vandq_u8(bits.val[3], m4); // 28..31
return scales;
}
}
inline float32x4_t result(float32x4_t acc) const {
return vmulq_f32(norm, acc);
}
const char * cx;
const size_t bx;
const block_iq4_nl_x4 * iq4;
const uint8x16_t m4 = vdupq_n_u8(0xf0);
const uint8x16_t m88 = vdupq_n_u8(0x88);
const float32x4_t norm = vdupq_n_f32(1.f/16);
};
struct Q5_0_R4_Dequantizer {
Q5_0_R4_Dequantizer(const void * vx, size_t bx) : cx((const char *)vx), bx(bx) {}
inline void new_row(int ix) { iq5 = (const block_q5_0_r4 *)(cx + ix*bx); }
inline float32x4_t prepare(int ib4, int k, int8x16_t * qx) const {
auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
return scales;
}
inline float32x4_t result(float32x4_t acc) const {
return acc;
}
const char * cx;
const size_t bx;
const block_q5_0_r4 * iq5;
const uint8x16_t m4 = vdupq_n_u8(0x0f);
const uint8x16_t m5 = vdupq_n_u8(0x10);
const int8x16_t m16 = vdupq_n_s8(-16);
};
//template <int nrc_y>
//void mul_mat_q5_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
// GGML_ASSERT(nrc_x%4 == 0);
// Q8<nrc_y, block_q8_0_x4> q8(info);
// auto m4 = vdupq_n_u8(0x0f);
// auto m5 = vdupq_n_u8(0x10);
// auto m16 = vdupq_n_s8(-16);
// int nb = n / QK5_0;
// GGML_ASSERT(nb%4 == 0);
// int8x16_t qx[8];
// float32x4_t acc[nrc_y] = {};
// for (int ix = 0; ix < nrc_x; ix += 4) {
// const block_q5_0_r4 * iq5 = (const block_q5_0_r4 *)((const char *)vx + ix*bx);
// for (int ib4 = 0; ib4 < nb/4; ++ib4) {
// for (int k = 0; k < 4; ++k) {
// auto scales = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[4*ib4+k].d));
// auto lbits = vld1q_u8_x4(iq5[4*ib4+k].qs);
// auto hbits = vld1q_u8(iq5[4*ib4+k].qh);
// qx[0] = vaddq_s8(vandq_u8(lbits.val[0], m4) | vandq_u8(vshlq_n_u8(hbits, 4), m5), m16); // 0...3
// qx[1] = vaddq_s8(vandq_u8(lbits.val[1], m4) | vandq_u8(vshlq_n_u8(hbits, 3), m5), m16); // 16..19
// qx[2] = vaddq_s8(vandq_u8(lbits.val[2], m4) | vandq_u8(vshlq_n_u8(hbits, 2), m5), m16); // 4...7
// qx[3] = vaddq_s8(vandq_u8(lbits.val[3], m4) | vandq_u8(vshlq_n_u8(hbits, 1), m5), m16); // 20..23
// qx[4] = vaddq_s8(vshrq_n_u8(lbits.val[0], 4)| vandq_u8(hbits, m5), m16); // 8..11
// qx[5] = vaddq_s8(vshrq_n_u8(lbits.val[1], 4)| vandq_u8(vshrq_n_u8(hbits, 1), m5), m16); // 24..27
// qx[6] = vaddq_s8(vshrq_n_u8(lbits.val[2], 4)| vandq_u8(vshrq_n_u8(hbits, 2), m5), m16); // 12..15
// qx[7] = vaddq_s8(vshrq_n_u8(lbits.val[3], 4)| vandq_u8(vshrq_n_u8(hbits, 3), m5), m16); // 28..31
// for (int iy = 0; iy < nrc_y; ++iy) {
// auto y = vld1q_s8_x2(q8.y[iy][ib4].qs+32*k);
// auto sumi = vdupq_n_s32(0);
// sumi = vdotq_laneq_s32(sumi, qx[0], y.val[0], 0);
// sumi = vdotq_laneq_s32(sumi, qx[1], y.val[1], 0);
// sumi = vdotq_laneq_s32(sumi, qx[2], y.val[0], 1);
// sumi = vdotq_laneq_s32(sumi, qx[3], y.val[1], 1);
// sumi = vdotq_laneq_s32(sumi, qx[4], y.val[0], 2);
// sumi = vdotq_laneq_s32(sumi, qx[5], y.val[1], 2);
// sumi = vdotq_laneq_s32(sumi, qx[6], y.val[0], 3);
// sumi = vdotq_laneq_s32(sumi, qx[7], y.val[1], 3);
// auto d4d8 = vmulq_f32(scales, vdupq_n_f32(GGML_FP16_TO_FP32(q8.y[iy][ib4].d[k])));
// acc[iy] = vfmaq_f32(acc[iy], d4d8, vcvtq_f32_s32(sumi));
// }
// }
// }
// for (int iy = 0; iy < nrc_y; ++iy) {
// info.store(ix, iy, acc[iy]);
// acc[iy] = vdupq_n_f32(0.f);
// }
// }
//}
template <int nrc_y>
void mul_mat_q6_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
@@ -8223,25 +8315,42 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_Q4_0_R4:
m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>;
m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>;
m.funcs[2] = mul_mat_q4_0_r4_q8_0<3>;
m.funcs[3] = mul_mat_q4_0_r4_q8_0<4>;
m.funcs[4] = mul_mat_q4_0_r4_q8_0<5>;
m.funcs[5] = mul_mat_q4_0_r4_q8_0<6>;
m.funcs[6] = mul_mat_q4_0_r4_q8_0<7>;
m.funcs[7] = mul_mat_q4_0_r4_q8_0<8>;
//m.funcs[0] = mul_mat_q4_0_r4_q8_0<1>;
//m.funcs[1] = mul_mat_q4_0_r4_q8_0<2>;
//m.funcs[2] = mul_mat_q4_0_r4_q8_0<3>;
//m.funcs[3] = mul_mat_q4_0_r4_q8_0<4>;
//m.funcs[4] = mul_mat_q4_0_r4_q8_0<5>;
//m.funcs[5] = mul_mat_q4_0_r4_q8_0<6>;
//m.funcs[6] = mul_mat_q4_0_r4_q8_0<7>;
//m.funcs[7] = mul_mat_q4_0_r4_q8_0<8>;
//m.funcs[0] = mul_mat_qx_r4_q8_0_1<Q4_0_R4_Dequantizer>;
m.funcs[0] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 1>;
m.funcs[1] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 2>;
m.funcs[2] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 3>;
m.funcs[3] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 4>;
m.funcs[4] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 5>;
m.funcs[5] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 6>;
m.funcs[6] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 7>;
m.funcs[7] = mul_mat_qx_r4_q8_0<Q4_0_R4_Dequantizer, 8>;
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q5_0_R4:
m.funcs[0] = mul_mat_q5_0_r4_q8_0<1>;
m.funcs[1] = mul_mat_q5_0_r4_q8_0<2>;
m.funcs[2] = mul_mat_q5_0_r4_q8_0<3>;
m.funcs[3] = mul_mat_q5_0_r4_q8_0<4>;
m.funcs[4] = mul_mat_q5_0_r4_q8_0<5>;
m.funcs[5] = mul_mat_q5_0_r4_q8_0<6>;
m.funcs[6] = mul_mat_q5_0_r4_q8_0<7>;
m.funcs[7] = mul_mat_q5_0_r4_q8_0<8>;
//m.funcs[0] = mul_mat_q5_0_r4_q8_0<1>;
//m.funcs[1] = mul_mat_q5_0_r4_q8_0<2>;
//m.funcs[2] = mul_mat_q5_0_r4_q8_0<3>;
//m.funcs[3] = mul_mat_q5_0_r4_q8_0<4>;
//m.funcs[4] = mul_mat_q5_0_r4_q8_0<5>;
//m.funcs[5] = mul_mat_q5_0_r4_q8_0<6>;
//m.funcs[6] = mul_mat_q5_0_r4_q8_0<7>;
//m.funcs[7] = mul_mat_q5_0_r4_q8_0<8>;
m.funcs[0] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 1>;
m.funcs[1] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 2>;
m.funcs[2] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 3>;
m.funcs[3] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 4>;
m.funcs[4] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 5>;
m.funcs[5] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 6>;
m.funcs[6] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 7>;
m.funcs[7] = mul_mat_qx_r4_q8_0<Q5_0_R4_Dequantizer, 8>;
expected_Btype = GGML_TYPE_Q8_0;
break;
case GGML_TYPE_Q6_0_R4:

View File

@@ -16569,6 +16569,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q4_0;
else chunk_size_multiplier = 4;
}
else if (new_type == GGML_TYPE_Q5_0_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q5_0;
else chunk_size_multiplier = 4;
}
else if (new_type == GGML_TYPE_Q6_0_R4) {
if (tensor->ne[1] % 4 != 0) new_type = GGML_TYPE_Q6_0;
else chunk_size_multiplier = 4;