mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-28 18:32:04 +00:00
MoE fix for R4 quants (#170)
* Fix bug in iqk_mul_mat I recently added the possibility to have a matrix multiplication kernel that processes 16 columns in the right matrix per iteration. This introduced a bug that shows up when batch size is greater than 16, is not a multiple of 16, and the remainder is not a multiple of the maximum columns being processed by the regular kernels (and so, never showed up in my testing using TG-128 and PP-512). This commit fixes the issue. * Make sure rows per thread is a multiple of 4 also for MoE when using _r4 quants --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -139,6 +139,8 @@ int main(int argc, char ** argv) {
|
|||||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||||
|
|
||||||
if (n_ctx_req > n_kv_max) {
|
if (n_ctx_req > n_kv_max) {
|
||||||
|
printf("n_ctx_req = %d is greater than n_kv_max = %d for pp = %d, tg = %d, pl = %d\n",
|
||||||
|
n_ctx_req, n_kv_max, pp, tg, pl);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -142,13 +142,14 @@ struct MulMat {
|
|||||||
}
|
}
|
||||||
int ny = funcs.size();
|
int ny = funcs.size();
|
||||||
while (!funcs[ny-1] && ny > 0) --ny;
|
while (!funcs[ny-1] && ny > 0) --ny;
|
||||||
int n_step = (nrc_y - info.cur_y)/ny;
|
int n_left = nrc_y - info.cur_y;
|
||||||
|
int n_step = n_left/ny;
|
||||||
if (n_step > 0) {
|
if (n_step > 0) {
|
||||||
if (n_step*ny != nrc_y) {
|
if (n_step*ny != n_left) {
|
||||||
++n_step;
|
++n_step;
|
||||||
int ny1 = nrc_y/n_step;
|
int ny1 = n_left/n_step;
|
||||||
int ny2 = ny1 + 1;
|
int ny2 = ny1 + 1;
|
||||||
int my1 = n_step*ny2 - nrc_y;
|
int my1 = n_step*ny2 - n_left;
|
||||||
int my2 = n_step - my1;
|
int my2 = n_step - my1;
|
||||||
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
||||||
auto this_info = info;
|
auto this_info = info;
|
||||||
@@ -163,7 +164,7 @@ struct MulMat {
|
|||||||
this_info.cur_y += ny2;
|
this_info.cur_y += ny2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info.cur_y += nrc_y;
|
info.cur_y += n_left;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
|
||||||
@@ -178,7 +179,7 @@ struct MulMat {
|
|||||||
info.cur_y += ny * n_step;
|
info.cur_y += ny * n_step;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
int n_left = nrc_y - info.cur_y;
|
n_left = nrc_y - info.cur_y;
|
||||||
if (n_left > 0) {
|
if (n_left > 0) {
|
||||||
funcs[n_left-1](n, vx, bx, info, nrc_x);
|
funcs[n_left-1](n, vx, bx, info, nrc_x);
|
||||||
}
|
}
|
||||||
@@ -203,6 +204,8 @@ struct MulMat {
|
|||||||
case GGML_TYPE_IQ5_K_R4:
|
case GGML_TYPE_IQ5_K_R4:
|
||||||
case GGML_TYPE_IQ4_KS_R4:
|
case GGML_TYPE_IQ4_KS_R4:
|
||||||
case GGML_TYPE_IQ2_XXS_R4:
|
case GGML_TYPE_IQ2_XXS_R4:
|
||||||
|
case GGML_TYPE_IQ2_XS_R4:
|
||||||
|
case GGML_TYPE_IQ2_S_R4:
|
||||||
case GGML_TYPE_IQ3_XXS_R4:
|
case GGML_TYPE_IQ3_XXS_R4:
|
||||||
case GGML_TYPE_IQ3_S_R4:
|
case GGML_TYPE_IQ3_S_R4:
|
||||||
case GGML_TYPE_IQ2_BN_R4: return 4;
|
case GGML_TYPE_IQ2_BN_R4: return 4;
|
||||||
@@ -255,11 +258,15 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
|
|||||||
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
|
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
size_t row_size_qx = strideA; //*ggml_type_size(ggml_type(typeA));
|
size_t row_size_qx = strideA;
|
||||||
size_t row_size_qy = strideB; //*ggml_type_size(ggml_type(typeB));
|
size_t row_size_qy = strideB;
|
||||||
int nrc_x = (Nx + nth - 1)/nth;
|
auto num_rows = MulMat::num_rows(ggml_type(typeA));
|
||||||
int first_x = ith*nrc_x;
|
GGML_ASSERT(Nx%num_rows == 0);
|
||||||
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
|
auto nrc_x = (Nx/num_rows + nth - 1)/nth;
|
||||||
|
auto first_x = ith*nrc_x;
|
||||||
|
if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
|
||||||
|
first_x *= num_rows;
|
||||||
|
nrc_x *= num_rows;
|
||||||
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
|
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
|
||||||
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
|
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
|
||||||
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
|
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
|
||||||
@@ -13597,6 +13604,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
|
|||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
float16_t q_f16[D*q_step];
|
float16_t q_f16[D*q_step];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
for (int i1 = 0; i1 < nq1/q_step; ++i1) {
|
||||||
fms.init_qstep();
|
fms.init_qstep();
|
||||||
kh.reset_block();
|
kh.reset_block();
|
||||||
|
|||||||
Reference in New Issue
Block a user