diff --git a/ggml/src/ggml-cuda/mmq_id_common.cuh b/ggml/src/ggml-cuda/mmq_id_common.cuh index 0bb1a1a8..01806a10 100644 --- a/ggml/src/ggml-cuda/mmq_id_common.cuh +++ b/ggml/src/ggml-cuda/mmq_id_common.cuh @@ -3907,7 +3907,7 @@ static __global__ void mul_mat_q_stream_k_fixup_id( const int col_diff = col_high - col_low; for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { - ids_dst_shared[j] = ids_dst[col_low + j]; + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; } __syncthreads(); @@ -3960,10 +3960,11 @@ template static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - const int nsm_max = ggml_cuda_info().devices[id].nsm; - int nsm = 1; - //while (nsm*2 <= nsm_max) nsm *= 2; - while (nsm < nsm_max) nsm *= 2; + const int nsm= ggml_cuda_info().devices[id].nsm; + //const int nsm_max = ggml_cuda_info().devices[id].nsm; + //int nsm = 1; + ////while (nsm*2 <= nsm_max) nsm *= 2; + //while (nsm < nsm_max) nsm *= 2; const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size; const int nwarps = mmq_get_nwarps_host(cc, warp_size); const int mmq_y = get_mmq_y_host(cc);