New iq4_kt: NEON implementation

We get very respectable PP-512 = 120 t/s.
TG-128 is pathetic at 5.3 t/s, so 20+% slower than the f16 variant.
This commit is contained in:
Iwan Kawrakow
2025-06-08 08:31:56 +03:00
parent 608e0f497b
commit 68ef8a7ae9
2 changed files with 205 additions and 4 deletions

View File

@@ -1199,10 +1199,213 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
}
}
struct Trellis3 {
constexpr static uint32_t ka = 89226354;
constexpr static uint32_t kb = 64248484;
constexpr static uint32_t ka1 = ka*ka;
constexpr static uint32_t kb1 = kb*ka+kb;
constexpr static uint32_t ka2 = ka1*ka;
constexpr static uint32_t kb2 = kb1*ka+kb;
constexpr static uint32_t ka3 = ka2*ka;
constexpr static uint32_t kb3 = kb2*ka+kb;
const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3};
const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3};
const uint8x16_t shuffle = load_shuffle();
inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const {
uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)};
result.val[0] = vmlaq_u32(mkb, mka, result.val[0]);
result.val[1] = vmlaq_u32(mkb, mka, result.val[1]);
return result;
}
//inline int8x16x2_t next32(const uint32_t * val) const {
// int8x16x4_t aux;
// int8x16x2_t result;
// for (int i = 0; i < 2; ++i) {
// auto i8 = next8(val[4*i+0], val[4*i+1]);
// i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
// i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
// aux.val[0] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])));
// aux.val[1] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])));
// i8 = next8(val[4*i+2], val[4*i+3]);
// i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
// i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
// aux.val[2] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])));
// aux.val[3] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])));
// result.val[i] = vqtbl4q_s8(aux, shuffle);
// }
// return result;
//}
// This works:
inline int8x16x2_t next32(const uint32_t * val) const {
uint16x8_t aux[4];
for (int i = 0; i < 4; ++i) {
auto i8 = next8(val[2*i+0], val[2*i+1]);
i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]));
auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]));
aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2));
}
int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))};
return result;
}
static uint8x16_t load_shuffle() {
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
return vld1q_u8(k_shuffle);
}
};
void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
GGML_ASSERT(n%QK_K == 0);
GGML_ASSERT(nrc_x%8 == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
Trellis3 trellis;
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
const block_iq4_kt * x8[8];
float dkt[8];
int32_t ls[8];
uint32_t idx0[8], idx[8];
for (int ix = 0; ix < nrc_x; ix += 8) {
for (int k = 0; k < 8; ++k) {
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
dkt[k] = dptr[0];
x8[k] = (const block_iq4_kt *)(dptr + 2);
}
auto vd = vld1q_f32_x2(dkt);
for (int i = 0; i < nb; ++i) {
for (int ib = 0; ib < QK_K/32; ++ib) {
for (int k = 0; k < 8; ++k) {
ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
}
auto scales1 = vmulq_f32(vd.val[0], vcvtq_f32_s32(vld1q_s32(ls+0)));
auto scales2 = vmulq_f32(vd.val[1], vcvtq_f32_s32(vld1q_s32(ls+4)));
vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
int shift1 = 8 - 4*(ib/4);
for (int j = 0; j < 8; ++j) {
for (int k = 0; k < 8; ++k) {
const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
const uint8_t * qh = ql + kNumGroups;
const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
}
vst1q_s8_x2(y[ib].qs+32*j, trellis.next32(idx));
}
}
y += 8; // = QK_K/32;
}
}
}
template <int nrc_y>
void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
assert(n%QK_K == 0);
const int nb = n/QK_K;
constexpr int kNumGroups = 64;
Trellis3 trellis;
union { uint32x4x2_t vec; uint32_t val[8]; } o_helper;
constexpr int k_acc = nrc_y;
float32x4_t accd[k_acc];
const block_q8_0_x4 * y[nrc_y];
for (int iy = 0; iy < nrc_y; ++iy) {
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
}
uint32_t values[64];
int8x16x2_t xv[4];
int32x4x4_t dot;
auto compute_dot = [&dot, &xv] (const int8_t * y) {
for (int k = 0; k < 4; ++k) {
auto yv = vld1q_s8_x2(y + 32*k);
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
}
dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
return vpaddq_s32(dot.val[0], dot.val[2]);
};
float32x4x2_t scales;
for (int ix = 0; ix < nrc_x; ++ix) {
const float * dptr = (const float *)((const char*)vx + ix*bx);
auto d = vdupq_n_f32(dptr[0]);
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
for (int i = 0; i < nb; ++i) {
auto vshb = vld1q_u32_x2(x[i].qs);
const uint32_t * shb = x[i].qs;
const uint8_t * ql = (const uint8_t *)(shb + 8);
const uint8_t * qh = ql + kNumGroups;
auto iscales1 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff)), 1));
auto iscales2 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)), 1));
iscales1 = vaddq_s32(iscales1, vdupq_n_s32(-64));
iscales2 = vaddq_s32(iscales2, vdupq_n_s32(-64));
scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(iscales1));
scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(iscales2));
o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
for (int ib = 0; ib < 4; ++ib) {
for (int j = 0; j < 4; ++j) {
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
}
}
for (int i128 = 0; i128 < 2; ++i128) {
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
for (int iy = 0; iy < nrc_y; ++iy) {
const block_q8_0_x4& yb = y[iy][2*i+i128];
auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d)));
//auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16));
//dy = vmulq_f32(scales.val[i128], dy);
auto sumi = compute_dot(yb.qs);
accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi));
}
}
}
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, vaddvq_f32(accd[iy]));
}
}
}
}
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
if (ne00%QK_K != 0) return false;
func16 = nullptr;
if (ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_0_x4_T, kernels);
return true;
}
return false;
}
//if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
// func16 = nullptr;
@@ -1213,8 +1416,6 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
return false;
}
func16 = nullptr;
switch (typeA) {
case GGML_TYPE_IQ2_KT:
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F16_T, kernels);
@@ -1236,7 +1437,7 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void *
switch (type) {
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
default: return false;
}

View File

@@ -273,7 +273,7 @@ struct MulMat {
switch (type) {
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
default: break;
}
#endif