mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 23:49:23 +00:00
New iq4_kt: NEON implementation
We get very respectable PP-512 = 120 t/s. TG-128 is pathetic at 5.3 t/s, so 20+% slower than the f16 variant.
This commit is contained in:
@@ -1199,10 +1199,213 @@ void mul_mat_iq4_kt_F32_T(int n, const void * vx, size_t bx, const DataInfo& inf
|
||||
}
|
||||
}
|
||||
|
||||
struct Trellis3 {
|
||||
constexpr static uint32_t ka = 89226354;
|
||||
constexpr static uint32_t kb = 64248484;
|
||||
constexpr static uint32_t ka1 = ka*ka;
|
||||
constexpr static uint32_t kb1 = kb*ka+kb;
|
||||
constexpr static uint32_t ka2 = ka1*ka;
|
||||
constexpr static uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr static uint32_t ka3 = ka2*ka;
|
||||
constexpr static uint32_t kb3 = kb2*ka+kb;
|
||||
const uint32x4_t mka = uint32x4_t{ka, ka1, ka2, ka3};
|
||||
const uint32x4_t mkb = uint32x4_t{kb, kb1, kb2, kb3};
|
||||
const uint8x16_t shuffle = load_shuffle();
|
||||
|
||||
inline uint32x4x2_t next8(uint32_t val1, uint32_t val2) const {
|
||||
uint32x4x2_t result{vdupq_n_u32(val1), vdupq_n_u32(val2)};
|
||||
result.val[0] = vmlaq_u32(mkb, mka, result.val[0]);
|
||||
result.val[1] = vmlaq_u32(mkb, mka, result.val[1]);
|
||||
return result;
|
||||
}
|
||||
//inline int8x16x2_t next32(const uint32_t * val) const {
|
||||
// int8x16x4_t aux;
|
||||
// int8x16x2_t result;
|
||||
// for (int i = 0; i < 2; ++i) {
|
||||
// auto i8 = next8(val[4*i+0], val[4*i+1]);
|
||||
// i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
// i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
// aux.val[0] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])));
|
||||
// aux.val[1] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])));
|
||||
// i8 = next8(val[4*i+2], val[4*i+3]);
|
||||
// i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
// i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
// aux.val[2] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0])));
|
||||
// aux.val[3] = vreinterpretq_s8_s32(vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1])));
|
||||
// result.val[i] = vqtbl4q_s8(aux, shuffle);
|
||||
// }
|
||||
// return result;
|
||||
//}
|
||||
// This works:
|
||||
inline int8x16x2_t next32(const uint32_t * val) const {
|
||||
uint16x8_t aux[4];
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto i8 = next8(val[2*i+0], val[2*i+1]);
|
||||
i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
auto s1 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[0]));
|
||||
auto s2 = vdotq_s32(vdupq_n_s32(-126), vdupq_n_s8(1), vreinterpretq_s8_u32(i8.val[1]));
|
||||
aux[i] = vcombine_s16(vmovn_s32(s1), vmovn_s32(s2));
|
||||
}
|
||||
int8x16x2_t result = {vcombine_s8(vmovn_s16(aux[0]), vmovn_s16(aux[1])), vcombine_s8(vmovn_s16(aux[2]), vmovn_s16(aux[3]))};
|
||||
return result;
|
||||
}
|
||||
static uint8x16_t load_shuffle() {
|
||||
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
|
||||
return vld1q_u8(k_shuffle);
|
||||
}
|
||||
};
|
||||
|
||||
void iqk_dequantize_iq4_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
|
||||
|
||||
const block_iq4_kt * x8[8];
|
||||
float dkt[8];
|
||||
int32_t ls[8];
|
||||
uint32_t idx0[8], idx[8];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
|
||||
dkt[k] = dptr[0];
|
||||
x8[k] = (const block_iq4_kt *)(dptr + 2);
|
||||
}
|
||||
auto vd = vld1q_f32_x2(dkt);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
ls[k] = ((x8[k][i].qs[ib] & 0xff) >> 1) - 64;
|
||||
idx0[k] = ((x8[k][i].qs[ib] & 1) << 15) + 4096;
|
||||
}
|
||||
auto scales1 = vmulq_f32(vd.val[0], vcvtq_f32_s32(vld1q_s32(ls+0)));
|
||||
auto scales2 = vmulq_f32(vd.val[1], vcvtq_f32_s32(vld1q_s32(ls+4)));
|
||||
vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
|
||||
vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
|
||||
int shift1 = 8 - 4*(ib/4);
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const uint8_t * ql = (const uint8_t *)(x8[k][i].qs + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
const uint32_t sh = x8[k][i].qs[ib] >> (8 + 3*j);
|
||||
idx[k+0] = ql[8*ib+j] + ((qh[8*(ib%4)+j] << shift1) & 0xf00) + ((sh & 7) << 12) + idx0[k];
|
||||
}
|
||||
vst1q_s8_x2(y[ib].qs+32*j, trellis.next32(idx));
|
||||
}
|
||||
}
|
||||
y += 8; // = QK_K/32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
constexpr int kNumGroups = 64;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
union { uint32x4x2_t vec; uint32_t val[8]; } o_helper;
|
||||
|
||||
constexpr int k_acc = nrc_y;
|
||||
|
||||
float32x4_t accd[k_acc];
|
||||
|
||||
const block_q8_0_x4 * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
uint32_t values[64];
|
||||
int8x16x2_t xv[4];
|
||||
int32x4x4_t dot;
|
||||
|
||||
auto compute_dot = [&dot, &xv] (const int8_t * y) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = vld1q_s8_x2(y + 32*k);
|
||||
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
|
||||
}
|
||||
dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
|
||||
dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
|
||||
return vpaddq_s32(dot.val[0], dot.val[2]);
|
||||
};
|
||||
|
||||
float32x4x2_t scales;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = vdupq_n_f32(dptr[0]);
|
||||
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto vshb = vld1q_u32_x2(x[i].qs);
|
||||
const uint32_t * shb = x[i].qs;
|
||||
const uint8_t * ql = (const uint8_t *)(shb + 8);
|
||||
const uint8_t * qh = ql + kNumGroups;
|
||||
auto iscales1 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(0xff)), 1));
|
||||
auto iscales2 = vreinterpretq_s32_u32(vshrq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(0xff)), 1));
|
||||
iscales1 = vaddq_s32(iscales1, vdupq_n_s32(-64));
|
||||
iscales2 = vaddq_s32(iscales2, vdupq_n_s32(-64));
|
||||
scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(iscales1));
|
||||
scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(iscales2));
|
||||
o_helper.vec.val[0] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[0], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
|
||||
o_helper.vec.val[1] = vaddq_u32(vshlq_n_u32(vandq_u32(vshb.val[1], vdupq_n_u32(1)), 15), vdupq_n_u32(4096));
|
||||
for (int ib = 0; ib < 4; ++ib) {
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const uint32_t sh1 = shb[ib+0] >> (8 + 6*j);
|
||||
const uint32_t sh2 = shb[ib+4] >> (8 + 6*j);
|
||||
values[8*ib+2*j+ 0] = ql[8*ib+2*j+ 0] + ((qh[8*ib+2*j+0] << 8) & 0xf00) + ((sh1 & 7) << 12) + o_helper.val[ib+0];
|
||||
values[8*ib+2*j+ 1] = ql[8*ib+2*j+ 1] + ((qh[8*ib+2*j+1] << 8) & 0xf00) + ((sh1 & 56) << 9) + o_helper.val[ib+0];
|
||||
values[8*ib+2*j+32] = ql[8*ib+2*j+32] + ((qh[8*ib+2*j+0] << 4) & 0xf00) + ((sh2 & 7) << 12) + o_helper.val[ib+4];
|
||||
values[8*ib+2*j+33] = ql[8*ib+2*j+33] + ((qh[8*ib+2*j+1] << 4) & 0xf00) + ((sh2 & 56) << 9) + o_helper.val[ib+4];
|
||||
}
|
||||
}
|
||||
for (int i128 = 0; i128 < 2; ++i128) {
|
||||
for (int k = 0; k < 4; ++k) xv[k] = trellis.next32(values + 32*i128 + 8*k);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_0_x4& yb = y[iy][2*i+i128];
|
||||
auto dy = vmulq_f32(scales.val[i128], vcvt_f32_f16(vld1_f16((const float16_t *)yb.d)));
|
||||
//auto dy = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(vld1_u16((const uint16_t *)yb.d)), 16));
|
||||
//dy = vmulq_f32(scales.val[i128], dy);
|
||||
auto sumi = compute_dot(yb.qs);
|
||||
accd[iy] = vfmaq_f32(accd[iy], dy, vcvtq_f32_s32(sumi));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
|
||||
|
||||
if (ne00%QK_K != 0) return false;
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
if (ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
|
||||
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_q8_0_x4_T, kernels);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
//if (ne00%QK_K == 0 && ggml_type(typeB) == GGML_TYPE_F32 && ggml_type(typeA) == GGML_TYPE_IQ4_KT) {
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
|
||||
// func16 = nullptr;
|
||||
@@ -1213,8 +1416,6 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
return false;
|
||||
}
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F16_T, kernels);
|
||||
@@ -1236,7 +1437,7 @@ bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void *
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
|
||||
default: return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -273,7 +273,7 @@ struct MulMat {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
default: break;
|
||||
}
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user