iq4_kt: not working NEON implementation

This commit is contained in:
Iwan Kawrakow
2025-05-25 08:58:29 +03:00
parent e0fedaeb07
commit 465fe3b78d

View File

@@ -364,18 +364,35 @@ struct Trellis1 {
inline uint32x4x2_t next8(uint32_t val) const {
auto mval = vdupq_n_u32(val);
uint32x4x2_t mres;
// This does not seem to be faster
//mres.val[0] = vmlaq_u32(mkb.val[0], mka.val[0], mval);
//mres.val[1] = vmlaq_u32(mkb.val[1], mka.val[1], mval);
mres.val[0] = vaddq_u32(vmulq_u32(mval, mka.val[0]), mkb.val[0]);
mres.val[1] = vaddq_u32(vmulq_u32(mval, mka.val[1]), mkb.val[1]);
mres.val[0] = veorq_u32(vandq_u32(mres.val[0], mask1), mask2);
mres.val[1] = veorq_u32(vandq_u32(mres.val[1], mask1), mask2);
return mres;
}
inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const {
auto mval1 = vdupq_n_u32(val1);
auto mval2 = vdupq_n_u32(val2);
uint32x4x2_t mres;
// This does not seem to be faster
//mres.val[0] = vmlaq_u32(mkb.val[0], mka.val[0], mval1);
//mres.val[1] = vmlaq_u32(mkb.val[0], mka.val[0], mval2);
mres.val[0] = vaddq_u32(vmulq_u32(mval1, mka.val[0]), mkb.val[0]);
mres.val[1] = vaddq_u32(vmulq_u32(mval2, mka.val[0]), mkb.val[0]);
mres.val[0] = veorq_u32(vandq_u32(mres.val[0], mask1), mask2);
mres.val[1] = veorq_u32(vandq_u32(mres.val[1], mask1), mask2);
return mres;
}
static inline float16x8_t gen8(const uint32x4x2_t& i8) {
auto fv1 = vreinterpretq_f16_u32(i8.val[0]);
auto fv2 = vreinterpretq_f16_u32(i8.val[1]);
return vpaddq_f16(fv1, fv2);
}
inline float16x8_t gen8(uint32_t val) const { return gen8(next8(val)); }
inline float16x8_t gen8(uint32_t val1, uint32_t val2) const { return gen8(next8(val1, val2)); }
};
template <int nrc_y>
@@ -502,6 +519,90 @@ static void mul_mat_iq3_kt_F16_T(int n, const void * vx, size_t bx, const DataIn
}
}
template <int nrc_y>
static void mul_mat_iq4_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
Trellis1 trellis;
union { float16x8_t vec; float16_t val[8]; } s_helper;
union { uint16x8_t vec; uint16_t val[8]; } o_helper;
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
float16x8_t accd[k_acc];
const float16_t * y[nrc_y];
float row_sum[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const float16_t *)info.src1_row(iy);
auto sum = vdupq_n_f16(0);
for (int i = 0; i < n/8; ++i) sum = vaddq_f16(sum, vld1q_f16(y[iy] + 8*i));
auto sum32 = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum)), vcvt_f32_f16(vget_high_f16(sum)));
//auto sum32 = vdupq_n_f32(0);
//for (int i = 0; i < n/4; ++i) sum32 = vaddq_f32(sum32, vcvt_f32_f16(vld1_f16(y[iy] + 4*i)));
row_sum[iy] = vaddvq_f32(sum32);
}
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = dptr[0] * 31.75f * 1.01f;
auto dav = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f16(0);
for (int i = 0; i < nb; ++i) {
const uint32_t * shb = x[i].qs;
auto vshb = vld1q_u32_x2(shb);
auto vshb16 = vcombine_u16(vmovn_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff))), vmovn_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff))));
const uint8_t * ql = (const uint8_t *)(shb + 8);
const uint8_t * qh = ql + kNumGroups;
auto iscales = vsubq_s16(vreinterpretq_s16_u16(vshrq_n_u16(vshb16, 1)), vdupq_n_s16(64));
s_helper.vec = vcvtq_f16_s16(iscales);
o_helper.vec = vaddq_u16(vshlq_n_u16(vandq_u16(vshb16, vdupq_n_u16(1)), 15), vdupq_n_u16(4096));
for (int ib = 0; ib < 4; ++ib) {
auto scale1 = vdupq_n_f16(s_helper.val[ib+0]);
auto scale2 = vdupq_n_f16(s_helper.val[ib+4]);
for (int j = 0; j < 4; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
uint32_t val1 = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
uint32_t val2 = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
uint32_t val3 = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
uint32_t val4 = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
auto x_val1 = vmulq_f16(scale1, trellis.gen8(val1, val3));
auto x_val2 = vmulq_f16(scale2, trellis.gen8(val2, val4));
if constexpr (nrc_y == 1) {
auto y1 = vld1q_f16(y[0] + i*QK_K+32*ib+8*j+ 0);
auto y2 = vld1q_f16(y[0] + i*QK_K+32*ib+8*j+128);
accd[0] = vfmaq_f16(accd[0], y1, x_val1);
accd[1] = vfmaq_f16(accd[1], y2, x_val2);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto y1 = vld1q_f16(y[iy] + i*QK_K+32*ib+8*j+ 0);
auto y2 = vld1q_f16(y[iy] + i*QK_K+32*ib+8*j+128);
accd[iy] = vfmaq_f16(accd[iy], y1, x_val1);
accd[iy] = vfmaq_f16(accd[iy], y2, x_val2);
}
}
}
}
}
if constexpr (nrc_y == 1) {
auto sum16 = vaddq_f16(accd[0], accd[1]);
auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(sum16)), vcvt_f32_f16(vget_high_f16(sum16)));
info.store(ix, 0, d*vaddvq_f32(sum) + dav*row_sum[0]);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum = vaddq_f32(vcvt_f32_f16(vget_low_f16(accd[iy])), vcvt_f32_f16(vget_high_f16(accd[iy])));
info.store(ix, iy, d*vaddvq_f32(sum) + dav*row_sum[iy]);
}
}
}
}
}
@@ -520,9 +621,9 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
case GGML_TYPE_IQ3_KT:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_F16_T, kernels);
break;
//case GGML_TYPE_IQ4_KT:
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
// break;
case GGML_TYPE_IQ4_KT:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F16_T, kernels);
break;
default:
return false;
}