From 50fbde85dc79c12cd0c25c018fd2e76f03d11fa8 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 12 Dec 2025 13:22:02 +0000 Subject: [PATCH] Fix overflow in offset calculation in mmq --- ggml/src/ggml-cuda/mmq.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 9bf16427..d0c8233e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3875,7 +3875,7 @@ static __device__ void mul_mat_q_process_tile( const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { - load_tiles(x + stride01*it*mmq_y, tile_x, kb0, tile_x_max_i, stride01); + load_tiles(x + int64_t(stride01)*it*mmq_y, tile_x, kb0, tile_x_max_i, stride01); { const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));