From bd1e4d490922b08a84af8be83df1a3f8c331cc75 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 18 May 2025 19:47:53 +0300 Subject: [PATCH] Refactor iqk: factor out legacy quants (NEON) --- ggml/src/iqk/iqk_common.h | 17 + ggml/src/iqk/iqk_gemm_iquants.cpp | 309 +++++++++ ggml/src/iqk/iqk_gemm_legacy_quants.cpp | 665 ++++++++++++++++++ ggml/src/iqk/iqk_mul_mat.cpp | 887 +----------------------- 4 files changed, 993 insertions(+), 885 deletions(-) diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h index 620bd7f9..8b9db67d 100644 --- a/ggml/src/iqk/iqk_common.h +++ b/ggml/src/iqk/iqk_common.h @@ -696,6 +696,23 @@ static inline void compute_16_blocks(const uint8x16x4_t& qx_1, const uint8x16x4_ sumi = vmlaq_s32(sumi, scales.val[2*j+1], p34); } +struct SignHelper { + + inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); } + + inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) { + auto aux = vqtbl1q_u8(signs16, shuffle); + auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); + shuffle = vaddq_u8(shuffle, step); + } + + const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); + const uint8x16_t m1 = vdupq_n_u8(1); + const uint8x16_t step = vdupq_n_u8(2); + uint8x16_t shuffle; +}; + template static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); diff --git a/ggml/src/iqk/iqk_gemm_iquants.cpp b/ggml/src/iqk/iqk_gemm_iquants.cpp index 0f63e241..152b6247 100644 --- a/ggml/src/iqk/iqk_gemm_iquants.cpp +++ b/ggml/src/iqk/iqk_gemm_iquants.cpp @@ -1632,6 +1632,315 @@ bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array> 0) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >> 7) & 127)))); + auto s2 = vcombine_s8(vld1_s8((const int8_t *)(signs + ((sidx >>14) & 127))), vld1_s8((const int8_t *)(signs + ((sidx >>21) & 127)))); + b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s1)); + b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); +} + +struct DequantizerIQ2XXS final : public BaseDequantizer { + DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + + auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); + data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 + data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 + data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 + data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 + + return prepare_scales_8(data.val[1], data.val[3]); + } + + static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { + b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); + b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); + apply_signs_2(b, signs, sidx); + } + + inline void prepare(int /*i*/, int j) { + const uint8_t * idx = (const uint8_t *)(data.val + 2*j); + const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1); + prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4; + prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4; + prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4; + prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); + } + + uint32x4x4_t data; + SimpleBits bits; + +}; + +inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { + auto aux = vld1_u8(sc); + auto scales_l = vand_u8(aux, vdup_n_u8(0xf)); + auto scales_h = vshr_n_u8(aux, 4); + auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + + auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1))); + int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; + return make_wider(scales16); +} + +struct DequantizerIQ2XS final : public BaseDequantizer { + DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return prepare_4bit_scales16(x[i].scales); + } + + inline static uint8x16_t make1(const uint16_t * qs) { + auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); + auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); + return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s)); + } + + inline static void make4(const uint16_t * qs, uint8x16_t * b) { + b[0] = make1(qs + 0); + b[1] = make1(qs + 2); + b[2] = make1(qs + 4); + b[3] = make1(qs + 6); + } + + inline void prepare(int i, int j) { + make4(x[i].qs + 16*j + 0, bits.b1.val); + make4(x[i].qs + 16*j + 8, bits.b2.val); + } + + SimpleBits bits; + + +}; + +struct DequantizerIQ2S final : public BaseDequantizer { + DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 16; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.125f * GGML_FP16_TO_FP32(x[i].d); + return prepare_4bit_scales16(x[i].scales); + } + + static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { + uint32_t aux32[2]; + const uint16_t * aux16 = (const uint16_t *)aux32; + for (int k = 0; k < 2; ++k) { + aux32[1] = (qh[k] << 4) | (qh[k] << 18); + aux32[0] = (aux32[1] << 4) & 0x03000300; + aux32[1] &= 0x03000300; + b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), + vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); + sh.apply_signs_1(b+2*k+0, signs16); + + b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), + vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); + sh.apply_signs_1(b+2*k+1, signs16); + } + } + + inline void prepare(int i, int j) { + + const auto * qs = x[i].qs + 16*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(qs + QK_K/8); + + sh.init(); + make4(sh, signs16, qs+0, qh+0, bits.b1.val); + make4(sh, signs16, qs+8, qh+2, bits.b2.val); + } + + SimpleBits bits; + SignHelper sh; + + +}; + +struct DequantizerIQ3XXS final : public BaseDequantizer { + DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = 0.25f * GGML_FP16_TO_FP32(x[i].d); + gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4)); + return prepare_scales_8(gas.val[0], gas.val[1]); + } + + inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) { + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); + apply_signs_2(b, keven_signs, sidx); + } + inline void prepare(int i, int j) { + const auto * q3 = x[i].qs + 32*j; + const auto * signs = (const uint32_t *)(gas.val + j); + make2(q3, signs[0], bits.b1.val + 0); q3 += 8; + make2(q3, signs[1], bits.b1.val + 2); q3 += 8; + make2(q3, signs[2], bits.b2.val + 0); q3 += 8; + make2(q3, signs[3], bits.b2.val + 2); + } + + SimpleBits bits; + uint32x4x2_t gas; + +}; + +struct DequantizerIQ3S final : public BaseDequantizer { + DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} + + constexpr static int num_blocks() { return 8; } + constexpr static bool should_scale_quants() { return false; } + + template + inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { + d = GGML_FP16_TO_FP32(x[i].d); + uint32_t scales32[2]; + std::memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; + auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 + scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); + auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); + int32x4x2_t scales; + scales.val[0] = vmovl_s16(vget_low_s16(scales16)); + scales.val[1] = vmovl_s16(vget_high_s16(scales16)); + return scales; + } + + static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh, + const int8x16_t& hshift, uint8x16_t * b) { + auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); + const uint16_t * idx = (const uint16_t *)&vindex; + b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); + b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); + sh.apply_signs_1(b+0, signs16); + sh.apply_signs_1(b+1, signs16); + } + static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, + const int8x16_t& hshift, uint8x16_t * b) { + auto idx_l = vld1q_u8(qs); + make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); + make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); + } + + inline void prepare(int i, int j) { + + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + const auto hshift = vld1q_s16(k_shift); + + const auto * qs = x[i].qs + 32*j; + const auto * qh = x[i].qh + 4*j; + const auto signs16 = vld1q_u8(x[i].signs + 16*j); + + sh.init(); + make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val); + make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val); + } + + SimpleBits bits; + SignHelper sh; + uint32x4x2_t gas; + +}; + +} + +bool iqk_set_kernels_iquants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { + + if (ne00%QK_K != 0 || ggml_type(typeB) != GGML_TYPE_Q8_K) { + return false; + } + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_IQ2_XXS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2XXS, kernels); + break; + case GGML_TYPE_IQ2_XS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2XS, kernels); + break; + case GGML_TYPE_IQ2_S: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ2S, kernels); + break; + case GGML_TYPE_IQ3_XXS: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3XXS, kernels); + break; + case GGML_TYPE_IQ3_S: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_K_q8_K_T, DequantizerIQ3S, kernels); + break; +// case GGML_TYPE_IQ2_XXS_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xxs_r4_q8_k, kernels); +// func16 = mul_mat_iq2_xxs_r4_q8_k<16>; +// break; +// case GGML_TYPE_IQ2_XS_R4: +// assert (ne00 % QK_K == 0); +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_xs_r4_q8_k, kernels); +//#ifndef HAVE_FANCY_SIMD +// // For some reason Zen4 does not like this particular function +// func16 = mul_mat_iq2_xs_r4_q8_k_16; +//#endif +// break; +// case GGML_TYPE_IQ2_S_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq2_s_r4_q8_k, kernels); +// func16 = mul_mat_iq2_s_r4_q8_k_16; +// break; +// case GGML_TYPE_IQ3_XXS_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_xxs_r4_q8_k, kernels); +// func16 = mul_mat_iq3_xxs_r4_q8_k<16>; +// break; +// case GGML_TYPE_IQ3_S_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq3_s_r4_q8_k, kernels); +// func16 = mul_mat_iq3_s_r4_q8_k<16>; +// break; + default: + return false; + } + + return true; + +} #endif diff --git a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp index b65ba06f..40055e48 100644 --- a/ggml/src/iqk/iqk_gemm_legacy_quants.cpp +++ b/ggml/src/iqk/iqk_gemm_legacy_quants.cpp @@ -1704,6 +1704,671 @@ bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array +inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) { + for (int k = 0; k < 4; ++k) aux[k] = x[k].d; + return vld1_f16((const float16_t *)aux); +} + +template +inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) { + if constexpr (std::is_same_v) { + for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; } + } else { + for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; } + } + return vld1q_f16((const float16_t *)aux); +} + +struct Q4LegacyBits { + template + inline void prepare(const Block * x) { + for (int i = 0; i < 4; ++i) { + auto q4bits = vld1q_u8(x[i].qs); + b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); + b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); + } + } + inline void prepare1(const uint8_t * qs, int8x16_t * q) const { + auto q4bits = vld1q_u8(qs); + q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); + q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); + } + inline void prepare1(const uint8_t * qs) { + prepare1(qs, b); + } + const uint8x16_t m4b = vdupq_n_u8(0xf); + int8x16_t b[8]; +}; + +// One would think this commented out version would do better than the one below +// because it offers more opportunities to execute instructions in parallel. +// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers +// cannot it just do the sequential version below on its own? +//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { +// const auto q8b_1 = vld1q_s8_x2(qs + 0); +// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]); +// const auto q8b_2 = vld1q_s8_x2(qs + 32); +// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]); +// auto p1234 = vpaddq_s32(p12, p34); +// const auto q8b_3 = vld1q_s8_x2(qs + 64); +// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]); +// const auto q8b_4 = vld1q_s8_x2(qs + 96); +// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]); +// return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); +//} + +inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { + auto q8b = vld1q_s8_x2(qs + 0); + auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 32); + auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]); + auto p1234 = vpaddq_s32(p12, p34); + q8b = vld1q_s8_x2(qs + 64); + auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 96); + auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]); + return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); +} + +inline int32x4x2_t sum_4_blocks(const int8x16_t * b1, const int8x16_t * b2, const int8_t * qs) { + auto q8b = vld1q_s8_x2(qs + 0); + auto p12_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q8b.val[0]), b1[1], q8b.val[1]); + auto p12_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q8b.val[0]), b2[1], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 32); + auto p34_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q8b.val[0]), b1[3], q8b.val[1]); + auto p34_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q8b.val[0]), b2[3], q8b.val[1]); + auto p1234_1 = vpaddq_s32(p12_1, p34_1); + auto p1234_2 = vpaddq_s32(p12_2, p34_2); + q8b = vld1q_s8_x2(qs + 64); + auto p56_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[4], q8b.val[0]), b1[5], q8b.val[1]); + auto p56_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[4], q8b.val[0]), b2[5], q8b.val[1]); + q8b = vld1q_s8_x2(qs + 96); + auto p78_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[6], q8b.val[0]), b1[7], q8b.val[1]); + auto p78_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[6], q8b.val[0]), b2[7], q8b.val[1]); + auto p5678_1 = vpaddq_s32(p56_1, p78_1); + auto p5678_2 = vpaddq_s32(p56_2, p78_2); + return { vpaddq_s32(p1234_1, p5678_1), vpaddq_s32(p1234_2, p5678_2)}; +} + +template struct Q80 { + + constexpr static int nrc_y = nrc; + + Q80(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); + } + + inline const int8_t * quant_data(int iy, int i) const { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; + return y4->qs; + } + + inline float16x4_t load_scales(int iy, int i) const { + const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; + return vld1_f16((const float16_t *)y4->d); + } + + template + inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const { + auto qx_scales = deq.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + sc16[iy] = vmul_f16(qx_scales, q8_scales); + } + } + + template + inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * /*acc*/) const { + auto qx_scales_1 = deq1.new_block(i); + auto qx_scales_2 = deq2.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + sc16[iy ] = vmul_f16(qx_scales_1, q8_scales); + sc16[iy+nrc_y] = vmul_f16(qx_scales_2, q8_scales); + } + } + + template + inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { + deq.prepare1(i); + float d = GGML_FP16_TO_FP32(deq.x[i].d); + for (int iy = 0; iy < nrc; ++iy) { + auto q8b = vld1q_s8_x2(y[iy][i].qs); + auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); + } + } + + const block_q8_0 * y[nrc_y]; +}; + +template struct Q81 { + + constexpr static int nrc_y = nrc; + + Q81(const DataInfo& info) { + for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy); + } + + inline const int8_t * quant_data(int iy, int i) const { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; + return y4->qs; + } + + inline float16x8_t load_scales(int iy, int i) const { + const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; + return vld1q_f16((const float16_t *)y4->d); + } + + template + inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const { + auto qx_scales = deq.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales)); + acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m)); + sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales)); + } + } + + template + inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * acc) const { + auto qx_scales_1 = deq1.new_block(i); + auto qx_scales_2 = deq2.new_block(i); + for (int iy = 0; iy < nrc; ++iy) { + auto q8_scales = load_scales(iy, i); + auto q8_scales_l = vget_low_f16(q8_scales); + auto q8_scales_h = vget_high_f16(q8_scales); + auto m1 = vmul_f16(vget_high_f16(qx_scales_1), q8_scales_h); + auto m2 = vmul_f16(vget_high_f16(qx_scales_2), q8_scales_h); + acc[iy ] = vaddq_f32(acc[iy ], vcvt_f32_f16(m1)); + acc[iy+nrc_y ] = vaddq_f32(acc[iy+nrc_y], vcvt_f32_f16(m2)); + sc16[iy ] = vmul_f16(vget_low_f16(qx_scales_1), q8_scales_l); + sc16[iy+nrc_y] = vmul_f16(vget_low_f16(qx_scales_2), q8_scales_l); + } + } + + template + inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { + deq.prepare1(i); + float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); + for (int iy = 0; iy < nrc; ++iy) { + auto q8b = vld1q_s8_x2(y[iy][i].qs); + auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); + acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); + acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s))); + } + } + + const block_q8_1 * y[nrc_y]; +}; + +template +struct BaseLegacyDequantizer { + + BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {} + + inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); } + + Q4LegacyBits bits; + + const void * vx; + const block_q * x; + size_t bx; +}; + +struct DequantizerQ40 final : public BaseLegacyDequantizer { + + DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + q[0] = vaddq_s8(q[0], m8); + q[1] = vaddq_s8(q[1], m8); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m8 = vdupq_n_s8(-8); + //ggml_half aux[4]; +}; + +struct DequantizerQ60 final : public BaseLegacyDequantizer { + + DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh8 = vld1_u8(x[i].qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32); + q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m32 = vdupq_n_s8(-32); + const uint8x16_t hmask = vdupq_n_u8(0x30); +}; + +struct DequantizerIQ4NL final : public BaseLegacyDequantizer { + + DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + q[0] = vqtbl1q_s8(values, q[0]); + q[1] = vqtbl1q_s8(values, q[1]); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + static int8x16_t load_values() { + static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + return vld1q_s8(iq4nl_values); + } + + const int8x16_t values = load_values(); +}; + +struct DequantizerQ41 : public BaseLegacyDequantizer { + + DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.prepare1(x[i].qs); + } + + inline float16x8_t new_block(int i) { + uint32_t aux32[4]; + const uint32_t * s32 = (const uint32_t *)&x[4*i].d; + for (int k = 0; k < 4; ++k) { + aux32[k] = *s32; s32 += sizeof(block_q4_1)/4; + bits.prepare1(x[4*i+k].qs, bits.b + 2*k); + } + return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); + } + // Leaving this commented out attempt to be reminded that I already tried this. + // It has basically the same performance as the version above. + //inline float16x8_t new_block(int i) { + // uint32x4_t scales = {}; + // const block_q4_1 * xi = x + 4*i; + // const uint32_t * s32 = (const uint32_t *)&xi->d; + // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[0].qs, bits.b + 0); + // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[1].qs, bits.b + 2); + // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4; + // bits.prepare1(xi[2].qs, bits.b + 4); + // scales = vsetq_lane_u32(*s32, scales, 3); + // bits.prepare1(xi[3].qs, bits.b + 6); + // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle))); + //} + + const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; +}; + +struct HighBit5Legacy { + inline uint8x16_t to_bytes(const uint8_t * qh) const { + uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); + return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask)); + } + inline uint8x16_t to_negated_bytes(const uint8_t * qh) const { + uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); + return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0)); + } + const uint64x2_t mask = vdupq_n_u64(0x8040201008040201); + const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); +}; + +struct DequantizerQ50 final : public BaseLegacyDequantizer { + + DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh = x[i].qh; + q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); + q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + HighBit5Legacy hbits; + + const uint8x16_t mh = vdupq_n_u8(0xf0); + +}; + +struct DequantizerQ80 final : public BaseLegacyDequantizer { + + DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.b[0] = vld1q_s8(x[i].qs); + bits.b[1] = vld1q_s8(x[i].qs+16); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs); + bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16); + } + return vld1_f16((const float16_t *)aux); + } + +}; + +// TODO: handle case where row size is not a multiple of 128 +struct DequantizerQ80_x4 final : public BaseLegacyDequantizer { + + DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i) { + bits.b[0] = vld1q_s8(x[i].qs); + bits.b[1] = vld1q_s8(x[i].qs+16); + } + + inline float16x4_t new_block(int i) { + auto scale = vld1_f16((const float16_t *)x[i].d); + for (int k = 0; k < 4; ++k) { + bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k); + bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16); + } + return scale; + } + +}; + +struct DequantizerQ51 final : public BaseLegacyDequantizer { + + DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh = x[i].qh; + q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); + q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); + } + inline void prepare1(int i) { + bits.prepare1(x[i].qs, bits.b); + } + + inline float16x8_t new_block(int i) { + uint32_t aux32[4]; + const uint32_t * s32 = (const uint32_t *)&x[4*i].d; + for (int k = 0; k < 4; ++k) { + aux32[k] = *s32; s32 += sizeof(block_q5_1)/4; + prepare1(4*i+k, bits.b + 2*k); + } + return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); + } + + HighBit5Legacy hbits; + + const uint8x16_t mh = vdupq_n_u8(0x10); + const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; + +}; + +template +inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i)); + auto scale = vcvt_f32_f16(sc16[iy]); + acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall)); + } +} + +template +inline void sum_4(int i, Dequantizer& deq1, Dequantizer& deq2, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + auto pall = sum_4_blocks(deq1.bits.b, deq2.bits.b, q8.quant_data(iy, i)); + auto scale1 = vcvt_f32_f16(sc16[iy]); + auto scale2 = vcvt_f32_f16(sc16[iy+Q8::nrc_y]); + acc[iy] = vmlaq_f32(acc[iy], scale1, vcvtq_f32_s32(pall.val[0])); + acc[iy+Q8::nrc_y] = vmlaq_f32(acc[iy+Q8::nrc_y], scale2, vcvtq_f32_s32(pall.val[1])); + } +} + +template +inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[Q8::nrc_y]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq.new_row(ix); + + float32x4_t acc[Q8::nrc_y]; + for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb/4; ++i) { + q8.process_scales(i, deq, sc16, acc); + sum_4(i, deq, q8, sc16, acc); + } + for (int i = 4*(nb/4); i < nb; ++i) { + q8.process_1_block(i, deq, acc); + } + + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + info.store(ix, iy, vaddvq_f32(acc[iy])); + } + } +} + +template +inline void mul_mat_qX_Y_q8_Y_IK(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[2*Q8::nrc_y]; + float32x4_t acc[2*Q8::nrc_y]; + + for (int ix = 0; ix < nrc_x; ix += 2) { + + deq1.new_row(ix+0); + deq2.new_row(ix+1); + + for (int iy = 0; iy < 2*Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); + + for (int i = 0; i < nb/4; ++i) { + q8.process_scales(i, deq1, deq2, sc16, acc); + sum_4(i, deq1, deq2, q8, sc16, acc); + } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq, acc); + //} + + for (int iy = 0; iy < Q8::nrc_y; ++iy) { + info.store(ix+0, iy, vaddvq_f32(acc[iy])); + info.store(ix+1, iy, vaddvq_f32(acc[iy+Q8::nrc_y])); + } + } +} + +template +inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { + const int nb = n / QK4_1; + + float16x4_t sc16[2]; + + for (int ix = 0; ix < nrc_x; ++ix) { + + deq1.new_row(ix); + deq2.new_row(ix); + + float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) }; + + for (int i = 0; i < nb/8; ++i) { + q8.process_scales(2*i+0, deq1, sc16+0, acc+0); + q8.process_scales(2*i+1, deq2, sc16+1, acc+1); + sum_4(2*i+0, deq1, q8, sc16+0, acc+0); + sum_4(2*i+1, deq2, q8, sc16+1, acc+1); + } + for (int i = 2*(nb/8); i < nb/4; ++i) { + q8.process_scales(i, deq1, sc16, acc); + sum_4(i, deq1, q8, sc16, acc); + } + //for (int i = 4*(nb/4); i < nb; ++i) { + // q8.process_1_block(i, deq1, acc); + //} + + info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); + } +} + +template +static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q81 q8(info); + if constexpr (nrc_y == 1) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); + } else { + if (nrc_x%2 == 0 && n%128 == 0) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } + } +} + +template +static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Q80 q8(info); + if constexpr (nrc_y == 1) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); + } else { + if (nrc_x%2 == 0 && n%128 == 0) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); + } else { + Dequantizer deq(vx, bx); + mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); + } + } +} + +template +static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + Q81<1> q8(info); + mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); +} + +template +static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { + Dequantizer deq1(vx, bx), deq2(vx, bx); + Q80<1> q8(info); + mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); +} + +} + +bool iqk_set_kernels_legacy_quants(int ne00, int typeA, int typeB, std::array& kernels, mul_mat_t& func16) { + + if (ne00%QK8_0 != 0) return false; + + auto etypeA = ggml_type(typeA); + auto expected_typeB = etypeA == GGML_TYPE_Q4_1 || etypeA == GGML_TYPE_Q5_1 ? GGML_TYPE_Q8_1_X4 : GGML_TYPE_Q8_0_X4; + if (ggml_type(typeB) != expected_typeB) return false; + + func16 = nullptr; + + switch (typeA) { + case GGML_TYPE_Q4_0: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ40, kernels); + break; + case GGML_TYPE_Q4_1: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_1, DequantizerQ41, kernels); + break; + case GGML_TYPE_Q5_0: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ50, kernels); + break; + case GGML_TYPE_Q5_1: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_1_q8_1, DequantizerQ51, kernels); + break; + case GGML_TYPE_Q6_0: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ60, kernels); + break; + case GGML_TYPE_Q8_0: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerQ60, kernels); + break; + case GGML_TYPE_IQ4_NL: + IQK_SET_MUL_MAT_FUNCTIONS_T(mul_mat_qX_0_q8_0, DequantizerIQ4NL, kernels); + break; +// case GGML_TYPE_Q4_0_R8: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q4_0_r8_q8_2, kernels) +//#ifdef HAVE_FANCY_SIMD +// func16 = mul_mat_q4_0_r8_q8_2<16>; +//#endif +// break; +// case GGML_TYPE_Q5_0_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q5_0_r4_q8_2, kernels) +// break; +// case GGML_TYPE_Q6_0_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q6_0_r4_q8_2, kernels) +// break; +// case GGML_TYPE_Q8_0_R8: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_q8_0_r8_q8_2, kernels) +// break; +// case GGML_TYPE_IQ4_NL_R4: +// IQK_SET_MUL_MAT_FUNCTIONS(mul_mat_iq4_nl_r4_q8_2, kernels) +// break; + default: + return false; + } + + return ggml_type(typeB) == expected_typeB; +} + #endif #endif diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index d8698617..163ac526 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -857,850 +857,6 @@ inline void apply_signs_2(uint8x16_t * b, const uint64_t * signs, uint32_t sidx) b[1] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[1]), s2)); } -struct DequantizerIQ2XXS final : public BaseDequantizer { - DequantizerIQ2XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - - auto tmp = vld1q_u32_x4((const uint32_t *)x[i].qs); - data.val[0] = vuzp1q_u32(tmp.val[0], tmp.val[1]); // codebook indices for blocks 0...3 - data.val[1] = vuzp2q_u32(tmp.val[0], tmp.val[1]); // scales and signs for blocks 0...3 - data.val[2] = vuzp1q_u32(tmp.val[2], tmp.val[3]); // codebook indices for blocks 4...7 - data.val[3] = vuzp2q_u32(tmp.val[2], tmp.val[3]); // scales and signs for blocks 4...7 - - return prepare_scales_8(data.val[1], data.val[3]); - } - - static inline void prepare2(uint8x16_t * b, const uint8_t * idx, const uint64_t * signs, uint32_t sidx) { - b[0] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[0]], iq2xxs_grid[idx[1]]}); - b[1] = vreinterpretq_u8_u64(uint64x2_t{iq2xxs_grid[idx[2]], iq2xxs_grid[idx[3]]}); - apply_signs_2(b, signs, sidx); - } - - inline void prepare(int /*i*/, int j) { - const uint8_t * idx = (const uint8_t *)(data.val + 2*j); - const uint32_t * sidx = (const uint32_t *)(data.val + 2*j+1); - prepare2(bits.b1.val + 0, idx, keven_signs, sidx[0]); idx += 4; - prepare2(bits.b1.val + 2, idx, keven_signs, sidx[1]); idx += 4; - prepare2(bits.b2.val + 0, idx, keven_signs, sidx[2]); idx += 4; - prepare2(bits.b2.val + 2, idx, keven_signs, sidx[3]); - } - - uint32x4x4_t data; - SimpleBits bits; - -}; - -inline int32x4x4_t prepare_4bit_scales16(const uint8_t * sc) { - auto aux = vld1_u8(sc); - auto scales_l = vand_u8(aux, vdup_n_u8(0xf)); - auto scales_h = vshr_n_u8(aux, 4); - auto aux1 = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); - - auto scales8 = vreinterpretq_s8_u8(vorrq_u8(vshlq_n_u8(aux1, 1), vdupq_n_u8(1))); - int16x8x2_t scales16 = { vmovl_s8(vget_low_s8(scales8)), vmovl_s8(vget_high_s8(scales8)) }; - return make_wider(scales16); -} - -struct DequantizerIQ2XS final : public BaseDequantizer { - DequantizerIQ2XS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - return prepare_4bit_scales16(x[i].scales); - } - - inline static uint8x16_t make1(const uint16_t * qs) { - auto b = vcombine_u8(vld1_u8((const uint8_t *)(iq2xs_grid + (qs[0] & 511))), vld1_u8((const uint8_t *)(iq2xs_grid + (qs[1] & 511)))); - auto s = vcombine_s8(vld1_s8((const int8_t *)(keven_signs + (qs[0] >> 9))), vld1_s8((const int8_t *)(keven_signs + (qs[1] >> 9)))); - return vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b), s)); - } - - inline static void make4(const uint16_t * qs, uint8x16_t * b) { - b[0] = make1(qs + 0); - b[1] = make1(qs + 2); - b[2] = make1(qs + 4); - b[3] = make1(qs + 6); - } - - inline void prepare(int i, int j) { - make4(x[i].qs + 16*j + 0, bits.b1.val); - make4(x[i].qs + 16*j + 8, bits.b2.val); - } - - SimpleBits bits; - - -}; - -struct SignHelper { - - inline void init() { shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); } - - inline void apply_signs_1(uint8x16_t * b, const uint8x16_t& signs16) { - auto aux = vqtbl1q_u8(signs16, shuffle); - auto s = vreinterpretq_s8_u8(vorrq_u8(vceqq_u8(vandq_u8(aux, smask), smask), m1)); - b[0] = vreinterpretq_u8_s8(vmulq_s8(vreinterpretq_s8_u8(b[0]), s)); - shuffle = vaddq_u8(shuffle, step); - } - - const uint8x16_t smask = vreinterpretq_u8_u64(vdupq_n_u64(0x8040201008040201)); - const uint8x16_t m1 = vdupq_n_u8(1); - const uint8x16_t step = vdupq_n_u8(2); - uint8x16_t shuffle; -}; - -struct DequantizerIQ2S final : public BaseDequantizer { - DequantizerIQ2S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 16; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x4_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { - d = 0.125f * GGML_FP16_TO_FP32(x[i].d); - return prepare_4bit_scales16(x[i].scales); - } - - static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, uint8x16_t * b) { - uint32_t aux32[2]; - const uint16_t * aux16 = (const uint16_t *)aux32; - for (int k = 0; k < 2; ++k) { - aux32[1] = (qh[k] << 4) | (qh[k] << 18); - aux32[0] = (aux32[1] << 4) & 0x03000300; - aux32[1] &= 0x03000300; - b[2*k+0] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+0] | aux16[0]))), - vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+1] | aux16[1])))); - sh.apply_signs_1(b+2*k+0, signs16); - - b[2*k+1] = vcombine_u8(vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+2] | aux16[2]))), - vld1_u8((const uint8_t *)(iq2s_grid + (qs[4*k+3] | aux16[3])))); - sh.apply_signs_1(b+2*k+1, signs16); - } - } - - inline void prepare(int i, int j) { - - const auto * qs = x[i].qs + 16*j; - const auto * qh = x[i].qh + 4*j; - const auto signs16 = vld1q_u8(qs + QK_K/8); - - sh.init(); - make4(sh, signs16, qs+0, qh+0, bits.b1.val); - make4(sh, signs16, qs+8, qh+2, bits.b2.val); - } - - SimpleBits bits; - SignHelper sh; - - -}; - -struct DequantizerIQ3XXS final : public BaseDequantizer { - DequantizerIQ3XXS(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { - d = 0.25f * GGML_FP16_TO_FP32(x[i].d); - gas = vld1q_u32_x2((const uint32_t *)(x[i].qs + QK_K/4)); - return prepare_scales_8(gas.val[0], gas.val[1]); - } - - inline static void make2(const uint8_t * q3, uint32_t sidx, uint8x16_t * b) { - b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[0]], iq3xxs_grid[q3[1]], iq3xxs_grid[q3[2]], iq3xxs_grid[q3[3]]}); - b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3xxs_grid[q3[4]], iq3xxs_grid[q3[5]], iq3xxs_grid[q3[6]], iq3xxs_grid[q3[7]]}); - apply_signs_2(b, keven_signs, sidx); - } - inline void prepare(int i, int j) { - const auto * q3 = x[i].qs + 32*j; - const auto * signs = (const uint32_t *)(gas.val + j); - make2(q3, signs[0], bits.b1.val + 0); q3 += 8; - make2(q3, signs[1], bits.b1.val + 2); q3 += 8; - make2(q3, signs[2], bits.b2.val + 0); q3 += 8; - make2(q3, signs[3], bits.b2.val + 2); - } - - SimpleBits bits; - uint32x4x2_t gas; - -}; - -struct DequantizerIQ3S final : public BaseDequantizer { - DequantizerIQ3S(const void * vx, size_t bx, int nrc) : BaseDequantizer(vx, bx, nrc) {} - - constexpr static int num_blocks() { return 8; } - constexpr static bool should_scale_quants() { return false; } - - template - inline int32x4x2_t new_block(int i, const Q8& /*q8*/, float32x4_t * /*acc*/) { - d = GGML_FP16_TO_FP32(x[i].d); - uint32_t scales32[2]; - std::memcpy(scales32, x[i].scales, 4); - scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; - scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; - auto scales8 = vld1_u8((const uint8_t *)scales32); // 0, 2, 4, 6, 1, 3, 5, 7 - scales8 = vtbl1_u8(scales8, vreinterpret_u8_u64(vdup_n_u64(0x0703060205010400))); - auto scales16 = vreinterpretq_s16_u16(vmovl_u8(scales8)); - int32x4x2_t scales; - scales.val[0] = vmovl_s16(vget_low_s16(scales16)); - scales.val[1] = vmovl_s16(vget_high_s16(scales16)); - return scales; - } - - static inline void make2(SignHelper& sh, const uint8x16_t& signs16, const uint16x8_t& idx_l, uint8_t qh, - const int8x16_t& hshift, uint8x16_t * b) { - auto vindex = vorrq_u16(idx_l, vandq_u16(vshlq_u16(vdupq_n_u16(qh), hshift), vdupq_n_u16(256))); - const uint16_t * idx = (const uint16_t *)&vindex; - b[0] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[0]], iq3s_grid[idx[1]], iq3s_grid[idx[2]], iq3s_grid[idx[3]]}); - b[1] = vreinterpretq_u8_u32(uint32x4_t{iq3s_grid[idx[4]], iq3s_grid[idx[5]], iq3s_grid[idx[6]], iq3s_grid[idx[7]]}); - sh.apply_signs_1(b+0, signs16); - sh.apply_signs_1(b+1, signs16); - } - static inline void make4(SignHelper& sh, const uint8x16_t& signs16, const uint8_t * qs, const uint8_t * qh, - const int8x16_t& hshift, uint8x16_t * b) { - auto idx_l = vld1q_u8(qs); - make2(sh, signs16, vmovl_u8(vget_low_u8 (idx_l)), qh[0], hshift, b+0); - make2(sh, signs16, vmovl_u8(vget_high_u8(idx_l)), qh[1], hshift, b+2); - } - - inline void prepare(int i, int j) { - - static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; - const auto hshift = vld1q_s16(k_shift); - - const auto * qs = x[i].qs + 32*j; - const auto * qh = x[i].qh + 4*j; - const auto signs16 = vld1q_u8(x[i].signs + 16*j); - - sh.init(); - make4(sh, signs16, qs+ 0, qh+0, hshift, bits.b1.val); - make4(sh, signs16, qs+16, qh+2, hshift, bits.b2.val); - } - - SimpleBits bits; - SignHelper sh; - uint32x4x2_t gas; - -}; - -// =========================================== Legacy quants - -template -inline float16x4_t load_scales_q0(const Block * x, ggml_half * aux) { - for (int k = 0; k < 4; ++k) aux[k] = x[k].d; - return vld1_f16((const float16_t *)aux); -} - -template -inline float16x8_t load_scales_q1(const Block * x, ggml_half * aux) { - if constexpr (std::is_same_v) { - for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].s; } - } else { - for (int k = 0; k < 4; ++k) { aux[k] = x[k].d; aux[k+4] = x[k].m; } - } - return vld1q_f16((const float16_t *)aux); -} - -struct Q4LegacyBits { - template - inline void prepare(const Block * x) { - for (int i = 0; i < 4; ++i) { - auto q4bits = vld1q_u8(x[i].qs); - b[2*i+0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); - b[2*i+1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); - } - } - inline void prepare1(const uint8_t * qs, int8x16_t * q) const { - auto q4bits = vld1q_u8(qs); - q[0] = vreinterpretq_s8_u8(vandq_u8(q4bits, m4b)); - q[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits, 4)); - } - inline void prepare1(const uint8_t * qs) { - prepare1(qs, b); - } - const uint8x16_t m4b = vdupq_n_u8(0xf); - int8x16_t b[8]; -}; - -// One would think this commented out version would do better than the one below -// because it offers more opportunities to execute instructions in parallel. -// Instead, it runs significantly slower. Why? If the compiler is running out of vector registers -// cannot it just do the sequential version below on its own? -//inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { -// const auto q8b_1 = vld1q_s8_x2(qs + 0); -// auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b_1.val[0]), b[1], q8b_1.val[1]); -// const auto q8b_2 = vld1q_s8_x2(qs + 32); -// auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b_2.val[0]), b[3], q8b_2.val[1]); -// auto p1234 = vpaddq_s32(p12, p34); -// const auto q8b_3 = vld1q_s8_x2(qs + 64); -// auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b_3.val[0]), b[5], q8b_3.val[1]); -// const auto q8b_4 = vld1q_s8_x2(qs + 96); -// auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b_4.val[0]), b[7], q8b_4.val[1]); -// return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); -//} - -inline int32x4_t sum_4_blocks(const int8x16_t * b, const int8_t * qs) { - auto q8b = vld1q_s8_x2(qs + 0); - auto p12 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[0], q8b.val[0]), b[1], q8b.val[1]); - q8b = vld1q_s8_x2(qs + 32); - auto p34 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[2], q8b.val[0]), b[3], q8b.val[1]); - auto p1234 = vpaddq_s32(p12, p34); - q8b = vld1q_s8_x2(qs + 64); - auto p56 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[4], q8b.val[0]), b[5], q8b.val[1]); - q8b = vld1q_s8_x2(qs + 96); - auto p78 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b[6], q8b.val[0]), b[7], q8b.val[1]); - return vpaddq_s32(p1234, vpaddq_s32(p56, p78)); -} - -inline int32x4x2_t sum_4_blocks(const int8x16_t * b1, const int8x16_t * b2, const int8_t * qs) { - auto q8b = vld1q_s8_x2(qs + 0); - auto p12_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[0], q8b.val[0]), b1[1], q8b.val[1]); - auto p12_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[0], q8b.val[0]), b2[1], q8b.val[1]); - q8b = vld1q_s8_x2(qs + 32); - auto p34_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[2], q8b.val[0]), b1[3], q8b.val[1]); - auto p34_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[2], q8b.val[0]), b2[3], q8b.val[1]); - auto p1234_1 = vpaddq_s32(p12_1, p34_1); - auto p1234_2 = vpaddq_s32(p12_2, p34_2); - q8b = vld1q_s8_x2(qs + 64); - auto p56_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[4], q8b.val[0]), b1[5], q8b.val[1]); - auto p56_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[4], q8b.val[0]), b2[5], q8b.val[1]); - q8b = vld1q_s8_x2(qs + 96); - auto p78_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b1[6], q8b.val[0]), b1[7], q8b.val[1]); - auto p78_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), b2[6], q8b.val[0]), b2[7], q8b.val[1]); - auto p5678_1 = vpaddq_s32(p56_1, p78_1); - auto p5678_2 = vpaddq_s32(p56_2, p78_2); - return { vpaddq_s32(p1234_1, p5678_1), vpaddq_s32(p1234_2, p5678_2)}; -} - -template struct Q80 { - - constexpr static int nrc_y = nrc; - - Q80(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_0 *)info.src1_row(iy); - } - - inline const int8_t * quant_data(int iy, int i) const { - const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; - return y4->qs; - } - - inline float16x4_t load_scales(int iy, int i) const { - const block_q8_0_x4 * y4 = (const block_q8_0_x4 *)y[iy] + i; - return vld1_f16((const float16_t *)y4->d); - } - - template - inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * /*acc*/) const { - auto qx_scales = deq.new_block(i); - for (int iy = 0; iy < nrc; ++iy) { - auto q8_scales = load_scales(iy, i); - sc16[iy] = vmul_f16(qx_scales, q8_scales); - } - } - - template - inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * /*acc*/) const { - auto qx_scales_1 = deq1.new_block(i); - auto qx_scales_2 = deq2.new_block(i); - for (int iy = 0; iy < nrc; ++iy) { - auto q8_scales = load_scales(iy, i); - sc16[iy ] = vmul_f16(qx_scales_1, q8_scales); - sc16[iy+nrc_y] = vmul_f16(qx_scales_2, q8_scales); - } - } - - template - inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { - deq.prepare1(i); - float d = GGML_FP16_TO_FP32(deq.x[i].d); - for (int iy = 0; iy < nrc; ++iy) { - auto q8b = vld1q_s8_x2(y[iy][i].qs); - auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); - acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); - } - } - - const block_q8_0 * y[nrc_y]; -}; - -template struct Q81 { - - constexpr static int nrc_y = nrc; - - Q81(const DataInfo& info) { - for (int iy = 0; iy < nrc_y; ++iy) y[iy] = (const block_q8_1 *)info.src1_row(iy); - } - - inline const int8_t * quant_data(int iy, int i) const { - const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; - return y4->qs; - } - - inline float16x8_t load_scales(int iy, int i) const { - const block_q8_1_x4 * y4 = (const block_q8_1_x4 *)y[iy] + i; - return vld1q_f16((const float16_t *)y4->d); - } - - template - inline void process_scales(int i, Dequantizer& deq, float16x4_t * sc16, float32x4_t * acc) const { - auto qx_scales = deq.new_block(i); - for (int iy = 0; iy < nrc; ++iy) { - auto q8_scales = load_scales(iy, i); - auto m = vmul_f16(vget_high_f16(qx_scales), vget_high_f16(q8_scales)); - acc[iy] = vaddq_f32(acc[iy], vcvt_f32_f16(m)); - sc16[iy] = vmul_f16(vget_low_f16(qx_scales), vget_low_f16(q8_scales)); - } - } - - template - inline void process_scales(int i, Dequantizer& deq1, Dequantizer& deq2, float16x4_t * sc16, float32x4_t * acc) const { - auto qx_scales_1 = deq1.new_block(i); - auto qx_scales_2 = deq2.new_block(i); - for (int iy = 0; iy < nrc; ++iy) { - auto q8_scales = load_scales(iy, i); - auto q8_scales_l = vget_low_f16(q8_scales); - auto q8_scales_h = vget_high_f16(q8_scales); - auto m1 = vmul_f16(vget_high_f16(qx_scales_1), q8_scales_h); - auto m2 = vmul_f16(vget_high_f16(qx_scales_2), q8_scales_h); - acc[iy ] = vaddq_f32(acc[iy ], vcvt_f32_f16(m1)); - acc[iy+nrc_y ] = vaddq_f32(acc[iy+nrc_y], vcvt_f32_f16(m2)); - sc16[iy ] = vmul_f16(vget_low_f16(qx_scales_1), q8_scales_l); - sc16[iy+nrc_y] = vmul_f16(vget_low_f16(qx_scales_2), q8_scales_l); - } - } - - template - inline void process_1_block(int i, Dequantizer& deq, float32x4_t * acc) const { - deq.prepare1(i); - float d = GGML_FP16_TO_FP32(deq.x[i].d), m = 0.25f*GGML_FP16_TO_FP32(deq.x[i].m); - for (int iy = 0; iy < nrc; ++iy) { - auto q8b = vld1q_s8_x2(y[iy][i].qs); - auto p = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), deq.bits.b[0], q8b.val[0]), deq.bits.b[1], q8b.val[1]); - acc[iy] = vmlaq_f32(acc[iy], vdupq_n_f32(d*GGML_FP16_TO_FP32(y[iy][i].d)), vcvtq_f32_s32(p)); - acc[iy] = vaddq_f32(acc[iy], vdupq_n_f32(m*GGML_FP16_TO_FP32(y[iy][i].s))); - } - } - - const block_q8_1 * y[nrc_y]; -}; - -template -struct BaseLegacyDequantizer { - - BaseLegacyDequantizer(const void * vx, size_t bx) : vx(vx), x(nullptr), bx(bx) {} - - inline void new_row(int ix) { x = (const block_q *)((const char *)vx + bx*ix); } - - Q4LegacyBits bits; - - const void * vx; - const block_q * x; - size_t bx; -}; - -struct DequantizerQ40 final : public BaseLegacyDequantizer { - - DequantizerQ40(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i, int8x16_t * q) const { - bits.prepare1(x[i].qs, q); - q[0] = vaddq_s8(q[0], m8); - q[1] = vaddq_s8(q[1], m8); - } - inline void prepare1(int i) { - prepare1(i, bits.b); - } - - inline float16x4_t new_block(int i) { - ggml_half aux[4]; - for (int k = 0; k < 4; ++k) { - aux[k] = x[4*i+k].d; - prepare1(4*i+k, bits.b + 2*k); - } - return vld1_f16((const float16_t *)aux); - } - - const int8x16_t m8 = vdupq_n_s8(-8); - //ggml_half aux[4]; -}; - -struct DequantizerQ60 final : public BaseLegacyDequantizer { - - DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i, int8x16_t * q) const { - bits.prepare1(x[i].qs, q); - auto qh8 = vld1_u8(x[i].qh); - auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); - q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32); - q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32); - } - inline void prepare1(int i) { - prepare1(i, bits.b); - } - - inline float16x4_t new_block(int i) { - ggml_half aux[4]; - for (int k = 0; k < 4; ++k) { - aux[k] = x[4*i+k].d; - prepare1(4*i+k, bits.b + 2*k); - } - return vld1_f16((const float16_t *)aux); - } - - const int8x16_t m32 = vdupq_n_s8(-32); - const uint8x16_t hmask = vdupq_n_u8(0x30); -}; - -struct DequantizerIQ4NL final : public BaseLegacyDequantizer { - - DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i, int8x16_t * q) const { - bits.prepare1(x[i].qs, q); - q[0] = vqtbl1q_s8(values, q[0]); - q[1] = vqtbl1q_s8(values, q[1]); - } - inline void prepare1(int i) { - prepare1(i, bits.b); - } - - inline float16x4_t new_block(int i) { - ggml_half aux[4]; - for (int k = 0; k < 4; ++k) { - aux[k] = x[4*i+k].d; - prepare1(4*i+k, bits.b + 2*k); - } - return vld1_f16((const float16_t *)aux); - } - static int8x16_t load_values() { - static const int8_t iq4nl_values[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; - return vld1q_s8(iq4nl_values); - } - - const int8x16_t values = load_values(); -}; - -struct DequantizerQ41 : public BaseLegacyDequantizer { - - DequantizerQ41(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i) { - bits.prepare1(x[i].qs); - } - - inline float16x8_t new_block(int i) { - uint32_t aux32[4]; - const uint32_t * s32 = (const uint32_t *)&x[4*i].d; - for (int k = 0; k < 4; ++k) { - aux32[k] = *s32; s32 += sizeof(block_q4_1)/4; - bits.prepare1(x[4*i+k].qs, bits.b + 2*k); - } - return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); - } - // Leaving this commented out attempt to be reminded that I already tried this. - // It has basically the same performance as the version above. - //inline float16x8_t new_block(int i) { - // uint32x4_t scales = {}; - // const block_q4_1 * xi = x + 4*i; - // const uint32_t * s32 = (const uint32_t *)&xi->d; - // scales = vsetq_lane_u32(*s32, scales, 0); s32 += sizeof(block_q4_1)/4; - // bits.prepare1(xi[0].qs, bits.b + 0); - // scales = vsetq_lane_u32(*s32, scales, 1); s32 += sizeof(block_q4_1)/4; - // bits.prepare1(xi[1].qs, bits.b + 2); - // scales = vsetq_lane_u32(*s32, scales, 2); s32 += sizeof(block_q4_1)/4; - // bits.prepare1(xi[2].qs, bits.b + 4); - // scales = vsetq_lane_u32(*s32, scales, 3); - // bits.prepare1(xi[3].qs, bits.b + 6); - // return vreinterpretq_f16_u8(vqtbl1q_u8(vreinterpretq_u8_u32(scales), vreinterpretq_u8_u64(shuffle))); - //} - - const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; -}; - -struct HighBit5Legacy { - inline uint8x16_t to_bytes(const uint8_t * qh) const { - uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); - return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vreinterpretq_u8_u64(mask)); - } - inline uint8x16_t to_negated_bytes(const uint8_t * qh) const { - uint8x16_t h = vqtbl1q_u8(vreinterpretq_u8_u16(vdupq_n_u16(*(const uint16_t *)qh)), shuffle); - return vceqq_u8(vandq_u8(h, vreinterpretq_u8_u64(mask)), vdupq_n_u8(0)); - } - const uint64x2_t mask = vdupq_n_u64(0x8040201008040201); - const uint8x16_t shuffle = vcombine_u8(vdup_n_u8(0), vdup_n_u8(1)); -}; - -struct DequantizerQ50 final : public BaseLegacyDequantizer { - - DequantizerQ50(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i, int8x16_t * q) const { - bits.prepare1(x[i].qs, q); - auto qh = x[i].qh; - q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_negated_bytes(qh+0)))); - q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_negated_bytes(qh+2)))); - } - inline void prepare1(int i) { - prepare1(i, bits.b); - } - - inline float16x4_t new_block(int i) { - ggml_half aux[4]; - for (int k = 0; k < 4; ++k) { - aux[k] = x[4*i+k].d; - prepare1(4*i+k, bits.b + 2*k); - } - return vld1_f16((const float16_t *)aux); - } - - HighBit5Legacy hbits; - - const uint8x16_t mh = vdupq_n_u8(0xf0); - -}; - -struct DequantizerQ80 final : public BaseLegacyDequantizer { - - DequantizerQ80(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i) { - bits.b[0] = vld1q_s8(x[i].qs); - bits.b[1] = vld1q_s8(x[i].qs+16); - } - - inline float16x4_t new_block(int i) { - ggml_half aux[4]; - for (int k = 0; k < 4; ++k) { - aux[k] = x[4*i+k].d; - bits.b[2*k+0] = vld1q_s8(x[4*i+k].qs); - bits.b[2*k+1] = vld1q_s8(x[4*i+k].qs+16); - } - return vld1_f16((const float16_t *)aux); - } - -}; - -// TODO: handle case where row size is not a multiple of 128 -struct DequantizerQ80_x4 final : public BaseLegacyDequantizer { - - DequantizerQ80_x4(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i) { - bits.b[0] = vld1q_s8(x[i].qs); - bits.b[1] = vld1q_s8(x[i].qs+16); - } - - inline float16x4_t new_block(int i) { - auto scale = vld1_f16((const float16_t *)x[i].d); - for (int k = 0; k < 4; ++k) { - bits.b[2*k+0] = vld1q_s8(x[i].qs+32*k); - bits.b[2*k+1] = vld1q_s8(x[i].qs+32*k+16); - } - return scale; - } - -}; - -struct DequantizerQ51 final : public BaseLegacyDequantizer { - - DequantizerQ51(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} - - inline void prepare1(int i, int8x16_t * q) const { - bits.prepare1(x[i].qs, q); - auto qh = x[i].qh; - q[0] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[0]), vandq_u8(mh, hbits.to_bytes(qh+0)))); - q[1] = vreinterpretq_s8_u8(vorrq_u8(vreinterpretq_u8_s8(q[1]), vandq_u8(mh, hbits.to_bytes(qh+2)))); - } - inline void prepare1(int i) { - bits.prepare1(x[i].qs, bits.b); - } - - inline float16x8_t new_block(int i) { - uint32_t aux32[4]; - const uint32_t * s32 = (const uint32_t *)&x[4*i].d; - for (int k = 0; k < 4; ++k) { - aux32[k] = *s32; s32 += sizeof(block_q5_1)/4; - prepare1(4*i+k, bits.b + 2*k); - } - return vreinterpretq_f16_u8(vqtbl1q_u8(vld1q_u8((const uint8_t *)aux32), vreinterpretq_u8_u64(shuffle))); - } - - HighBit5Legacy hbits; - - const uint8x16_t mh = vdupq_n_u8(0x10); - const uint64x2_t shuffle = {0x0d0c090805040100, 0x0f0e0b0a07060302}; - -}; - -template -inline void sum_4(int i, Dequantizer& deq, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto pall = sum_4_blocks(deq.bits.b, q8.quant_data(iy, i)); - auto scale = vcvt_f32_f16(sc16[iy]); - acc[iy] = vmlaq_f32(acc[iy], scale, vcvtq_f32_s32(pall)); - } -} - -template -inline void sum_4(int i, Dequantizer& deq1, Dequantizer& deq2, const Q8& q8, const float16x4_t * sc16, float32x4_t * acc) { - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - auto pall = sum_4_blocks(deq1.bits.b, deq2.bits.b, q8.quant_data(iy, i)); - auto scale1 = vcvt_f32_f16(sc16[iy]); - auto scale2 = vcvt_f32_f16(sc16[iy+Q8::nrc_y]); - acc[iy] = vmlaq_f32(acc[iy], scale1, vcvtq_f32_s32(pall.val[0])); - acc[iy+Q8::nrc_y] = vmlaq_f32(acc[iy+Q8::nrc_y], scale2, vcvtq_f32_s32(pall.val[1])); - } -} - -template -inline void mul_mat_qX_Y_q8_Y(int n, Dequantizer& deq, Q8& q8, const DataInfo& info, int nrc_x) { - const int nb = n / QK4_1; - - float16x4_t sc16[Q8::nrc_y]; - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq.new_row(ix); - - float32x4_t acc[Q8::nrc_y]; - for (int iy = 0; iy < Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); - - for (int i = 0; i < nb/4; ++i) { - q8.process_scales(i, deq, sc16, acc); - sum_4(i, deq, q8, sc16, acc); - } - for (int i = 4*(nb/4); i < nb; ++i) { - q8.process_1_block(i, deq, acc); - } - - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - info.store(ix, iy, vaddvq_f32(acc[iy])); - } - } -} - -template -inline void mul_mat_qX_Y_q8_Y_IK(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { - const int nb = n / QK4_1; - - float16x4_t sc16[2*Q8::nrc_y]; - float32x4_t acc[2*Q8::nrc_y]; - - for (int ix = 0; ix < nrc_x; ix += 2) { - - deq1.new_row(ix+0); - deq2.new_row(ix+1); - - for (int iy = 0; iy < 2*Q8::nrc_y; ++iy) acc[iy] = vdupq_n_f32(0.f); - - for (int i = 0; i < nb/4; ++i) { - q8.process_scales(i, deq1, deq2, sc16, acc); - sum_4(i, deq1, deq2, q8, sc16, acc); - } - //for (int i = 4*(nb/4); i < nb; ++i) { - // q8.process_1_block(i, deq, acc); - //} - - for (int iy = 0; iy < Q8::nrc_y; ++iy) { - info.store(ix+0, iy, vaddvq_f32(acc[iy])); - info.store(ix+1, iy, vaddvq_f32(acc[iy+Q8::nrc_y])); - } - } -} - -template -inline void mul_mat_qX_Y_q8_Y_1(int n, Dequantizer& deq1, Dequantizer& deq2, Q8& q8, const DataInfo& info, int nrc_x) { - const int nb = n / QK4_1; - - float16x4_t sc16[2]; - - for (int ix = 0; ix < nrc_x; ++ix) { - - deq1.new_row(ix); - deq2.new_row(ix); - - float32x4_t acc[2] = { vdupq_n_f32(0.f), vdupq_n_f32(0.f) }; - - for (int i = 0; i < nb/8; ++i) { - q8.process_scales(2*i+0, deq1, sc16+0, acc+0); - q8.process_scales(2*i+1, deq2, sc16+1, acc+1); - sum_4(2*i+0, deq1, q8, sc16+0, acc+0); - sum_4(2*i+1, deq2, q8, sc16+1, acc+1); - } - for (int i = 2*(nb/8); i < nb/4; ++i) { - q8.process_scales(i, deq1, sc16, acc); - sum_4(i, deq1, q8, sc16, acc); - } - //for (int i = 4*(nb/4); i < nb; ++i) { - // q8.process_1_block(i, deq1, acc); - //} - - info.store(ix, 0, vaddvq_f32(vaddq_f32(acc[0], acc[1]))); - } -} - -template -static void mul_mat_qX_1_q8_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - Q81 q8(info); - if constexpr (nrc_y == 1) { - Dequantizer deq1(vx, bx), deq2(vx, bx); - mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); - } else { - if (nrc_x%2 == 0 && n%128 == 0) { - Dequantizer deq1(vx, bx), deq2(vx, bx); - mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); - } else { - Dequantizer deq(vx, bx); - mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); - } - } -} - -template -static void mul_mat_qX_0_q8_0(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - Q80 q8(info); - if constexpr (nrc_y == 1) { - Dequantizer deq1(vx, bx), deq2(vx, bx); - mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); - } else { - if (nrc_x%2 == 0 && n%128 == 0) { - Dequantizer deq1(vx, bx), deq2(vx, bx); - mul_mat_qX_Y_q8_Y_IK(n, deq1, deq2, q8, info, nrc_x); - } else { - Dequantizer deq(vx, bx); - mul_mat_qX_Y_q8_Y(n, deq, q8, info, nrc_x); - } - } -} - -template -static void mul_mat_qX_1_q8_1_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - Dequantizer deq1(vx, bx), deq2(vx, bx); - Q81<1> q8(info); - mul_mat_qX_Y_q8_Y_1(n, deq1, deq2, q8, info, nrc_x); -} - -template -static void mul_mat_qX_0_q8_0_1(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { - Dequantizer deq1(vx, bx), deq2(vx, bx); - Q80<1> q8(info); - mul_mat_qX_Y_q8_Y(n, deq1, deq2, q8, info, nrc_x); -} - template struct Q8_K64 { constexpr static int nrc_y = nrc; @@ -3777,17 +2933,7 @@ void mul_mat_q8_0_r8_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf m.funcs[7] = func<8>;\ template void MulMat::set_functions(MulMat& m) { - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v) { - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_0_q8_0, Dequantizer); - } - else if constexpr (std::is_same_v || std::is_same_v) { - SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_1_q8_1, Dequantizer); - } - else { SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer); - } } bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { @@ -3817,48 +2963,19 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { case GGML_TYPE_IQ6_K: return iqk_set_kernels_iqk_quants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ2_XXS: - MulMat::set_functions(m); - break; case GGML_TYPE_IQ2_XS: - MulMat::set_functions(m); - break; case GGML_TYPE_IQ2_S: - MulMat::set_functions(m); - break; case GGML_TYPE_IQ3_XXS: - MulMat::set_functions(m); - break; case GGML_TYPE_IQ3_S: - MulMat::set_functions(m); - break; + return iqk_set_kernels_iquants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_Q4_0: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; case GGML_TYPE_Q4_1: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_1_X4; - break; case GGML_TYPE_Q5_0: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; case GGML_TYPE_Q5_1: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_1_X4; - break; case GGML_TYPE_Q6_0: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; case GGML_TYPE_Q8_0: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; case GGML_TYPE_IQ4_NL: - MulMat::set_functions(m); - expected_Btype = GGML_TYPE_Q8_0_X4; - break; + return iqk_set_kernels_legacy_quants(ne00, typeA, typeB, m.funcs, m.func16); case GGML_TYPE_IQ4_NL_R4: SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0_X4;