mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-05-11 08:30:19 +00:00
iq2_kt: NEON implementation
This commit is contained in:
@@ -1583,7 +1583,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_iq2_kt,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref,
|
||||
.vec_dot = vec_dot_iq2_kt_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
#ifdef __ARM_NEON
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
#else
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
#endif
|
||||
.nrows = 1,
|
||||
.row_meta_size = 4,
|
||||
},
|
||||
@@ -1596,7 +1600,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_iq3_kt,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref,
|
||||
.vec_dot = vec_dot_iq3_kt_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
#ifdef __ARM_NEON
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
#else
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
#endif
|
||||
.nrows = 1,
|
||||
.row_meta_size = 4,
|
||||
},
|
||||
@@ -1609,7 +1617,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.from_float = quantize_row_iq4_kt,
|
||||
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref,
|
||||
.vec_dot = vec_dot_iq4_kt_q8_k,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
#ifdef __ARM_NEON
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
#else
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
#endif
|
||||
.nrows = 1,
|
||||
.row_meta_size = 8,
|
||||
},
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "iqk_common.h"
|
||||
#include "iqk_gemm_ktquants.h"
|
||||
#include "ggml.h"
|
||||
|
||||
@@ -316,37 +317,13 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
assert (ne00 % QK_K == 0);
|
||||
kernels[0] = mul_mat_iq2_kt_F32_T<1>;
|
||||
kernels[1] = mul_mat_iq2_kt_F32_T<2>;
|
||||
kernels[2] = mul_mat_iq2_kt_F32_T<3>;
|
||||
kernels[3] = mul_mat_iq2_kt_F32_T<4>;
|
||||
kernels[4] = mul_mat_iq2_kt_F32_T<5>;
|
||||
kernels[5] = mul_mat_iq2_kt_F32_T<6>;
|
||||
kernels[6] = mul_mat_iq2_kt_F32_T<7>;
|
||||
kernels[7] = mul_mat_iq2_kt_F32_T<8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F32_T, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
assert (ne00 % QK_K == 0);
|
||||
kernels[0] = mul_mat_iq3_kt_F32_T<1>;
|
||||
kernels[1] = mul_mat_iq3_kt_F32_T<2>;
|
||||
kernels[2] = mul_mat_iq3_kt_F32_T<3>;
|
||||
kernels[3] = mul_mat_iq3_kt_F32_T<4>;
|
||||
kernels[4] = mul_mat_iq3_kt_F32_T<5>;
|
||||
kernels[5] = mul_mat_iq3_kt_F32_T<6>;
|
||||
kernels[6] = mul_mat_iq3_kt_F32_T<7>;
|
||||
kernels[7] = mul_mat_iq3_kt_F32_T<8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_F32_T, kernels);
|
||||
break;
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
assert (ne00 % QK_K == 0);
|
||||
kernels[0] = mul_mat_iq4_kt_F32_T<1>;
|
||||
kernels[1] = mul_mat_iq4_kt_F32_T<2>;
|
||||
kernels[2] = mul_mat_iq4_kt_F32_T<3>;
|
||||
kernels[3] = mul_mat_iq4_kt_F32_T<4>;
|
||||
kernels[4] = mul_mat_iq4_kt_F32_T<5>;
|
||||
kernels[5] = mul_mat_iq4_kt_F32_T<6>;
|
||||
kernels[6] = mul_mat_iq4_kt_F32_T<7>;
|
||||
kernels[7] = mul_mat_iq4_kt_F32_T<8>;
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -358,8 +335,137 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat
|
||||
|
||||
#else // !__x86_64__
|
||||
|
||||
namespace {
|
||||
|
||||
struct Trellis1 {
|
||||
constexpr static uint32_t kmask = 0x8fff8fff;
|
||||
constexpr static uint32_t km32 = 0x3b603b60;
|
||||
constexpr static uint32_t ka = 89226354;
|
||||
constexpr static uint32_t kb = 64248484;
|
||||
constexpr static uint32_t ka1 = ka*ka;
|
||||
constexpr static uint32_t kb1 = kb*ka+kb;
|
||||
constexpr static uint32_t ka2 = ka1*ka;
|
||||
constexpr static uint32_t kb2 = kb1*ka+kb;
|
||||
constexpr static uint32_t ka3 = ka2*ka;
|
||||
constexpr static uint32_t kb3 = kb2*ka+kb;
|
||||
constexpr static uint32_t ka4 = ka3*ka;
|
||||
constexpr static uint32_t kb4 = kb3*ka+kb;
|
||||
constexpr static uint32_t ka5 = ka4*ka;
|
||||
constexpr static uint32_t kb5 = kb4*ka+kb;
|
||||
constexpr static uint32_t ka6 = ka5*ka;
|
||||
constexpr static uint32_t kb6 = kb5*ka+kb;
|
||||
constexpr static uint32_t ka7 = ka6*ka;
|
||||
constexpr static uint32_t kb7 = kb6*ka+kb;
|
||||
const uint32x4x2_t mka = {uint32x4_t{ka, ka1, ka2, ka3}, uint32x4_t{ka4, ka5, ka6, ka7}};
|
||||
const uint32x4x2_t mkb = {uint32x4_t{kb, kb1, kb2, kb3}, uint32x4_t{kb4, kb5, kb6, kb7}};
|
||||
const uint32x4_t mask1 = vdupq_n_u32(kmask);
|
||||
const uint32x4_t mask2 = vdupq_n_u32(km32);
|
||||
|
||||
inline uint32x4x2_t next8(uint32_t val) const {
|
||||
auto mval = vdupq_n_u32(val);
|
||||
uint32x4x2_t mres;
|
||||
mres.val[0] = vaddq_u32(vmulq_u32(mval, mka.val[0]), mkb.val[0]);
|
||||
mres.val[1] = vaddq_u32(vmulq_u32(mval, mka.val[1]), mkb.val[1]);
|
||||
mres.val[0] = veorq_u32(vandq_u32(mres.val[0], mask1), mask2);
|
||||
mres.val[1] = veorq_u32(vandq_u32(mres.val[1], mask1), mask2);
|
||||
return mres;
|
||||
}
|
||||
static inline float16x8_t gen8(const uint32x4x2_t& i8) {
|
||||
auto fv1 = vreinterpretq_f16_u32(i8.val[0]);
|
||||
auto fv2 = vreinterpretq_f16_u32(i8.val[1]);
|
||||
return vpaddq_f16(fv1, fv2);
|
||||
}
|
||||
inline float16x8_t gen8(uint32_t val) const { return gen8(next8(val)); }
|
||||
};
|
||||
|
||||
template <int nrc_y>
|
||||
static void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) {
|
||||
assert(n%QK_K == 0);
|
||||
const int nb = n/QK_K;
|
||||
|
||||
Trellis1 trellis;
|
||||
|
||||
auto values = vld1q_s8(iq4k_values);
|
||||
|
||||
union { float16x8_t vec; float16_t val[8]; } s_helper;
|
||||
|
||||
constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y;
|
||||
float16x8_t accd[k_acc];
|
||||
const float16_t * y[nrc_y];
|
||||
for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float16_t *)info.src1_row(iy);
|
||||
|
||||
for (int ix = 0; ix < nrc_x; ++ix) {
|
||||
const float * dptr = (const float *)((const char*)vx + ix*bx);
|
||||
const float d = *dptr * 31.75f * 1.05f;
|
||||
const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1);
|
||||
|
||||
for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f16(0);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const uint16_t * ql = (const uint16_t *)x[i].ql;
|
||||
auto u32 = *(const uint32_t *)x[i].scales;
|
||||
auto s8_u32 = uint32x2_t{u32, u32 >> 4};
|
||||
s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f));
|
||||
auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32));
|
||||
auto s16 = vmovl_s8(s8);
|
||||
s_helper.vec = vcvtq_f16_s16(s16);
|
||||
for (int ib = 0; ib < QK_K/64; ++ib) {
|
||||
auto scale1 = vdupq_n_f16(s_helper.val[2*ib+0]);
|
||||
auto scale2 = vdupq_n_f16(s_helper.val[2*ib+1]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
auto xval1 = vmulq_f16(scale1, trellis.gen8(ql[8*ib+j+0]+4096));
|
||||
auto xval2 = vmulq_f16(scale2, trellis.gen8(ql[8*ib+j+4]+4096));
|
||||
if constexpr (nrc_y == 1) {
|
||||
accd[0] = vfmaq_f16(accd[0], xval1, vld1q_f16(y[0] + i*QK_K + 64*ib + 8*j + 0));
|
||||
accd[1] = vfmaq_f16(accd[1], xval2, vld1q_f16(y[0] + i*QK_K + 64*ib + 8*j + 32));
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
accd[iy] = vfmaq_f16(accd[iy], xval1, vld1q_f16(y[iy] + i*QK_K + 64*ib + 8*j + 0));
|
||||
accd[iy] = vfmaq_f16(accd[iy], xval2, vld1q_f16(y[iy] + i*QK_K + 64*ib + 8*j + 32));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (nrc_y == 1) {
|
||||
auto res16 = vpaddq_f16(accd[0], accd[1]);
|
||||
auto res = vaddq_f32(vcvt_f32_f16(vget_low_f16(res16)), vcvt_f32_f16(vget_high_f16(res16)));
|
||||
info.store(ix, 0, vaddvq_f32(res)*d);
|
||||
} else {
|
||||
for (int iy = 0; iy < nrc_y; ++iy) {
|
||||
auto res = vaddq_f32(vcvt_f32_f16(vget_low_f16(accd[iy])), vcvt_f32_f16(vget_high_f16(accd[iy])));
|
||||
info.store(ix, iy, vaddvq_f32(res)*d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array<mul_mat_t, IQK_MAX_NY>& kernels, mul_mat_t& func16) {
|
||||
return false;
|
||||
|
||||
if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
|
||||
func16 = nullptr;
|
||||
|
||||
switch (typeA) {
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F16_T, kernels);
|
||||
break;
|
||||
//case GGML_TYPE_IQ3_KT:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_F32_T, kernels);
|
||||
// break;
|
||||
//case GGML_TYPE_IQ4_KT:
|
||||
// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels);
|
||||
// break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -651,6 +651,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
|
||||
case GGML_TYPE_IQ1_S_R4:
|
||||
case GGML_TYPE_IQ1_M_R4:
|
||||
return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16);
|
||||
case GGML_TYPE_IQ2_KT:
|
||||
case GGML_TYPE_IQ3_KT:
|
||||
case GGML_TYPE_IQ4_KT:
|
||||
return ggml_type(typeB) == GGML_TYPE_F16 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, m.funcs, m.func16) : false;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@@ -926,4 +930,4 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n
|
||||
return false;
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user