Refactor iqk: GEMM kernels are refactored on NEON

This commit is contained in:
Iwan Kawrakow
2025-05-19 08:36:16 +03:00
parent 7aa2de6d5a
commit 4b4b4fdcac
2 changed files with 334 additions and 525 deletions

View File

@@ -1884,6 +1884,315 @@ struct DequantizerIQ3S final : public BaseDequantizer<block_iq3_s> {
};
template <int nrc_y>
static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto sas = vld1q_u8(iq2[ibl].sas + 16*ib);
auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
auto sumi12 = vpaddq_s32(sumi1, sumi2);
auto sumi34 = vpaddq_s32(sumi3, sumi4);
auto sumi = vpaddq_s32(sumi12, sumi34);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qs += 16;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31};
auto shuff = vld1q_u8(k_shuff);
float32x4_t acc[nrc_y] = {};
int32x4_t isum[2*nrc_y] = {};
int8x16_t qx[8];
uint16x8x4_t scales16;
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
for (int is = 0; is < 2; ++is) {
auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
auto scales2 = vshrq_n_u8(scale_bits, 4);
scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
auto s1 = vzip1q_u8(scales1, scales2);
auto s2 = vzip2q_u8(scales1, scales2);
scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
scales16.val[1] = vmovl_u8(vget_high_u8(s1));
scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
scales16.val[3] = vmovl_u8(vget_high_u8(s2));
for (int ib = 0; ib < QK_K/64; ++ib) {
auto v = vld1q_u8_x2((const uint8_t *)qs);
auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
}
qs += 16;
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[2*nrc_y] = {};
int8x16_t qx[8];
uint16x8x4_t scales16;
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
auto qh = iq2[ibl].qh;
for (int is = 0; is < 2; ++is) {
auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
auto scales2 = vshrq_n_u8(scale_bits, 4);
scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
auto s1 = vzip1q_u8(scales1, scales2);
auto s2 = vzip2q_u8(scales1, scales2);
scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
scales16.val[1] = vmovl_u8(vget_high_u8(s1));
scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
scales16.val[3] = vmovl_u8(vget_high_u8(s2));
for (int ib = 0; ib < QK_K/64; ++ib) {
auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib);
sh.init();
for (int i = 0; i < 4; ++i) {
qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128);
qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128);
}
qs += 16; qh += 4;
auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)));
auto qs = iq3[ibl].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto sas = vld1q_u8(iq3[ibl].sas + 16*ib);
auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
auto sumi12 = vpaddq_s32(sumi1, sumi2);
auto sumi34 = vpaddq_s32(sumi3, sumi4);
auto sumi = vpaddq_s32(sumi12, sumi34);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qs += 32;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
auto m1 = vdupq_n_u8(1);
auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03});
uint32_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
auto qs = iq3[ibl].qs;
auto qh = iq3[ibl].qh;
auto scale_bits = vld1q_u8(iq3[ibl].scales);
uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) };
scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1);
scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1);
vst1q_u8_x2((uint8_t *)stored_scales, scales8);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib);
if constexpr (nrc_y == 1) {
auto qh32 = (const uint32_t *)qh;
auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4}));
union { uint16x8_t vec; uint16_t val[8]; } hidx;
for (int i = 0; i < 4; ++i) {
auto idx_l = vmovl_u8(vld1_u8(qs));
hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1);
qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]});
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]});
signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
signs128 = vshrq_n_u8(signs128, 1);
qs += 8;
}
} else {
for (int i = 0; i < 4; ++i) {
qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)],
iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]});
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)],
iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]});
signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
qs += 8;
signs128 = vshrq_n_u8(signs128, 1);
}
}
auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qh += 4;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
}
bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
@@ -1910,30 +2219,26 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array<mul_mat_
case GGML_TYPE_IQ3_S:
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3S, kernels);
break;
// case GGML_TYPE_IQ2_XXS_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_r4_q8_k, kernels);
// func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
// break;
// case GGML_TYPE_IQ2_XS_R4:
// assert (ne00 % QK_K == 0);
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels);
//#ifndef HAVE_FANCY_SIMD
// // For some reason Zen4 does not like this particular function
// func16 = mul_mat_iq2_xs_r4_q8_k_16;
//#endif
// break;
// case GGML_TYPE_IQ2_S_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels);
// func16 = mul_mat_iq2_s_r4_q8_k_16;
// break;
// case GGML_TYPE_IQ3_XXS_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels);
// func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
// break;
// case GGML_TYPE_IQ3_S_R4:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels);
// func16 = mul_mat_iq3_s_r4_q8_k<16>;
// break;
case GGML_TYPE_IQ2_XXS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_r4_q8_k, kernels);
func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
break;
case GGML_TYPE_IQ2_XS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels);
func16 = mul_mat_iq2_xs_r4_q8_k<16>;
break;
case GGML_TYPE_IQ2_S_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels);
func16 = mul_mat_iq2_s_r4_q8_k<16>;
break;
case GGML_TYPE_IQ3_XXS_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels);
func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
break;
case GGML_TYPE_IQ3_S_R4:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels);
func16 = mul_mat_iq3_s_r4_q8_k<16>;
break;
default:
return false;
}

View File

@@ -300,8 +300,6 @@ struct MulMat {
}
#endif
}
private:
template <typename Dequantizer> static void set_functions(MulMat& m);
};
}
@@ -674,9 +672,6 @@ static void mul_mat_q8_KV_q8_KV_8(int n, const void * vx, size_t bx, const DataI
}
#endif
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
}
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
(void)Ny;
@@ -765,478 +760,8 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& mm, int Ny) {
namespace {
template <int nrc, typename block_q8 = block_q8_K> struct Q8 {
constexpr static int nrc_y = nrc;
Q8(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8 *)info.src1_row(iy);
}
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy][i].qs + 32*j); }
inline int8x16x4_t load_quants_64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy][i].qs + 64*j); }
inline int16x8x2_t load_bsums(int iy, int i) const { return vld1q_s16_x2(y[iy][i].bsums); }
inline int16x8_t load_bsums8(int iy, int i) const {
auto q8s = vld1q_s16_x2(y[iy][i].bsums);
return vpaddq_s16(q8s.val[0], q8s.val[1]);
}
inline float scale(int iy, int i) const { return y[iy][i].d; }
const block_q8 * y[nrc_y];
};
template <typename Q8>
inline void accum_mins_16(const int16x8x2_t& mins, const Q8& q8, float32x4_t * acc, int i, float c) {
for (int iy = 0; iy < Q8::nrc_y; ++iy) {
auto q8s = q8.load_bsums(iy, i);
int32x4_t b1 = vmull_s16(vget_low_s16 (mins.val[0]), vget_low_s16 (q8s.val[0]));
int32x4_t b2 = vmull_s16(vget_high_s16(mins.val[0]), vget_high_s16(q8s.val[0]));
int32x4_t b3 = vmull_s16(vget_low_s16 (mins.val[1]), vget_low_s16 (q8s.val[1]));
int32x4_t b4 = vmull_s16(vget_high_s16(mins.val[1]), vget_high_s16(q8s.val[1]));
float32x4_t prod = vcvtq_f32_s32(vaddq_s32(vaddq_s32(b1, b2), vaddq_s32(b3, b4)));
acc[iy] = vmlaq_f32(acc[iy], prod, vdupq_n_f32(c*q8.scale(iy, i)));
}
}
inline int32x4x4_t make_wider(const int16x8x2_t& scales16) {
int32x4x4_t scales = {
vmovl_s16(vget_low_s16 (scales16.val[0])),
vmovl_s16(vget_high_s16(scales16.val[0])),
vmovl_s16(vget_low_s16 (scales16.val[1])),
vmovl_s16(vget_high_s16(scales16.val[1])),
};
return scales;
}
// ============================= i-quants
inline int32x4x4_t make_wider_8(const int8x16_t& scales8) {
int16x8x2_t scales16{vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8))};
return make_wider(scales16);
}
struct Scale16Extra {
template <typename Q8>
static inline int32x4x4_t new_block(int i, float d, uint16_t extra, uint8_t val,
const int8x16_t& scales8, const Q8& q8, float32x4_t * acc) {
uint8x16_t e8 = vreinterpretq_u8_u16(vdupq_n_u16(extra));
e8 = vceqq_u8(vandq_u8(e8, emask), emask);
e8 = vqtbl1q_u8(vandq_u8(e8, vdupq_n_u8(val)), eshuff);
int16x8x2_t extra16 = {vmull_s8(vget_low_s8 (e8), vget_low_s8 (scales8)),
vmull_s8(vget_high_s8(e8), vget_high_s8(scales8))};
accum_mins_16(extra16, q8, acc, i, d);
return make_wider_8(scales8);
}
constexpr static uint32x4_t emask = {0x02020101, 0x08080404, 0x20201010, 0x80804040};
constexpr static uint32x4_t eshuff = {0x06040200, 0x0e0c0a08, 0x07050301, 0x0f0d0b09};
};
// Note: on ARM_NEON we cannot use the values shifted into the uint8_t range because
// the ARM_NEON only has vdotq_s32 or vdotq_u32, where both operands need to
// be signed or unsigned. As the Q8_K quants are signed, we need to have the
// iq4_s quants also signed. We can only use unsigned values in k-quants
// because they are all within the valid int8_t range.
struct SimpleBits {
uint8x16x4_t b1;
uint8x16x4_t b2;
};
inline int32x4x2_t prepare_scales_8(const uint32x4_t& v1, const uint32x4_t& v2) {
int32x4x2_t scales;
scales.val[0] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v1, 28), 1), vdupq_n_u32(1)));
scales.val[1] = vreinterpretq_s32_u32(vorrq_u32(vshlq_n_u32(vshrq_n_u32(v2, 28), 1), vdupq_n_u32(1)));
return scales;
}
inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) {
auto s1 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127))));
auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127))));
b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1));
b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2));
}
template <int nrc> struct Q8_K64 {
constexpr static int nrc_y = nrc;
Q8_K64(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto dptr = (const float *)info.src1_row(iy);
std::memcpy(d + 8*iy, dptr, 8*sizeof(float));
y[iy] = (const int8_t *)(dptr + 8);
}
}
inline int8x16x4_t load_quants64(int iy, int i, int j) const { return vld1q_s8_x4(y[iy] + 128*i + 64*j); }
inline int8x16x2_t load_quants(int iy, int i, int j) const { return vld1q_s8_x2(y[iy] + 128*i + 32*j); }
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 8*iy); }
inline float32x4_t minus(int iy) const { return vld1q_f32(d + 8*iy + 4); }
float d[8*nrc_y];
const int8_t * y[nrc_y];
};
template <int nrc> struct Q8_16 {
constexpr static int nrc_y = nrc;
Q8_16(const DataInfo& info) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto ptr = (const float *)info.src1_row(iy);
std::memcpy(d + 5*iy, ptr, 5*sizeof(float));
y[iy] = (const int8_t *)(ptr + 5);
}
}
inline int8x16x4_t load_quants(int iy, int i) const { return vld1q_s8_x4(y[iy] + 64*i); }
inline int8x16x2_t load_quants_32(int iy, int i) const { return vld1q_s8_x2(y[iy] + 32*i); }
inline float scale(int iy, int k) const { return d[5*iy+k]; }
inline float sum_row(int iy) const { return d[5*iy + 4]; }
inline float32x4_t scale(int iy) const { return vld1q_f32(d + 5*iy); }
float d[5*nrc_y];
const int8_t * y[nrc_y];
};
template <int nrc_y>
static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_xxs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto sas = vld1q_u8(iq2[ibl].sas + 16*ib);
auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xxs_grid[qs[2*i+0]], iq2xxs_grid[qs[2*i+1]]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
auto sumi12 = vpaddq_s32(sumi1, sumi2);
auto sumi34 = vpaddq_s32(sumi3, sumi4);
auto sumi = vpaddq_s32(sumi12, sumi34);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qs += 16;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq2_xs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
static const uint8_t k_shuff[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31};
auto shuff = vld1q_u8(k_shuff);
float32x4_t acc[nrc_y] = {};
int32x4_t isum[2*nrc_y] = {};
int8x16_t qx[8];
uint16x8x4_t scales16;
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_xs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
for (int is = 0; is < 2; ++is) {
auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
auto scales2 = vshrq_n_u8(scale_bits, 4);
scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
auto s1 = vzip1q_u8(scales1, scales2);
auto s2 = vzip2q_u8(scales1, scales2);
scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
scales16.val[1] = vmovl_u8(vget_high_u8(s1));
scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
scales16.val[3] = vmovl_u8(vget_high_u8(s2));
for (int ib = 0; ib < QK_K/64; ++ib) {
auto v = vld1q_u8_x2((const uint8_t *)qs);
auto signs128 = vandq_u8(vqtbl2q_u8(v, shuff), vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u64(uint64x2_t{iq2xs_grid[qs[2*i+0] & 511], iq2xs_grid[qs[2*i+1] & 511]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
}
qs += 16;
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq2_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[2*nrc_y] = {};
int8x16_t qx[8];
uint16x8x4_t scales16;
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq2 = (const block_iq2_s_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
auto qs = iq2[ibl].qs;
auto qh = iq2[ibl].qh;
for (int is = 0; is < 2; ++is) {
auto scale_bits = vld1q_u8(iq2[ibl].scales + 16*is);
auto scales1 = vandq_u8(scale_bits, vdupq_n_u8(0xf));
auto scales2 = vshrq_n_u8(scale_bits, 4);
scales1 = vorrq_u8(vshlq_n_u8(scales1, 1), vdupq_n_u8(1));
scales2 = vorrq_u8(vshlq_n_u8(scales2, 1), vdupq_n_u8(1));
auto s1 = vzip1q_u8(scales1, scales2);
auto s2 = vzip2q_u8(scales1, scales2);
scales16.val[0] = vmovl_u8(vget_low_u8 (s1));
scales16.val[1] = vmovl_u8(vget_high_u8(s1));
scales16.val[2] = vmovl_u8(vget_low_u8 (s2));
scales16.val[3] = vmovl_u8(vget_high_u8(s2));
for (int ib = 0; ib < QK_K/64; ++ib) {
auto signs128 = vld1q_u8(iq2[ibl].signs + 64*is + 16*ib);
sh.init();
for (int i = 0; i < 4; ++i) {
qx[2*i+0] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+0] | ((qh[i] << 8) & 0x300)], iq2s_grid[qs[4*i+1] | ((qh[i] << 6) & 0x300)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+0, signs128);
qx[2*i+1] = vreinterpretq_s8_u64(uint64x2_t{iq2s_grid[qs[4*i+2] | ((qh[i] << 4) & 0x300)], iq2s_grid[qs[4*i+3] | ((qh[i] << 2) & 0x300)]});
sh.apply_signs_1((uint8x16_t *)qx+2*i+1, signs128);
}
qs += 16; qh += 4;
auto s32_1 = vmovl_u16(vget_low_u16 (scales16.val[ib]));
auto s32_2 = vmovl_u16(vget_high_u16(scales16.val[ib]));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 128*is + 32*ib);
auto sumi1 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[1], y.val[1]));
auto sumi2 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[3], y.val[1]));
auto sumi3 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[5], y.val[1]));
auto sumi4 = vpaddq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), ggml_vdotq_s32(vdupq_n_s32(0), qx[7], y.val[1]));
auto sumi12 = vpaddq_s32(sumi1, sumi2); // blocks 0,1,2,3 in rows 0,1
auto sumi34 = vpaddq_s32(sumi3, sumi4); // blocks 4,5,6,7 in rows 2,3
isum[2*iy+0] = vmlaq_s32(isum[2*iy+0], s32_1, sumi12);
isum[2*iy+1] = vmlaq_s32(isum[2*iy+1], s32_2, sumi34);
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto sumi = vpaddq_s32(isum[2*iy+0], isum[2*iy+1]);
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(sumi));
isum[2*iy] = isum[2*iy+1] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vmulq_f32(vdupq_n_f32(0.125f), acc[iy]));
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq3_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
SignHelper sh;
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq3 = (const block_iq3_xxs_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vmulq_f32(vdupq_n_f32(0.25f), vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d)));
auto qs = iq3[ibl].qs;
for (int ib = 0; ib < QK_K/32; ++ib) {
auto sas = vld1q_u8(iq3[ibl].sas + 16*ib);
auto scale_bits = vandq_u8(sas, vdupq_n_u8(1));
auto scales = ggml_vdotq_s32(vdupq_n_s32(1), scale_bits, vreinterpretq_s8_u32(vdupq_n_u32(0x10080402)));
auto signs128 = vandq_u8(sas, vdupq_n_u8(254));
signs128 = veorq_u8(signs128, vshrq_n_u8(signs128, 1));
sh.init();
for (int i = 0; i < 8; ++i) {
qx[i] = vreinterpretq_s8_u32(uint32x4_t{iq3xxs_grid[qs[4*i+0]], iq3xxs_grid[qs[4*i+1]], iq3xxs_grid[qs[4*i+2]], iq3xxs_grid[qs[4*i+3]]});
sh.apply_signs_1((uint8x16_t *)qx+i, signs128);
}
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[0], y.val[0]), qx[1], y.val[1]);
auto sumi2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[2], y.val[0]), qx[3], y.val[1]);
auto sumi3 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[4], y.val[0]), qx[5], y.val[1]);
auto sumi4 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), qx[6], y.val[0]), qx[7], y.val[1]);
auto sumi12 = vpaddq_s32(sumi1, sumi2);
auto sumi34 = vpaddq_s32(sumi3, sumi4);
auto sumi = vpaddq_s32(sumi12, sumi34);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qs += 32;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
template <int nrc_y>
static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
GGML_ASSERT(nrc_x%4 == 0);
Q8<nrc_y, block_q8_K> q8(info);
int nbl = n / QK_K;
float32x4_t acc[nrc_y] = {};
int32x4_t isum[nrc_y] = {};
int8x16_t qx[8];
auto m1 = vdupq_n_u8(1);
auto shuff = vreinterpretq_u8_u32(uint32x4_t{0xffffff00, 0xffffff01, 0xffffff02, 0xffffff03});
uint32_t stored_scales[8];
for (int ix = 0; ix < nrc_x; ix += 4) {
auto iq3 = (const block_iq3_s_r4 *)((const char *)vx + (ix+0)*bx);
for (int ibl = 0; ibl < nbl; ++ibl) { // Block of 256
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
auto qs = iq3[ibl].qs;
auto qh = iq3[ibl].qh;
auto scale_bits = vld1q_u8(iq3[ibl].scales);
uint8x16x2_t scales8 = { vandq_u8(scale_bits, vdupq_n_u8(0xf)), vshrq_n_u8(scale_bits, 4) };
scales8.val[0] = vorrq_u8(vshlq_n_u8(scales8.val[0], 1), m1);
scales8.val[1] = vorrq_u8(vshlq_n_u8(scales8.val[1], 1), m1);
vst1q_u8_x2((uint8_t *)stored_scales, scales8);
for (int ib = 0; ib < QK_K/32; ++ib) {
auto signs128 = vld1q_u8(iq3[ibl].signs+16*ib);
if constexpr (nrc_y == 1) {
auto qh32 = (const uint32_t *)qh;
auto idx_h = vreinterpretq_u16_u64(vshlq_u64(vreinterpretq_u64_u16(vmovl_u8(vreinterpret_u8_u32(vdup_n_u32(qh32[0])))), int64x2_t{8, 4}));
union { uint16x8_t vec; uint16_t val[8]; } hidx;
for (int i = 0; i < 4; ++i) {
auto idx_l = vmovl_u8(vld1_u8(qs));
hidx.vec = vorrq_u16(idx_l, vandq_u16(idx_h, vdupq_n_u16(0x100))); idx_h = vshrq_n_u16(idx_h, 1);
qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[0]], iq3s_grid[hidx.val[1]], iq3s_grid[hidx.val[2]], iq3s_grid[hidx.val[3]]});
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[hidx.val[4]], iq3s_grid[hidx.val[5]], iq3s_grid[hidx.val[6]], iq3s_grid[hidx.val[7]]});
signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
signs128 = vshrq_n_u8(signs128, 1);
qs += 8;
}
} else {
for (int i = 0; i < 4; ++i) {
qx[2*i+0] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[0] | ((qh[0] << (8-i)) & 0x100)], iq3s_grid[qs[1] | ((qh[1] << (8-i)) & 0x100)],
iq3s_grid[qs[2] | ((qh[2] << (8-i)) & 0x100)], iq3s_grid[qs[3] | ((qh[3] << (8-i)) & 0x100)]});
auto signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(signs128, m1), m1), m1));
qx[2*i+0] = vmulq_s8(qx[2*i+0], signs);
qx[2*i+1] = vreinterpretq_s8_u32(uint32x4_t{iq3s_grid[qs[4] | ((qh[0] << (4-i)) & 0x100)], iq3s_grid[qs[5] | ((qh[1] << (4-i)) & 0x100)],
iq3s_grid[qs[6] | ((qh[2] << (4-i)) & 0x100)], iq3s_grid[qs[7] | ((qh[3] << (4-i)) & 0x100)]});
signs = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(vshrq_n_u8(signs128, 4), m1), m1), m1));
qx[2*i+1] = vmulq_s8(qx[2*i+1], signs);
qs += 8;
signs128 = vshrq_n_u8(signs128, 1);
}
}
auto scales = vreinterpretq_s32_u8(vqtbl1q_u8(vreinterpretq_u8_u32(vdupq_n_u32(stored_scales[ib])), shuff));
for (int iy = 0; iy < nrc_y; ++iy) {
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs + 32*ib);
auto sumi = interleaved_dotq(qx, y);
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
}
qh += 4;
}
for (int iy = 0; iy < nrc_y; ++iy) {
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
isum[iy] = vdupq_n_s32(0);
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, acc[iy]);
acc[iy] = vdupq_n_f32(0.f);
}
}
}
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
m.funcs[0] = func<Dequantizer, 1>;\
m.funcs[1] = func<Dequantizer, 2>;\
m.funcs[2] = func<Dequantizer, 3>;\
m.funcs[3] = func<Dequantizer, 4>;\
m.funcs[4] = func<Dequantizer, 5>;\
m.funcs[5] = func<Dequantizer, 6>;\
m.funcs[6] = func<Dequantizer, 7>;\
m.funcs[7] = func<Dequantizer, 8>;\
#define SET_MUL_MAT_FUNCTIONS(m, func) \
m.funcs[0] = func<1>;\
m.funcs[1] = func<2>;\
m.funcs[2] = func<3>;\
m.funcs[3] = func<4>;\
m.funcs[4] = func<5>;\
m.funcs[5] = func<6>;\
m.funcs[6] = func<7>;\
m.funcs[7] = func<8>;\
template <typename Dequantizer> void MulMat::set_functions(MulMat& m) {
SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer);
}
bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
auto expected_Btype = GGML_TYPE_Q8_K;
switch (typeA) {
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
@@ -1279,6 +804,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ2_XXS_R4:
case GGML_TYPE_IQ2_XS_R4:
case GGML_TYPE_IQ2_S_R4:
case GGML_TYPE_IQ3_XXS_R4:
case GGML_TYPE_IQ3_S_R4:
return iqk_set_kernels_iquants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
@@ -1293,21 +823,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_Q8_0_R8:
case GGML_TYPE_IQ4_NL_R4:
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ2_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k);
m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_XS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xs_r4_q8_k);
m.func16 = mul_mat_iq2_xs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ2_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_s_r4_q8_k);
m.func16 = mul_mat_iq2_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ1_BN:
case GGML_TYPE_IQ2_BN:
case GGML_TYPE_IQ2_BN_R4:
@@ -1315,21 +830,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
case GGML_TYPE_IQ1_S_R4:
case GGML_TYPE_IQ1_M_R4:
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);
case GGML_TYPE_IQ3_XXS_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_xxs_r4_q8_k);
m.func16 = mul_mat_iq3_xxs_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
case GGML_TYPE_IQ3_S_R4:
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_s_r4_q8_k);
m.func16 = mul_mat_iq3_s_r4_q8_k<16>;
expected_Btype = GGML_TYPE_Q8_K;
break;
default:
return false;
}
return typeB == expected_Btype;
}
}