diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index ed147c4f..6375ee32 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -6292,7 +6292,7 @@ struct DequantizerIQ3S final : public BaseDequantizer { }; -template +template void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInfo& info, int nrc_x) { assert(n % QK_K == 0); const int nb = n / QK_K; @@ -7942,38 +7942,37 @@ void mul_mat_q8_0_r4_q8_0(int n, const void * vx, size_t bx, const DataInfo& inf } } +#define SET_MUL_MAT_FUNCTIONS_T(m, func, Dequantizer) \ + m.funcs[0] = func;\ + m.funcs[1] = func;\ + m.funcs[2] = func;\ + m.funcs[3] = func;\ + m.funcs[4] = func;\ + m.funcs[5] = func;\ + m.funcs[6] = func;\ + m.funcs[7] = func;\ + +#define SET_MUL_MAT_FUNCTIONS(m, func) \ + m.funcs[0] = func<1>;\ + m.funcs[1] = func<2>;\ + m.funcs[2] = func<3>;\ + m.funcs[3] = func<4>;\ + m.funcs[4] = func<5>;\ + m.funcs[5] = func<6>;\ + m.funcs[6] = func<7>;\ + m.funcs[7] = func<8>;\ + template void MulMat::set_functions(MulMat& m) { if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_0_q8_0; - m.funcs[1] = mul_mat_qX_0_q8_0; - m.funcs[2] = mul_mat_qX_0_q8_0; - m.funcs[3] = mul_mat_qX_0_q8_0; - m.funcs[4] = mul_mat_qX_0_q8_0; - m.funcs[5] = mul_mat_qX_0_q8_0; - m.funcs[6] = mul_mat_qX_0_q8_0; - m.funcs[7] = mul_mat_qX_0_q8_0; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_0_q8_0, Dequantizer); } else if constexpr (std::is_same_v || std::is_same_v) { - m.funcs[0] = mul_mat_qX_1_q8_1; - m.funcs[1] = mul_mat_qX_1_q8_1; - m.funcs[2] = mul_mat_qX_1_q8_1; - m.funcs[3] = mul_mat_qX_1_q8_1; - m.funcs[4] = mul_mat_qX_1_q8_1; - m.funcs[5] = mul_mat_qX_1_q8_1; - m.funcs[6] = mul_mat_qX_1_q8_1; - m.funcs[7] = mul_mat_qX_1_q8_1; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_1_q8_1, Dequantizer); } else { - m.funcs[0] = mul_mat_qX_K_q8_K_T<1, Dequantizer>; - m.funcs[1] = mul_mat_qX_K_q8_K_T<2, Dequantizer>; - m.funcs[2] = mul_mat_qX_K_q8_K_T<3, Dequantizer>; - m.funcs[3] = mul_mat_qX_K_q8_K_T<4, Dequantizer>; - m.funcs[4] = mul_mat_qX_K_q8_K_T<5, Dequantizer>; - m.funcs[5] = mul_mat_qX_K_q8_K_T<6, Dequantizer>; - m.funcs[6] = mul_mat_qX_K_q8_K_T<7, Dequantizer>; - m.funcs[7] = mul_mat_qX_K_q8_K_T<8, Dequantizer>; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qX_K_q8_K_T, Dequantizer); } } @@ -8062,25 +8061,11 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { MulMat::set_functions(m); break; case GGML_TYPE_IQ1_BN: - m.funcs[0] = mul_mat_iq1bn_q8_K64<1>; - m.funcs[1] = mul_mat_iq1bn_q8_K64<2>; - m.funcs[2] = mul_mat_iq1bn_q8_K64<3>; - m.funcs[3] = mul_mat_iq1bn_q8_K64<4>; - m.funcs[4] = mul_mat_iq1bn_q8_K64<5>; - m.funcs[5] = mul_mat_iq1bn_q8_K64<6>; - m.funcs[6] = mul_mat_iq1bn_q8_K64<7>; - m.funcs[7] = mul_mat_iq1bn_q8_K64<8>; + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq1bn_q8_K64); expected_Btype = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN: - m.funcs[0] = mul_mat_iq2bn_q8_K64<1>; - m.funcs[1] = mul_mat_iq2bn_q8_K64<2>; - m.funcs[2] = mul_mat_iq2bn_q8_K64<3>; - m.funcs[3] = mul_mat_iq2bn_q8_K64<4>; - m.funcs[4] = mul_mat_iq2bn_q8_K64<5>; - m.funcs[5] = mul_mat_iq2bn_q8_K64<6>; - m.funcs[6] = mul_mat_iq2bn_q8_K64<7>; - m.funcs[7] = mul_mat_iq2bn_q8_K64<8>; + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq2bn_q8_K64); expected_Btype = GGML_TYPE_Q8_K64; break; case GGML_TYPE_IQ2_BN_R4: @@ -8123,69 +8108,27 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) { expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_IQ4_NL_X4: - m.funcs[0] = mul_mat_qx_r4_q8_0; - m.funcs[1] = mul_mat_qx_r4_q8_0; - m.funcs[2] = mul_mat_qx_r4_q8_0; - m.funcs[3] = mul_mat_qx_r4_q8_0; - m.funcs[4] = mul_mat_qx_r4_q8_0; - m.funcs[5] = mul_mat_qx_r4_q8_0; - m.funcs[6] = mul_mat_qx_r4_q8_0; - m.funcs[7] = mul_mat_qx_r4_q8_0; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, IQ4_NL_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_IQ4_XS_R4: - m.funcs[0] = mul_mat_iq4_xs_r4_q8_k<1>; - m.funcs[1] = mul_mat_iq4_xs_r4_q8_k<2>; - m.funcs[2] = mul_mat_iq4_xs_r4_q8_k<3>; - m.funcs[3] = mul_mat_iq4_xs_r4_q8_k<4>; - m.funcs[4] = mul_mat_iq4_xs_r4_q8_k<5>; - m.funcs[5] = mul_mat_iq4_xs_r4_q8_k<6>; - m.funcs[6] = mul_mat_iq4_xs_r4_q8_k<7>; - m.funcs[7] = mul_mat_iq4_xs_r4_q8_k<8>; + SET_MUL_MAT_FUNCTIONS(m, mul_mat_iq4_xs_r4_q8_k); expected_Btype = GGML_TYPE_Q8_K; break; case GGML_TYPE_Q4_0_R4: - m.funcs[0] = mul_mat_qx_r4_q8_0; - m.funcs[1] = mul_mat_qx_r4_q8_0; - m.funcs[2] = mul_mat_qx_r4_q8_0; - m.funcs[3] = mul_mat_qx_r4_q8_0; - m.funcs[4] = mul_mat_qx_r4_q8_0; - m.funcs[5] = mul_mat_qx_r4_q8_0; - m.funcs[6] = mul_mat_qx_r4_q8_0; - m.funcs[7] = mul_mat_qx_r4_q8_0; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q4_0_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q5_0_R4: - m.funcs[0] = mul_mat_qx_r4_q8_0; - m.funcs[1] = mul_mat_qx_r4_q8_0; - m.funcs[2] = mul_mat_qx_r4_q8_0; - m.funcs[3] = mul_mat_qx_r4_q8_0; - m.funcs[4] = mul_mat_qx_r4_q8_0; - m.funcs[5] = mul_mat_qx_r4_q8_0; - m.funcs[6] = mul_mat_qx_r4_q8_0; - m.funcs[7] = mul_mat_qx_r4_q8_0; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q5_0_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q6_0_R4: - m.funcs[0] = mul_mat_qx_r4_q8_0; - m.funcs[1] = mul_mat_qx_r4_q8_0; - m.funcs[2] = mul_mat_qx_r4_q8_0; - m.funcs[3] = mul_mat_qx_r4_q8_0; - m.funcs[4] = mul_mat_qx_r4_q8_0; - m.funcs[5] = mul_mat_qx_r4_q8_0; - m.funcs[6] = mul_mat_qx_r4_q8_0; - m.funcs[7] = mul_mat_qx_r4_q8_0; + SET_MUL_MAT_FUNCTIONS_T(m, mul_mat_qx_r4_q8_0, Q6_0_R4_Dequantizer); expected_Btype = GGML_TYPE_Q8_0; break; case GGML_TYPE_Q8_0_R4: - m.funcs[0] = mul_mat_q8_0_r4_q8_0<1>; - m.funcs[1] = mul_mat_q8_0_r4_q8_0<2>; - m.funcs[2] = mul_mat_q8_0_r4_q8_0<3>; - m.funcs[3] = mul_mat_q8_0_r4_q8_0<4>; - m.funcs[4] = mul_mat_q8_0_r4_q8_0<5>; - m.funcs[5] = mul_mat_q8_0_r4_q8_0<6>; - m.funcs[6] = mul_mat_q8_0_r4_q8_0<7>; - m.funcs[7] = mul_mat_q8_0_r4_q8_0<8>; + SET_MUL_MAT_FUNCTIONS(m, mul_mat_q8_0_r4_q8_0); expected_Btype = GGML_TYPE_Q8_0; break; default: @@ -8402,7 +8345,7 @@ struct F16 { #else using Data = float16x8_t; constexpr static int block_size = 8; - constexpr static int num_registers = 32; + //constexpr static int num_registers = 32; constexpr static int q_step = 8; static inline Data zero() { return vdupq_n_f16(0); } static inline Data load(const char * ptr, int i) { return vld1q_f16((const float16_t *)ptr + block_size*i); }