Also do the dequantize approach for mul_mat_id

This commit is contained in:
Iwan Kawrakow
2025-06-03 10:50:09 +03:00
parent 7a8abe29f7
commit feccbe0b9d

View File

@@ -501,6 +501,48 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
assert(row_mapping != nullptr);
MulMat mm;
auto etypeA = ggml_type(typeA);
if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
return false;
}
constexpr int k_x_step = 32;
auto num_rows = MulMat::num_rows(ggml_type(dequant_type));
GGML_ASSERT(Nx%num_rows == 0);
auto nrc_x = (Nx/num_rows + nth - 1)/nth;
auto first_x = ith*nrc_x;
if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
first_x *= num_rows;
nrc_x *= num_rows;
auto type_size = ggml_type_size(dequant_type);
thread_local std::vector<char> f;
size_t row_size_qx = ne00*type_size;
size_t row_size_qy = strideB;
//printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
auto this_info = info;
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
GGML_ABORT("Fatal error");
}
mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
}
return true;
}
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;
}