Fix bug in iqk_mul_mat

I recently added the possibility to have a matrix multiplication
kernel that processes 16 columns in the right matrix per iteration.
This introduced a bug that shows up when batch size is greater
than 16, is not a multiple of 16, and the remainder is not a multiple
of the maximum columns being processed by the regular kernels
(and so, never showed up in my testing using TG-128 and PP-512).

This commit fixes the issue.
This commit is contained in:
Iwan Kawrakow
2025-01-12 11:00:04 +02:00
parent 7553989dd8
commit 4e7ce22614
2 changed files with 10 additions and 6 deletions

View File

@@ -139,6 +139,8 @@ int main(int argc, char ** argv) {
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg); const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
if (n_ctx_req > n_kv_max) { if (n_ctx_req > n_kv_max) {
printf("n_ctx_req = %d is greater than n_kv_max = %d for pp = %d, tg = %d, pl = %d\n",
n_ctx_req, n_kv_max, pp, tg, pl);
continue; continue;
} }

View File

@@ -142,13 +142,14 @@ struct MulMat {
} }
int ny = funcs.size(); int ny = funcs.size();
while (!funcs[ny-1] && ny > 0) --ny; while (!funcs[ny-1] && ny > 0) --ny;
int n_step = (nrc_y - info.cur_y)/ny; int n_left = nrc_y - info.cur_y;
int n_step = n_left/ny;
if (n_step > 0) { if (n_step > 0) {
if (n_step*ny != nrc_y) { if (n_step*ny != n_left) {
++n_step; ++n_step;
int ny1 = nrc_y/n_step; int ny1 = n_left/n_step;
int ny2 = ny1 + 1; int ny2 = ny1 + 1;
int my1 = n_step*ny2 - nrc_y; int my1 = n_step*ny2 - n_left;
int my2 = n_step - my1; int my2 = n_step - my1;
for (int ix = 0; ix < nrc_x; ix += k_x_step) { for (int ix = 0; ix < nrc_x; ix += k_x_step) {
auto this_info = info; auto this_info = info;
@@ -163,7 +164,7 @@ struct MulMat {
this_info.cur_y += ny2; this_info.cur_y += ny2;
} }
} }
info.cur_y += nrc_y; info.cur_y += n_left;
} }
else { else {
for (int ix = 0; ix < nrc_x; ix += k_x_step) { for (int ix = 0; ix < nrc_x; ix += k_x_step) {
@@ -178,7 +179,7 @@ struct MulMat {
info.cur_y += ny * n_step; info.cur_y += ny * n_step;
} }
} }
int n_left = nrc_y - info.cur_y; n_left = nrc_y - info.cur_y;
if (n_left > 0) { if (n_left > 0) {
funcs[n_left-1](n, vx, bx, info, nrc_x); funcs[n_left-1](n, vx, bx, info, nrc_x);
} }
@@ -13597,6 +13598,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
#ifdef __aarch64__ #ifdef __aarch64__
float16_t q_f16[D*q_step]; float16_t q_f16[D*q_step];
#endif #endif
for (int i1 = 0; i1 < nq1/q_step; ++i1) { for (int i1 = 0; i1 < nq1/q_step; ++i1) {
fms.init_qstep(); fms.init_qstep();
kh.reset_block(); kh.reset_block();