From 1c31b25380b9321227b85f4faf4cb3efa645fed0 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 7 Nov 2025 18:58:09 +0200 Subject: [PATCH] Fix PPL increase caused by mmq_id (#913) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/mmq_id_common.cuh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 89baa31b..0bb1a1a8 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -3960,7 +3960,10 @@ template static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - const int nsm = ggml_cuda_info().devices[id].nsm; + const int nsm_max = ggml_cuda_info().devices[id].nsm; + int nsm = 1; + //while (nsm*2 <= nsm_max) nsm *= 2; + while (nsm < nsm_max) nsm *= 2; const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_y = get_mmq_y_host(cc);