diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index 041ad165..2277349b 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -737,46 +737,45 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n auto etypeA = ggml_type(typeA); if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) { - if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) { - return false; - } + if (MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) { - constexpr int k_x_step = 64; + constexpr int k_x_step = 64; - auto num_rows = MulMat::num_rows(ggml_type(dequant_type)); - GGML_ASSERT(Nx%num_rows == 0); - auto nrc_x = (Nx/num_rows + nth - 1)/nth; - auto first_x = ith*nrc_x; - if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; - first_x *= num_rows; - nrc_x *= num_rows; + auto num_rows = MulMat::num_rows(ggml_type(dequant_type)); + GGML_ASSERT(Nx%num_rows == 0); + auto nrc_x = (Nx/num_rows + nth - 1)/nth; + auto first_x = ith*nrc_x; + if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x; + first_x *= num_rows; + nrc_x *= num_rows; - size_t row_size_qx = ggml_row_size(dequant_type, ne00); - size_t row_size_qy = strideB; + size_t row_size_qx = ggml_row_size(dequant_type, ne00); + size_t row_size_qy = strideB; - DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; + DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; - auto& f = thread_local_work_buffer(); + auto& f = thread_local_work_buffer(); - for (int ix = 0; ix < nrc_x; ix += k_x_step) { - auto this_info = info; - this_info.s += ix; - int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; - if (f.size() < 2*row_size_qx*this_nrc_x) f.resize(2*row_size_qx*this_nrc_x); - auto Xu = f.data(); - auto Xg = f.data() + row_size_qx*this_nrc_x; - if (!iqk_convert_repack(typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) { - GGML_ABORT("Fatal error"); + for (int ix = 0; ix < nrc_x; ix += k_x_step) { + auto this_info = info; + this_info.s += ix; + int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix; + if (f.size() < 2*row_size_qx*this_nrc_x) f.resize(2*row_size_qx*this_nrc_x); + auto Xu = f.data(); + auto Xg = f.data() + row_size_qx*this_nrc_x; + if (!iqk_convert_repack(typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) { + GGML_ABORT("Fatal error"); + } + if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { + GGML_ABORT("Fatal error"); + } + auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr; + auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr; + mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op); } - if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) { - GGML_ABORT("Fatal error"); - } - auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr; - auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr; - mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op); - } - return true; + return true; + } }