diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 8e5ff8ab..0c1c1625 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -5432,6 +5432,34 @@ struct DequantizerQ40 final : public BaseLegacyDequantizer { //ggml_half aux[4]; }; +struct DequantizerQ60 final : public BaseLegacyDequantizer { + + DequantizerQ60(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} + + inline void prepare1(int i, int8x16_t * q) const { + bits.prepare1(x[i].qs, q); + auto qh8 = vld1_u8(x[i].qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + q[0] = vaddq_s8(vorrq_u8(q[0], vandq_u8(qh, hmask)), m32); + q[1] = vaddq_s8(vorrq_u8(q[1], vandq_u8(vshrq_n_u8(qh, 2), hmask)), m32); + } + inline void prepare1(int i) { + prepare1(i, bits.b); + } + + inline float16x4_t new_block(int i) { + ggml_half aux[4]; + for (int k = 0; k < 4; ++k) { + aux[k] = x[4*i+k].d; + prepare1(4*i+k, bits.b + 2*k); + } + return vld1_f16((const float16_t *)aux); + } + + const int8x16_t m32 = vdupq_n_s8(-32); + const uint8x16_t hmask = vdupq_n_u8(0x30); +}; + struct DequantizerIQ4NL final : public BaseLegacyDequantizer { DequantizerIQ4NL(const void * vx, size_t bx) : BaseLegacyDequantizer(vx, bx) {} @@ -6340,7 +6368,8 @@ static void mul_mat_iq2bn_q8_K64(int n, const void * vx, size_t bx, const DataIn template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) { + std::is_same_v || std::is_same_v || + std::is_same_v) { m.funcs[0] = mul_mat_qX_0_q8_0; m.funcs[1] = mul_mat_qX_0_q8_0; m.funcs[2] = mul_mat_qX_0_q8_0; @@ -6507,6 +6536,10 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); expected_Btype = GGML_TYPE_Q8_1; break; + case GGML_TYPE_Q6_0: + MulMat::set_functions(m); + expected_Btype = GGML_TYPE_Q8_0; + break; case GGML_TYPE_Q8_0: MulMat::set_functions(m); expected_Btype = GGML_TYPE_Q8_0; @@ -7259,11 +7292,14 @@ struct HelperQ60 final : public BaseHelper { #ifdef __aarch64__ // TODO auto vd = F16::set1(*(const float16_t *)&dl->d); - auto q = vld1q_u8(dl->qs); - q = j%QK4_0 ? vshrq_n_u8(q, 4) : vandq_u8(q, mask); - q = vaddq_s8(q, m8); - v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(q)))); - v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(q)))); + auto qh8 = vld1_u8(dl->qh); + auto qh = vcombine_u8(vshl_n_u8(qh8, 4), qh8); + auto qs = vld1q_u8(dl->qs); + qs = j%QK4_0 ? vshrq_n_u8(qs, 4) : vandq_u8(qs, mask_l); + qs = vorrq_u8(qs, vandq_u8(mask_h, j%QK4_0 ? vshrq_n_u8(qh, 2) : qh)); + qs = vaddq_s8(qs, m32); + v1 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_low_s8(qs)))); + v2 = vmulq_f16(vd, vcvtq_f16_s16(vmovl_s8(vget_high_s8(qs)))); #else auto vd = F16::set1(GGML_FP16_TO_FP32(dl->d)); auto bl = _mm_loadu_si128((const __m128i *)dl->qs); @@ -7832,7 +7868,7 @@ struct FlashQKfp32 { else if constexpr (std::is_same_v>) { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; #ifdef __aarch64__ - mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); + mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); #else mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); #endif @@ -7962,13 +7998,13 @@ struct FlashQKfp32 { DataInfo info{fms.cache, (const char *)q, k_step, (D/QK8_1)*sizeof(block_q8), 0, 1, nullptr}; switch (nq) { #ifdef __aarch64__ - case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; - case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 1: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 2: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 3: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 4: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 5: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 6: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; + case 7: mul_mat_qX_0_q8_0(D, kh.block, kh.stride, info, k_step); break; #else case 1: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break; case 2: mul_mat_qX_1_q8_1_T(D, kh.block, kh.stride, info, k_step); break;