Fix PPL increase caused by mmq_id (#913)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-11-07 18:58:09 +02:00
committed by GitHub
parent f9a411e5db
commit 1c31b25380

View File

@@ -3960,7 +3960,10 @@ template <ggml_type type, int mmq_x>
static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int nsm = ggml_cuda_info().devices[id].nsm;
const int nsm_max = ggml_cuda_info().devices[id].nsm;
int nsm = 1;
//while (nsm*2 <= nsm_max) nsm *= 2;
while (nsm < nsm_max) nsm *= 2;
const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size;
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
const int mmq_y = get_mmq_y_host(cc);