mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-28 09:04:10 +00:00
Refactor iqk: factor out repacked iqk quants (NEON)
This commit is contained in:
@@ -2561,13 +2561,661 @@ struct DequantizerIQ2KS final : public BaseDequantizer<block_iq2_ks, true, true>
|
||||
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[8];
|
||||
int16x8x4_t iscales;
|
||||
int32x4x4_t scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
auto dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto d4 = vld1q_f32(dptr);
|
||||
const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto sas = vld1q_u8_x2(iq4[ibl].scales);
|
||||
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
|
||||
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
|
||||
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
// Adding the block shifts costs us ~9% in performance drop.
|
||||
// Is there a better way?
|
||||
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
|
||||
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
|
||||
{
|
||||
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
|
||||
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
|
||||
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
|
||||
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
|
||||
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
|
||||
auto b8 = vget_low_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vget_high_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
|
||||
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
|
||||
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
|
||||
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
|
||||
prepare_iq4_nl_quants(values, m4, bits, qx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
|
||||
isum[iy] = vdupq_n_s32(0);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq5_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m10 = vdupq_n_u8(0x10);
|
||||
auto values = vld1q_s8_x2(iq5nl_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[8];
|
||||
int16x8x4_t iscales;
|
||||
int32x4x4_t scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
auto dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto d4 = vld1q_f32(dptr);
|
||||
const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto sas = vld1q_u8_x2(iq5[ibl].scales);
|
||||
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
|
||||
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
|
||||
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
// Adding the block shifts costs us ~9% in performance drop.
|
||||
// Is there a better way?
|
||||
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1);
|
||||
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1);
|
||||
{
|
||||
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
|
||||
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
|
||||
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
|
||||
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
|
||||
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
|
||||
auto b8 = vget_low_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vget_high_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
|
||||
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
|
||||
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
|
||||
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
|
||||
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
|
||||
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4)));
|
||||
qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2)));
|
||||
qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits));
|
||||
qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2)));
|
||||
qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3)));
|
||||
qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1)));
|
||||
qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1)));
|
||||
qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3)));
|
||||
for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
|
||||
isum[iy] = vdupq_n_s32(0);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y, int k_shift>
|
||||
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
|
||||
int32x4_t * isum) {
|
||||
auto ms = vdupq_n_s8(k_shift);
|
||||
int8x16_t s8_1, s8_2;
|
||||
if constexpr (k_shift == 5) {
|
||||
auto m1 = vdupq_n_u8(1);
|
||||
s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
} else {
|
||||
if constexpr (k_shift == 4) {
|
||||
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
|
||||
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
|
||||
} else {
|
||||
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1)));
|
||||
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1)));
|
||||
}
|
||||
}
|
||||
auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
|
||||
auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
|
||||
auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
|
||||
auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
if constexpr (k_shift == 5) {
|
||||
auto m1 = vdupq_n_u8(1);
|
||||
s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
} else {
|
||||
if constexpr (k_shift == 4) {
|
||||
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
|
||||
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
|
||||
} else {
|
||||
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3)));
|
||||
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5)));
|
||||
}
|
||||
}
|
||||
s16_1 = vmovl_s8(vget_low_s8 (s8_1));
|
||||
s16_2 = vmovl_s8(vget_high_s8(s8_1));
|
||||
s16_3 = vmovl_s8(vget_low_s8 (s8_2));
|
||||
s16_4 = vmovl_s8(vget_high_s8(s8_2));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m03 = vdupq_n_u8(0x03);
|
||||
auto ms = vdupq_n_u8(4);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
auto values8 = vld1_s8(iq2nl_values);
|
||||
auto values = vcombine_s8(values8, values8);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
|
||||
auto extra8 = vld1_u8(iq2[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq2[ibl].scales);
|
||||
i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8));
|
||||
i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8));
|
||||
i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8));
|
||||
i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8));
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
|
||||
qx[0] = vandq_u8( bits.val[0], m03);
|
||||
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03);
|
||||
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03);
|
||||
qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03);
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
qx[0] = vandq_u8( bits.val[1], m03);
|
||||
qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03);
|
||||
qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03);
|
||||
qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03);
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
|
||||
auto m03 = vdupq_n_u8(0x03);
|
||||
auto m04 = vdupq_n_u8(0x04);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
|
||||
auto values = vld1q_s8(iq3nl_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
|
||||
auto extra8 = vld1_u8(iq3[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
|
||||
auto sh8 = vld1_u8(iq3[ibl].scales_h);
|
||||
auto sh = vcombine_u8(sh8, sh8);
|
||||
i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
|
||||
i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
|
||||
sh = vshrq_n_u8(sh, 4);
|
||||
i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
|
||||
i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
|
||||
auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
|
||||
qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
|
||||
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
|
||||
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
|
||||
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
|
||||
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
|
||||
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
|
||||
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m3 = vdupq_n_u8(0x30);
|
||||
auto ms = vdupq_n_u8(4);
|
||||
auto m32 = vdupq_n_s8(-32);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
|
||||
auto extra8 = vld1_u8(iq4[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
|
||||
auto sh = vld1q_u8(iq4[ibl].scales_h);
|
||||
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
|
||||
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
|
||||
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
|
||||
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows
|
||||
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7
|
||||
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11
|
||||
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15
|
||||
}
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
|
||||
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
|
||||
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
|
||||
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19
|
||||
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23
|
||||
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27
|
||||
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m3 = vdupq_n_u8(0x30);
|
||||
auto ms = vdupq_n_u8(2);
|
||||
auto m32 = vdupq_n_s8(-32);
|
||||
auto m10 = vdupq_n_u8(0x10);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
auto values = vld1q_s8_x2(iq5nl_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
|
||||
auto extra8 = vld1_u8(iq5[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
|
||||
auto sh = vld1q_u8(iq5[ibl].scales_h);
|
||||
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
|
||||
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
|
||||
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
|
||||
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
|
||||
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
|
||||
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2
|
||||
qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2
|
||||
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2
|
||||
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 1));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
|
||||
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
|
||||
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
|
||||
}
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2
|
||||
qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2
|
||||
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2
|
||||
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
|
||||
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
|
||||
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, [[maybe_unused]] mul_mat_t& func16) {
|
||||
|
||||
auto etypeA = ggml_type(typeA);
|
||||
auto expected_type_B = etypeA == GGML_TYPE_IQ4_KS_R4 || etypeA == GGML_TYPE_IQ5_KS_R4 ? GGML_TYPE_Q8_K32 : GGML_TYPE_Q8_K;
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != expected_type_B) {
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -2599,33 +3247,20 @@ bool iqk_set_kernels_iqk_quants(int ne00, int typeA, int typeB, std::array<mul_m
|
||||
case GGML_TYPE_IQ6_K:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ6K, kernels);
|
||||
break;
|
||||
// case GGML_TYPE_IQ2_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
|
||||
// break;
|
||||
// case GGML_TYPE_IQ3_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
|
||||
//#ifdef HAVE_FANCY_SIMD
|
||||
// func16 = mul_mat_iq3_k_r4_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
// case GGML_TYPE_IQ4_K_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
|
||||
// func16 = mul_mat_iq4_k_r4_q8_k<16>;
|
||||
// break;
|
||||
// case GGML_TYPE_IQ4_KS_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
|
||||
//#ifndef HAVE_FANCY_SIMD
|
||||
// // For some reason Zen4 does not like this particular function
|
||||
// func16 = mul_mat_iq4_ks_r4_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
// case GGML_TYPE_IQ5_KS_R4:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
|
||||
//#ifndef HAVE_FANCY_SIMD
|
||||
// // For some reason Zen4 does not like this particular function
|
||||
// func16 = mul_mat_iq5_ks_r4_q8_k<16>;
|
||||
//#endif
|
||||
// break;
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_k_r4_q8_k, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_k_r4_q8_k, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_k_r4_q8_k, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_ks_r4_q8_k, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq5_ks_r4_q8_k, kernels);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -878,44 +878,6 @@ template <int nrc> struct Q8_K64 {
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
struct DequantizerIQ1BN {
|
||||
const uint8x16_t m1 = vdupq_n_u8(1);
|
||||
|
||||
static inline uint8x16x4_t load_shuffles() {
|
||||
static const uint8_t data[64] = {0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 12,
|
||||
3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 12,
|
||||
6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 12,
|
||||
9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12};
|
||||
return vld1q_u8_x4(data);
|
||||
}
|
||||
static inline uint8x16x4_t load_mult() {
|
||||
static const uint8_t data[64] = {81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 27,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 9,
|
||||
81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 81, 27, 9, 3, 1, 3};
|
||||
return vld1q_u8_x4(data);
|
||||
}
|
||||
const uint8x16x4_t shuff = load_shuffles();
|
||||
const uint8x16x4_t mult = load_mult();
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants(const block_iq1_bn * x, int8x16x4_t& v) const {
|
||||
auto data = vld1q_u8((const uint8_t *)x);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
|
||||
val = vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6);
|
||||
v.val[k] = vsubq_s8(vreinterpretq_s8_u8(val), m1);
|
||||
}
|
||||
}
|
||||
|
||||
IQK_ALWAYS_INLINE void prepare_iq1bn_quants_nosub(const block_iq1_bn * x, int8x16x4_t& v) const {
|
||||
auto data = vld1q_u8((const uint8_t *)x);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto val = vmulq_u8(vqtbl1q_u8(data, shuff.val[k]), mult.val[k]);
|
||||
v.val[k] = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(val, vshrq_n_u8(val, 1)), 6));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <int nrc> struct Q8_16 {
|
||||
|
||||
constexpr static int nrc_y = nrc;
|
||||
@@ -938,166 +900,6 @@ template <int nrc> struct Q8_16 {
|
||||
const int8_t * y[nrc_y];
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_ks_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[8];
|
||||
int16x8x4_t iscales;
|
||||
int32x4x4_t scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
auto dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto d4 = vld1q_f32(dptr);
|
||||
const block_iq4_ks_r4 * iq4 = (const block_iq4_ks_r4 *)(dptr + 4);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto sas = vld1q_u8_x2(iq4[ibl].scales);
|
||||
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
|
||||
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
|
||||
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
// Adding the block shifts costs us ~9% in performance drop.
|
||||
// Is there a better way?
|
||||
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 2);
|
||||
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 2);
|
||||
{
|
||||
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
|
||||
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
|
||||
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
|
||||
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
|
||||
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
|
||||
auto b8 = vget_low_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vget_high_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
|
||||
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
|
||||
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
|
||||
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
|
||||
prepare_iq4_nl_quants(values, m4, bits, qx);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
|
||||
isum[iy] = vdupq_n_s32(0);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq5_ks_r4_q8_k_neon(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m10 = vdupq_n_u8(0x10);
|
||||
auto values = vld1q_s8_x2(iq5nl_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[8];
|
||||
int16x8x4_t iscales;
|
||||
int32x4x4_t scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
auto dptr = (const float *)((const char *)vx + ix*bx);
|
||||
auto d4 = vld1q_f32(dptr);
|
||||
const block_iq5_ks_r4 * iq5 = (const block_iq5_ks_r4 *)(dptr + 4);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto sas = vld1q_u8_x2(iq5[ibl].scales);
|
||||
auto scale = vandq_u8(sas.val[0], vdupq_n_u8(254));
|
||||
iscales.val[0] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[1] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
scale = vandq_u8(sas.val[1], vdupq_n_u8(254));
|
||||
iscales.val[2] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8 (scale))), vdupq_n_s16(-127));
|
||||
iscales.val[3] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(scale))), vdupq_n_s16(-127));
|
||||
// Adding the block shifts costs us ~9% in performance drop.
|
||||
// Is there a better way?
|
||||
sas.val[0] = vshlq_n_u8(vandq_u8(sas.val[0], vdupq_n_u8(1)), 1);
|
||||
sas.val[1] = vshlq_n_u8(vandq_u8(sas.val[1], vdupq_n_u8(1)), 1);
|
||||
{
|
||||
auto s16_1 = vmulq_s16(iscales.val[0], vmovl_u8(vget_low_u8 (sas.val[0])));
|
||||
auto s16_2 = vmulq_s16(iscales.val[1], vmovl_u8(vget_high_u8(sas.val[0])));
|
||||
auto s16_3 = vmulq_s16(iscales.val[2], vmovl_u8(vget_low_u8 (sas.val[1])));
|
||||
auto s16_4 = vmulq_s16(iscales.val[3], vmovl_u8(vget_high_u8(sas.val[1])));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto bsums = vld1q_s16_x2(q8.y[iy][ibl].bsums);
|
||||
auto bs = vpaddq_s16(bsums.val[0], bsums.val[1]);
|
||||
auto b8 = vget_low_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vget_high_s16(bs);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
scales.val[0] = vmovl_s16(vget_low_s16 (iscales.val[2*is+0]));
|
||||
scales.val[1] = vmovl_s16(vget_high_s16(iscales.val[2*is+0]));
|
||||
scales.val[2] = vmovl_s16(vget_low_s16 (iscales.val[2*is+1]));
|
||||
scales.val[3] = vmovl_s16(vget_high_s16(iscales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
|
||||
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
|
||||
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4)));
|
||||
qx[1] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2)));
|
||||
qx[2] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits));
|
||||
qx[3] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2)));
|
||||
qx[4] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3)));
|
||||
qx[5] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1)));
|
||||
qx[6] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1)));
|
||||
qx[7] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3)));
|
||||
for (int l = 0; l < 8; ++l) qx[l] = vqtbl2q_s8(values, qx[l]);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8_x2(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales.val[ib], sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vdupq_n_f32(q8.scale(iy, ibl)), vcvtq_f32_s32(isum[iy]));
|
||||
isum[iy] = vdupq_n_s32(0);
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vmulq_f32(d4, acc[iy]));
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_xxs_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
@@ -1407,496 +1209,6 @@ static void mul_mat_iq3_s_r4_q8_k(int n, const void * vx, size_t bx, const DataI
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y, int k_shift>
|
||||
inline void iq3_4_add_shift(int ibl, const Q8<nrc_y, block_q8_K>& q8, const int8x16x4_t& i8scales, uint8x16_t extra,
|
||||
int32x4_t * isum) {
|
||||
auto ms = vdupq_n_s8(k_shift);
|
||||
int8x16_t s8_1, s8_2;
|
||||
if constexpr (k_shift == 5) {
|
||||
auto m1 = vdupq_n_u8(1);
|
||||
s8_1 = vmulq_s8(i8scales.val[0], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
s8_2 = vmulq_s8(i8scales.val[1], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
} else {
|
||||
if constexpr (k_shift == 4) {
|
||||
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 2)));
|
||||
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, extra));
|
||||
} else {
|
||||
s8_1 = vmulq_s8(i8scales.val[0], vandq_u8(ms, vshlq_n_u8(extra, 1)));
|
||||
s8_2 = vmulq_s8(i8scales.val[1], vandq_u8(ms, vshrq_n_u8(extra, 1)));
|
||||
}
|
||||
}
|
||||
auto s16_1 = vmovl_s8(vget_low_s8 (s8_1));
|
||||
auto s16_2 = vmovl_s8(vget_high_s8(s8_1));
|
||||
auto s16_3 = vmovl_s8(vget_low_s8 (s8_2));
|
||||
auto s16_4 = vmovl_s8(vget_high_s8(s8_2));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto b8 = vld1_s16(q8.y[iy][ibl].bsums);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vld1_s16(q8.y[iy][ibl].bsums+4);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
if constexpr (k_shift == 5) {
|
||||
auto m1 = vdupq_n_u8(1);
|
||||
s8_1 = vmulq_s8(i8scales.val[2], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
s8_2 = vmulq_s8(i8scales.val[3], vandq_s8(ms, vceqq_u8(vandq_u8(extra, m1), m1))); extra = vshrq_n_u8(extra, 2);
|
||||
} else {
|
||||
if constexpr (k_shift == 4) {
|
||||
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 2)));
|
||||
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 4)));
|
||||
} else {
|
||||
s8_1 = vmulq_s8(i8scales.val[2], vandq_u8(ms, vshrq_n_u8(extra, 3)));
|
||||
s8_2 = vmulq_s8(i8scales.val[3], vandq_u8(ms, vshrq_n_u8(extra, 5)));
|
||||
}
|
||||
}
|
||||
s16_1 = vmovl_s8(vget_low_s8 (s8_1));
|
||||
s16_2 = vmovl_s8(vget_high_s8(s8_1));
|
||||
s16_3 = vmovl_s8(vget_low_s8 (s8_2));
|
||||
s16_4 = vmovl_s8(vget_high_s8(s8_2));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto b8 = vld1_s16(q8.y[iy][ibl].bsums+8);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_1), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_1), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_2), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_2), b8, 3);
|
||||
b8 = vld1_s16(q8.y[iy][ibl].bsums+12);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_3), b8, 0);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_3), b8, 1);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_low_s16 (s16_4), b8, 2);
|
||||
isum[iy] = vmlal_lane_s16(isum[iy], vget_high_s16(s16_4), b8, 3);
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq2_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m03 = vdupq_n_u8(0x03);
|
||||
auto ms = vdupq_n_u8(4);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
auto values8 = vld1_s8(iq2nl_values);
|
||||
auto values = vcombine_s8(values8, values8);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq2_k_r4 * iq2 = (const block_iq2_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq2[ibl].d));
|
||||
auto extra8 = vld1_u8(iq2[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq2[ibl].scales);
|
||||
i8scales.val[0] = vaddq_s8(vandq_u8(sl.val[0], m4), vdupq_n_s8(-8));
|
||||
i8scales.val[1] = vaddq_s8(vandq_u8(sl.val[1], m4), vdupq_n_s8(-8));
|
||||
i8scales.val[2] = vaddq_s8(vshrq_n_u8(sl.val[0], 4), vdupq_n_s8(-8));
|
||||
i8scales.val[3] = vaddq_s8(vshrq_n_u8(sl.val[1], 4), vdupq_n_s8(-8));
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 5>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
auto bits = vld1q_u8_x2(iq2[ibl].qs + 128*is + 32*ib);
|
||||
qx[0] = vandq_u8( bits.val[0], m03);
|
||||
qx[1] = vandq_u8(vshrq_n_u8(bits.val[0], 2), m03);
|
||||
qx[2] = vandq_u8(vshrq_n_u8(bits.val[0], 4), m03);
|
||||
qx[3] = vandq_u8(vshrq_n_u8(bits.val[0], 6), m03);
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
qx[0] = vandq_u8( bits.val[1], m03);
|
||||
qx[1] = vandq_u8(vshrq_n_u8(bits.val[1], 2), m03);
|
||||
qx[2] = vandq_u8(vshrq_n_u8(bits.val[1], 4), m03);
|
||||
qx[3] = vandq_u8(vshrq_n_u8(bits.val[1], 6), m03);
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq3_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto ms = nrc_y == 1 ? vdupq_n_u8(4) : vdupq_n_u8(8);
|
||||
auto m03 = vdupq_n_u8(0x03);
|
||||
auto m04 = vdupq_n_u8(0x04);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
uint8x16x2_t smask = { vcombine_u8(vdup_n_u8(1), vdup_n_u8(2)), vcombine_u8(vdup_n_u8(4), vdup_n_u8(8)) };
|
||||
auto values = vld1q_s8(iq3nl_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq3_k_r4 * iq3 = (const block_iq3_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq3[ibl].d));
|
||||
auto extra8 = vld1_u8(iq3[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq3[ibl].scales_l);
|
||||
auto sh8 = vld1_u8(iq3[ibl].scales_h);
|
||||
auto sh = vcombine_u8(sh8, sh8);
|
||||
i8scales.val[0] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[0], m4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[1] = vaddq_s8(vshlq_n_u8(vandq_u8(sl.val[1], m4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[2] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[0], 4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[3] = vaddq_s8(vshlq_n_u8(vshrq_n_u8(sl.val[1], 4), 1), vdupq_n_s8(1));
|
||||
i8scales.val[0] = vmulq_s8(i8scales.val[0], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
|
||||
i8scales.val[1] = vmulq_s8(i8scales.val[1], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
|
||||
sh = vshrq_n_u8(sh, 4);
|
||||
i8scales.val[2] = vmulq_s8(i8scales.val[2], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[0]), smask.val[0]), vdupq_n_u8(1)));
|
||||
i8scales.val[3] = vmulq_s8(i8scales.val[3], vorrq_u8(vceqq_u8(vandq_u8(sh, smask.val[1]), smask.val[1]), vdupq_n_u8(1)));
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
auto lbits = vld1q_u8_x2(iq3[ibl].qs + 128*is + 32*ib);
|
||||
auto hbits = vld1q_u8(iq3[ibl].qh + 64*is + 16*ib);
|
||||
qx[0] = vorrq_u8(vandq_u8( lbits.val[0], m03), vandq_u8(m04, vshlq_n_u8(hbits, 2)));
|
||||
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 2), m03), vandq_u8(m04, vshlq_n_u8(hbits, 1)));
|
||||
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 4), m03), vandq_u8(m04, hbits));
|
||||
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[0], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 1)));
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 3));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
qx[0] = vorrq_u8(vandq_u8( lbits.val[1], m03), vandq_u8(m04, vshrq_n_u8(hbits, 2)));
|
||||
qx[1] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 2), m03), vandq_u8(m04, vshrq_n_u8(hbits, 3)));
|
||||
qx[2] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 4), m03), vandq_u8(m04, vshrq_n_u8(hbits, 4)));
|
||||
qx[3] = vorrq_u8(vandq_u8(vshrq_n_u8(lbits.val[1], 6), m03), vandq_u8(m04, vshrq_n_u8(hbits, 5)));
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vqtbl1q_s8(values, vaddq_u8(shift, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vaddq_u8(shift, qx[1])); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vaddq_u8(shift, qx[2])); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vaddq_u8(shift, qx[3])); // 12..15
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m3 = vdupq_n_u8(0x30);
|
||||
auto ms = vdupq_n_u8(4);
|
||||
auto m32 = vdupq_n_s8(-32);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq4_k_r4 * iq4 = (const block_iq4_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq4[ibl].d));
|
||||
auto extra8 = vld1_u8(iq4[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq4[ibl].scales_l);
|
||||
auto sh = vld1q_u8(iq4[ibl].scales_h);
|
||||
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
|
||||
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
|
||||
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
|
||||
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 4>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto bits = vld1q_u8_x4(iq4[ibl].qs + 256*is + 64*ib);
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[0], m4)); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[2], m4)); // 4...7
|
||||
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4)); // 8..11
|
||||
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4)); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 2));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[0], m4))); // 0...3 from the 4 rows
|
||||
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[2], m4))); // 4...7
|
||||
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[0], 4))); // 8..11
|
||||
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[2], 4))); // 12..15
|
||||
}
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl1q_s8(values, vandq_u8(bits.val[1], m4)); // 16..19
|
||||
qx[1] = vqtbl1q_s8(values, vandq_u8(bits.val[3], m4)); // 20..23
|
||||
qx[2] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4)); // 24..27
|
||||
qx[3] = vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4)); // 28..31
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[1], m4))); // 16..19
|
||||
qx[1] = vaddq_s8(shift, vqtbl1q_s8(values, vandq_u8(bits.val[3], m4))); // 20..23
|
||||
qx[2] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[1], 4))); // 24..27
|
||||
qx[3] = vaddq_s8(shift, vqtbl1q_s8(values, vshrq_n_u8(bits.val[3], 4))); // 28..31
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq5_k_r4_q8_k(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
GGML_ASSERT(nrc_x%4 == 0);
|
||||
Q8<nrc_y, block_q8_K> q8(info);
|
||||
auto m4 = vdupq_n_u8(0xf);
|
||||
auto m3 = vdupq_n_u8(0x30);
|
||||
auto ms = vdupq_n_u8(2);
|
||||
auto m32 = vdupq_n_s8(-32);
|
||||
auto m10 = vdupq_n_u8(0x10);
|
||||
uint8x16x2_t shift_shuffle = {
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0101010100000000, 0x0303030302020202}),
|
||||
vreinterpretq_u8_u64(uint64x2_t{0x0505050504040404, 0x0707070706060606})
|
||||
};
|
||||
auto values = vld1q_s8_x2(iq5nl_values);
|
||||
int nbl = n / QK_K;
|
||||
int8x16_t qx[4];
|
||||
int8x16x4_t i8scales;
|
||||
int16x8x4_t i16scales;
|
||||
float32x4_t acc[nrc_y] = {};
|
||||
for (int ix = 0; ix < nrc_x; ix += 4) {
|
||||
const block_iq5_k_r4 * iq5 = (const block_iq5_k_r4 *)((const char *)vx + ix*bx);
|
||||
for (int ibl = 0; ibl < nbl; ++ibl) {
|
||||
auto d4 = vcvt_f32_f16(vld1_f16((const float16_t *)iq5[ibl].d));
|
||||
auto extra8 = vld1_u8(iq5[ibl].extra);
|
||||
uint8x16_t extra;
|
||||
if constexpr (nrc_y == 1) {
|
||||
extra = vcombine_u8(extra8, vshr_n_u8(extra8,1));
|
||||
} else {
|
||||
extra = vcombine_u8(extra8, extra8);
|
||||
}
|
||||
auto sl = vld1q_u8_x2(iq5[ibl].scales_l);
|
||||
auto sh = vld1q_u8(iq5[ibl].scales_h);
|
||||
i8scales.val[0] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[0], m4), vandq_u8(vshlq_n_u8(sh, 4), m3)), m32);
|
||||
i8scales.val[1] = vaddq_s8(vorrq_u8(vandq_u8(sl.val[1], m4), vandq_u8(vshlq_n_u8(sh, 2), m3)), m32);
|
||||
i8scales.val[2] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[0], 4), vandq_u8(sh, m3)), m32);
|
||||
i8scales.val[3] = vaddq_s8(vorrq_u8(vshrq_n_u8(sl.val[1], 4), vandq_u8(vshrq_n_u8(sh, 2), m3)), m32);
|
||||
int32x4_t isum[nrc_y] = {};
|
||||
if constexpr (nrc_y == 1) {
|
||||
iq3_4_add_shift<nrc_y, 2>(ibl, q8, i8scales, extra, isum);
|
||||
}
|
||||
for (int is = 0; is < 2; ++is) {
|
||||
i16scales.val[0] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+0]));
|
||||
i16scales.val[1] = vmovl_s8(vget_high_s8(i8scales.val[2*is+0]));
|
||||
i16scales.val[2] = vmovl_s8(vget_low_s8 (i8scales.val[2*is+1]));
|
||||
i16scales.val[3] = vmovl_s8(vget_high_s8(i8scales.val[2*is+1]));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
auto lbits = vld1q_u8_x4(iq5[ibl].qs + 256*is + 64*ib);
|
||||
auto hbits = vld1q_u8(iq5[ibl].qh + 64*is + 16*ib);
|
||||
qx[0] = vorrq_u8(vandq_u8(lbits.val[0], m4), vandq_u8(m10, vshlq_n_u8(hbits, 4))); // aligns with 1st half of qx[0] in AVX2
|
||||
qx[1] = vorrq_u8(vandq_u8(lbits.val[2], m4), vandq_u8(m10, hbits)); // aligns with 1st half of qx[1] in AVX2
|
||||
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[0], 4), vandq_u8(m10, vshlq_n_u8(hbits, 3))); // aligns with 1st half of qx[2] in AVX2
|
||||
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[2], 4), vandq_u8(m10, vshrq_n_u8(hbits, 1))); // aligns with 1st half of qx[3] in AVX2
|
||||
uint8x16_t shifts;
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
shifts = vandq_u8(ms, vshlq_n_u8(extra, 1));
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[0]);
|
||||
extra = vshrq_n_u8(extra, 1);
|
||||
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
|
||||
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
|
||||
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
|
||||
}
|
||||
auto scales = vmovl_s16(vget_low_s16 (i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
qx[0] = vorrq_u8(vandq_u8(lbits.val[1], m4), vandq_u8(m10, vshlq_n_u8(hbits, 2))); // aligns with 2nd half of qx[0] in AVX2
|
||||
qx[1] = vorrq_u8(vandq_u8(lbits.val[3], m4), vandq_u8(m10, vshrq_n_u8(hbits, 2))); // aligns with 2nd half of qx[1] in AVX2
|
||||
qx[2] = vorrq_u8(vshrq_n_u8(lbits.val[1], 4), vandq_u8(m10, vshlq_n_u8(hbits, 1))); // aligns with 2nd half of qx[2] in AVX2
|
||||
qx[3] = vorrq_u8(vshrq_n_u8(lbits.val[3], 4), vandq_u8(m10, vshrq_n_u8(hbits, 3))); // aligns with 2nd half of qx[3] in AVX2
|
||||
if constexpr (nrc_y == 1) {
|
||||
qx[0] = vqtbl2q_s8(values, qx[0]); // 0...3 from the 4 rows
|
||||
qx[1] = vqtbl2q_s8(values, qx[1]); // 4...7
|
||||
qx[2] = vqtbl2q_s8(values, qx[2]); // 8..11
|
||||
qx[3] = vqtbl2q_s8(values, qx[3]); // 12..15
|
||||
} else {
|
||||
auto shift = vqtbl1q_u8(shifts, shift_shuffle.val[1]);
|
||||
qx[0] = vaddq_s8(shift, vqtbl2q_s8(values, qx[0])); // 0...3 from the 4 rows
|
||||
qx[1] = vaddq_s8(shift, vqtbl2q_s8(values, qx[1])); // 4...7
|
||||
qx[2] = vaddq_s8(shift, vqtbl2q_s8(values, qx[2])); // 8..11
|
||||
qx[3] = vaddq_s8(shift, vqtbl2q_s8(values, qx[3])); // 12..15
|
||||
}
|
||||
scales = vmovl_s16(vget_high_s16(i16scales.val[ib]));
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto y = vld1q_s8(q8.y[iy][ibl].qs+128*is+32*ib+16);
|
||||
auto sumi = interleaved_dotq(qx, y);
|
||||
isum[iy] = vmlaq_s32(isum[iy], scales, sumi);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
acc[iy] = vfmaq_f32(acc[iy], vmulq_f32(d4, vdupq_n_f32(q8.scale(iy, ibl))), vcvtq_f32_s32(isum[iy]));
|
||||
}
|
||||
}
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, acc[iy]);
|
||||
acc[iy] = vdupq_n_f32(0.f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \
|
||||
m.funcs[0] = func<Dequantizer, 1>;\
|
||||
m.funcs[1] = func<Dequantizer, 2>;\
|
||||
@@ -1955,6 +1267,12 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_IQ5_KS:
|
||||
case GGML_TYPE_IQ5_K:
|
||||
case GGML_TYPE_IQ6_K:
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
case GGML_TYPE_IQ3_K_R4:
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
case GGML_TYPE_IQ5_K_R4:
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ2_XXS:
|
||||
case GGML_TYPE_IQ2_XS:
|
||||
@@ -1975,10 +1293,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_Q8_0_R8:
|
||||
case GGML_TYPE_IQ4_NL_R4:
|
||||
return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ4_KS_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_ks_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_XXS_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_xxs_r4_q8_k);
|
||||
m.func16 = mul_mat_iq2_xxs_r4_q8_k<16>;
|
||||
@@ -2011,26 +1325,6 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
m.func16 = mul_mat_iq3_s_r4_q8_k<16>;
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_IQ2_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_IQ3_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq3_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_IQ4_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_IQ5_K_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_k_r4_q8_k);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
case GGML_TYPE_IQ5_KS_R4:
|
||||
SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq5_ks_r4_q8_k_neon);
|
||||
expected_Btype = GGML_TYPE_Q8_K;
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user