mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
New iq2_kt: NEON GEMM/GEMV
This commit is contained in:
@@ -1452,6 +1452,51 @@ struct Trellis3 {
|
||||
}
|
||||
return result;
|
||||
}
|
||||
inline int8x16x2_t next32(const uint16_t * val, uint32_t v0) const {
|
||||
auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3);
|
||||
int8x16x2_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126)};
|
||||
int8x16x2_t i8;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+0]+v0));
|
||||
i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]);
|
||||
i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
auto s1 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
|
||||
i8.val[0] = vmlaq_u32(mkb, mka, vdupq_n_u32(val[2*i+1]+v0));
|
||||
i8.val[1] = vmlaq_u32(vkb3, vka3, i8.val[0]);
|
||||
i8.val[0] = vandq_u32(i8.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8.val[1] = vandq_u32(i8.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
auto s2 = vpaddq_s8(vreinterpretq_s8_u32(i8.val[0]), vreinterpretq_s8_u32(i8.val[1]));
|
||||
result.val[i] = vaddq_s8(result.val[i], vpaddq_s8(s1, s2));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
inline int8x16x4_t next64(const uint32_t * val) const {
|
||||
auto vka3 = vdupq_n_u32(ka3), vkb3 = vdupq_n_u32(kb3);
|
||||
int8x16x4_t result = {vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126), vdupq_n_s8(-126)};
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto i8_1 = next8(val[4*i+0], val[4*i+1]);
|
||||
int8x16x2_t i8_2{vmlaq_u32(vkb3, vka3, i8_1.val[0]), vmlaq_u32(vkb3, vka3, i8_1.val[1])};
|
||||
i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
auto s1_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1]));
|
||||
auto s1_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1]));
|
||||
i8_1 = next8(val[4*i+2], val[4*i+3]);
|
||||
i8_2.val[0] = vmlaq_u32(vkb3, vka3, i8_1.val[0]);
|
||||
i8_2.val[1] = vmlaq_u32(vkb3, vka3, i8_1.val[1]);
|
||||
i8_1.val[0] = vandq_u32(i8_1.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8_1.val[1] = vandq_u32(i8_1.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8_2.val[0] = vandq_u32(i8_2.val[0], vdupq_n_u32(0x3f3f3f3f));
|
||||
i8_2.val[1] = vandq_u32(i8_2.val[1], vdupq_n_u32(0x3f3f3f3f));
|
||||
auto s2_1 = vpaddq_s8(vreinterpretq_s8_u32(i8_1.val[0]), vreinterpretq_s8_u32(i8_1.val[1]));
|
||||
auto s2_2 = vpaddq_s8(vreinterpretq_s8_u32(i8_2.val[0]), vreinterpretq_s8_u32(i8_2.val[1]));
|
||||
result.val[i+0] = vaddq_s8(result.val[i+0], vpaddq_s8(s1_1, s2_1));
|
||||
result.val[i+2] = vaddq_s8(result.val[i+2], vpaddq_s8(s1_2, s2_2));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
static uint8x16_t load_shuffle() {
|
||||
static const uint8_t k_shuffle[16] = {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60};
|
||||
return vld1q_u8(k_shuffle);
|
||||
@@ -1612,6 +1657,136 @@ void mul_mat_iq4_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo&
|
||||
}
|
||||
}
|
||||
|
||||
void iqk_dequantize_iq2_kt_q80_r8(int n, const void * vx, size_t bx, void * vy, int nrc_x) {
|
||||
GGML_ASSERT(n%QK_K == 0);
|
||||
GGML_ASSERT(nrc_x%8 == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
|
||||
block_q8_0_r8 * y = (block_q8_0_r8 *)vy;
|
||||
|
||||
const block_iq2_kt * x8[8];
|
||||
float dkt[8];
|
||||
float ls[8], ls_all[64];
|
||||
uint32_t idx[8];
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ix += 8) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const float * dptr = (const float *)((const char*)vx + (ix+k)*bx);
|
||||
dkt[k] = dptr[0] * 1.05f;
|
||||
x8[k] = (const block_iq2_kt *)(dptr + 1);
|
||||
}
|
||||
auto vd = vld1q_f32_x2(dkt);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
auto u32 = *(const uint32_t *)x8[k][i].scales;
|
||||
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
|
||||
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
|
||||
auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32));
|
||||
auto s16 = vmovl_s8(s8);
|
||||
vst1q_f32(ls_all + 8*k + 0, vcvtq_f32_s32(vmovl_s16(vget_low_s16(s16))));
|
||||
vst1q_f32(ls_all + 8*k + 4, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
|
||||
}
|
||||
for (int ib = 0; ib < QK_K/32; ++ib) {
|
||||
for (int k = 0; k < 8; ++k) ls[k] = ls_all[8*k+ib];
|
||||
auto scales1 = vmulq_f32(vd.val[0], vld1q_f32(ls+0));
|
||||
auto scales2 = vmulq_f32(vd.val[1], vld1q_f32(ls+4));
|
||||
vst1_f16((float16_t *)y[ib].d+0, vcvt_f16_f32(scales1));
|
||||
vst1_f16((float16_t *)y[ib].d+4, vcvt_f16_f32(scales2));
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
for (int k = 0; k < 8; ++k) {
|
||||
const uint16_t * ql = (const uint16_t *)x8[k][i].ql;
|
||||
idx[k] = ql[4*ib+j] + 4096;
|
||||
}
|
||||
vst1q_s8_x4(y[ib].qs+64*j, trellis.next64(idx));
|
||||
}
|
||||
}
|
||||
y += 8; // = QK_K/32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int nrc_y>
|
||||
void mul_mat_iq2_kt_q8_0_x4_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis3 trellis;
|
||||
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
|
||||
float32x4_t accd[k_acc];
|
||||
|
||||
const block_q8_0_x4 * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
y[iy] = (const block_q8_0_x4 *)info.src1_row(iy);
|
||||
}
|
||||
|
||||
int8x16x2_t xv[8];
|
||||
int32x4x4_t dot;
|
||||
|
||||
auto compute_dot = [&dot] (const int8_t * y, const int8x16x2_t * xv) {
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
auto yv = vld1q_s8_x2(y + 32*k);
|
||||
dot.val[k] = vdotq_s32(vdotq_s32(vdupq_n_s32(0), xv[k].val[0], yv.val[0]), xv[k].val[1], yv.val[1]);
|
||||
}
|
||||
dot.val[0] = vpaddq_s32(dot.val[0], dot.val[1]);
|
||||
dot.val[2] = vpaddq_s32(dot.val[2], dot.val[3]);
|
||||
return vpaddq_s32(dot.val[0], dot.val[2]);
|
||||
};
|
||||
|
||||
float32x4x2_t scales;
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
auto d = vdupq_n_f32(dptr[0]*1.05f);
|
||||
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f32(0);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
auto u32 = *(const uint32_t *)x[i].scales;
|
||||
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
|
||||
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
|
||||
auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32));
|
||||
auto s16 = vmovl_s8(s8);
|
||||
scales.val[0] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_low_s16 (s16))));
|
||||
scales.val[1] = vmulq_f32(d, vcvtq_f32_s32(vmovl_s16(vget_high_s16(s16))));
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
for (int k = 0; k < 8; ++k) xv[k] = trellis.next32(ql + 4*k, 4096);
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
const block_q8_0_x4& ybl = y[iy][2*i+0];
|
||||
const block_q8_0_x4& ybh = y[iy][2*i+1];
|
||||
auto dyl = vmulq_f32(scales.val[0], vcvt_f32_f16(vld1_f16((const float16_t *)ybl.d)));
|
||||
auto dyh = vmulq_f32(scales.val[1], vcvt_f32_f16(vld1_f16((const float16_t *)ybh.d)));
|
||||
auto sumil = compute_dot(ybl.qs, xv+0);
|
||||
auto sumih = compute_dot(ybh.qs, xv+4);
|
||||
if constexpr (nrc_y == 1) {
|
||||
accd[2*iy+0] = vfmaq_f32(accd[2*iy+0], dyl, vcvtq_f32_s32(sumil));
|
||||
accd[2*iy+1] = vfmaq_f32(accd[2*iy+1], dyh, vcvtq_f32_s32(sumih));
|
||||
} else {
|
||||
accd[iy] = vfmaq_f32(accd[iy], dyl, vcvtq_f32_s32(sumil));
|
||||
accd[iy] = vfmaq_f32(accd[iy], dyh, vcvtq_f32_s32(sumih));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
info.store(ix, 0, vaddvq_f32(vaddq_f32(accd[0], accd[1])));
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
info.store(ix, iy, vaddvq_f32(accd[iy]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
@@ -1628,6 +1803,15 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_type(typeA) == GGML_TYPE_IQ2_KT) {
|
||||
if (ggml_type(typeB) == GGML_TYPE_Q8_0_X4) {
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_q8_0_x4_T, kernels);
|
||||
func16 = nullptr;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ggml_type(typeB) != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
@@ -1653,7 +1837,7 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
|
||||
bool iqk_dequantize_ktquants(int type, int n, const void * vx, size_t bx, void * y, size_t stride_y, int nrc_x) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ2_KT: iqk_dequantize_iq2_kt_q80_r8(n, vx, bx, y, nrc_x); break;
|
||||
case GGML_TYPE_IQ3_KT: iqk_dequantize_iq3_kt(n, vx, bx, (float16_t *)y, stride_y, nrc_x); break;
|
||||
case GGML_TYPE_IQ4_KT: iqk_dequantize_iq4_kt_q80_r8(n, vx, bx, y, nrc_x); break;
|
||||
default: return false;
|
||||
|
||||
@@ -271,7 +271,7 @@ struct MulMat {
|
||||
}
|
||||
#else
|
||||
switch (type) {
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||
case GGML_TYPE_IQ2_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
case GGML_TYPE_IQ3_KT: return nrc_y >= 32 ? GGML_TYPE_F16 : type;
|
||||
case GGML_TYPE_IQ4_KT: return nrc_y >= 32 ? GGML_TYPE_Q8_0_R8 : type;
|
||||
default: break;
|
||||
|
||||
Reference in New Issue
Block a user