From e2f8747555984f92d998b8131656c46764f5fc2a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 12 Jan 2025 11:39:52 +0200 Subject: [PATCH] Make sure rows per thread is a multiple of 4 also for MoE when using _r4 quants --- ggml/src/iqk/iqk_mul_mat.cpp | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 826643ad..cfca477d 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -204,6 +204,8 @@ struct MulMat { case GGML_TYPE_IQ5_K_R4: case GGML_TYPE_IQ4_KS_R4: case GGML_TYPE_IQ2_XXS_R4: + case GGML_TYPE_IQ2_XS_R4: + case GGML_TYPE_IQ2_S_R4: case GGML_TYPE_IQ3_XXS_R4: case GGML_TYPE_IQ3_S_R4: case GGML_TYPE_IQ2_BN_R4: return 4; @@ -256,11 +258,15 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { return false; } - size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA)); - size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB)); - int nrc_x = (Nx + nth - 1)/nth; - int first_x = ith*nrc_x; - if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; + size_t row_size_qx = strideA; + size_t row_size_qy = strideB; + auto num_rows = MulMat::num_rows(ggml_type(typeA)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);