diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d8025a5a..5bb75d32 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1583,7 +1583,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq2_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq2_kt_ref, .vec_dot = vec_dot_iq2_kt_q8_k, - .vec_dot_type = GGML_TYPE_Q8_K, +#ifdef __ARM_NEON + .vec_dot_type = GGML_TYPE_F16, +#else + .vec_dot_type = GGML_TYPE_F32, +#endif .nrows = 1, .row_meta_size = 4, }, @@ -1596,7 +1600,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq3_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq3_kt_ref, .vec_dot = vec_dot_iq3_kt_q8_k, - .vec_dot_type = GGML_TYPE_Q8_K, +#ifdef __ARM_NEON + .vec_dot_type = GGML_TYPE_F16, +#else + .vec_dot_type = GGML_TYPE_F32, +#endif .nrows = 1, .row_meta_size = 4, }, @@ -1609,7 +1617,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .from_float = quantize_row_iq4_kt, .from_float_ref = (ggml_from_float_t)quantize_row_iq4_kt_ref, .vec_dot = vec_dot_iq4_kt_q8_k, - .vec_dot_type = GGML_TYPE_Q8_K, +#ifdef __ARM_NEON + .vec_dot_type = GGML_TYPE_F16, +#else + .vec_dot_type = GGML_TYPE_F32, +#endif .nrows = 1, .row_meta_size = 8, }, diff --git a/ggml/src/iqk/iqk_gemm_ktquants.cpp b/ggml/src/iqk/iqk_gemm_ktquants.cpp index 6604480d..07550fc5 100644 --- a/ggml/src/iqk/iqk_gemm_ktquants.cpp +++ b/ggml/src/iqk/iqk_gemm_ktquants.cpp @@ -1,3 +1,4 @@ +#include "iqk_common.h" #include "iqk_gemm_ktquants.h" #include "ggml.h" @@ -316,37 +317,13 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array; - kernels[1] = mul_mat_iq2_kt_F32_T<2>; - kernels[2] = mul_mat_iq2_kt_F32_T<3>; - kernels[3] = mul_mat_iq2_kt_F32_T<4>; - kernels[4] = mul_mat_iq2_kt_F32_T<5>; - kernels[5] = mul_mat_iq2_kt_F32_T<6>; - kernels[6] = mul_mat_iq2_kt_F32_T<7>; - kernels[7] = mul_mat_iq2_kt_F32_T<8>; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F32_T, kernels); break; case GGML_TYPE_IQ3_KT: - assert (ne00 % QK_K == 0); - kernels[0] = mul_mat_iq3_kt_F32_T<1>; - kernels[1] = mul_mat_iq3_kt_F32_T<2>; - kernels[2] = mul_mat_iq3_kt_F32_T<3>; - kernels[3] = mul_mat_iq3_kt_F32_T<4>; - kernels[4] = mul_mat_iq3_kt_F32_T<5>; - kernels[5] = mul_mat_iq3_kt_F32_T<6>; - kernels[6] = mul_mat_iq3_kt_F32_T<7>; - kernels[7] = mul_mat_iq3_kt_F32_T<8>; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_F32_T, kernels); break; case GGML_TYPE_IQ4_KT: - assert (ne00 % QK_K == 0); - kernels[0] = mul_mat_iq4_kt_F32_T<1>; - kernels[1] = mul_mat_iq4_kt_F32_T<2>; - kernels[2] = mul_mat_iq4_kt_F32_T<3>; - kernels[3] = mul_mat_iq4_kt_F32_T<4>; - kernels[4] = mul_mat_iq4_kt_F32_T<5>; - kernels[5] = mul_mat_iq4_kt_F32_T<6>; - kernels[6] = mul_mat_iq4_kt_F32_T<7>; - kernels[7] = mul_mat_iq4_kt_F32_T<8>; + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); break; default: return false; @@ -358,8 +335,137 @@ bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array +static void mul_mat_iq2_kt_F16_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + assert(n%QK_K == 0); + const int nb = n/QK_K; + + Trellis1 trellis; + + auto values = vld1q_s8(iq4k_values); + + union { float16x8_t vec; float16_t val[8]; } s_helper; + + constexpr int k_acc = nrc_y == 1 ? 2 : nrc_y; + float16x8_t accd[k_acc]; + const float16_t * y[nrc_y]; + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const float16_t *)info.src1_row(iy); + + for (int ix = 0; ix < nrc_x; ++ix) { + const float * dptr = (const float *)((const char*)vx + ix*bx); + const float d = *dptr * 31.75f * 1.05f; + const block_iq2_kt * x = (const block_iq2_kt *)(dptr + 1); + + for (int iy = 0; iy < k_acc; ++iy) accd[iy] = vdupq_n_f16(0); + + for (int i = 0; i < nb; ++i) { + const uint16_t * ql = (const uint16_t *)x[i].ql; + auto u32 = *(const uint32_t *)x[i].scales; + auto s8_u32 = uint32x2_t{u32, u32 >> 4}; + s8_u32 = vand_u8(s8_u32, vdup_n_u32(0x0f0f0f0f)); + auto s8 = vqtbl1_s8(values, vreinterpret_u8_u32(s8_u32)); + auto s16 = vmovl_s8(s8); + s_helper.vec = vcvtq_f16_s16(s16); + for (int ib = 0; ib < QK_K/64; ++ib) { + auto scale1 = vdupq_n_f16(s_helper.val[2*ib+0]); + auto scale2 = vdupq_n_f16(s_helper.val[2*ib+1]); + for (int j = 0; j < 4; ++j) { + auto xval1 = vmulq_f16(scale1, trellis.gen8(ql[8*ib+j+0]+4096)); + auto xval2 = vmulq_f16(scale2, trellis.gen8(ql[8*ib+j+4]+4096)); + if constexpr (nrc_y == 1) { + accd[0] = vfmaq_f16(accd[0], xval1, vld1q_f16(y[0] + i*QK_K + 64*ib + 8*j + 0)); + accd[1] = vfmaq_f16(accd[1], xval2, vld1q_f16(y[0] + i*QK_K + 64*ib + 8*j + 32)); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + accd[iy] = vfmaq_f16(accd[iy], xval1, vld1q_f16(y[iy] + i*QK_K + 64*ib + 8*j + 0)); + accd[iy] = vfmaq_f16(accd[iy], xval2, vld1q_f16(y[iy] + i*QK_K + 64*ib + 8*j + 32)); + } + } + } + } + } + + if constexpr (nrc_y == 1) { + auto res16 = vpaddq_f16(accd[0], accd[1]); + auto res = vaddq_f32(vcvt_f32_f16(vget_low_f16(res16)), vcvt_f32_f16(vget_high_f16(res16))); + info.store(ix, 0, vaddvq_f32(res)*d); + } else { + for (int iy = 0; iy < nrc_y; ++iy) { + auto res = vaddq_f32(vcvt_f32_f16(vget_low_f16(accd[iy])), vcvt_f32_f16(vget_high_f16(accd[iy]))); + info.store(ix, iy, vaddvq_f32(res)*d); + } + } + } +} + +} + bool iqk_set_kernels_ktquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { - return false; + + if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_F16) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_IQ2_KT: + IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_kt_F16_T, kernels); + break; + //case GGML_TYPE_IQ3_KT: + // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_kt_F32_T, kernels); + // break; + //case GGML_TYPE_IQ4_KT: + // IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_kt_F32_T, kernels); + // break; + default: + return false; + } + + return true; } #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 43be0885..d6fc4d31 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -651,6 +651,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ1_S_R4: case GGML_TYPE_IQ1_M_R4: return iqk_set_kernels_1bit(ne00, typeA, typeB, m.funcs, m.func16); + case GGML_TYPE_IQ2_KT: + case GGML_TYPE_IQ3_KT: + case GGML_TYPE_IQ4_KT: + return ggml_type(typeB) == GGML_TYPE_F16 ? iqk_set_kernels_ktquants(ne00, typeA, typeB, m.funcs, m.func16) : false; default: return false; } @@ -926,4 +930,4 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long /*Nx*/, long /*Ny*/, long /*n return false; } -#endif \ No newline at end of file +#endif