mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-30 11:21:56 +00:00
Adopt fix from mainline PR 17089 (#920)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
@@ -3907,7 +3907,7 @@ static __global__ void mul_mat_q_stream_k_fixup_id(
|
|||||||
const int col_diff = col_high - col_low;
|
const int col_diff = col_high - col_low;
|
||||||
|
|
||||||
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
|
||||||
ids_dst_shared[j] = ids_dst[col_low + j];
|
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
@@ -3960,10 +3960,11 @@ template <ggml_type type, int mmq_x>
|
|||||||
static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
|
static void launch_mul_mat_q_id(ggml_backend_cuda_context & ctx, const mmq_args_id & args, cudaStream_t stream) {
|
||||||
const int id = ggml_cuda_get_device();
|
const int id = ggml_cuda_get_device();
|
||||||
const int cc = ggml_cuda_info().devices[id].cc;
|
const int cc = ggml_cuda_info().devices[id].cc;
|
||||||
const int nsm_max = ggml_cuda_info().devices[id].nsm;
|
const int nsm= ggml_cuda_info().devices[id].nsm;
|
||||||
int nsm = 1;
|
//const int nsm_max = ggml_cuda_info().devices[id].nsm;
|
||||||
//while (nsm*2 <= nsm_max) nsm *= 2;
|
//int nsm = 1;
|
||||||
while (nsm < nsm_max) nsm *= 2;
|
////while (nsm*2 <= nsm_max) nsm *= 2;
|
||||||
|
//while (nsm < nsm_max) nsm *= 2;
|
||||||
const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size;
|
const int warp_size = ggml_cuda_get_physical_warp_size_host(); //ggml_cuda_info().devices[id].warp_size;
|
||||||
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
const int nwarps = mmq_get_nwarps_host(cc, warp_size);
|
||||||
const int mmq_y = get_mmq_y_host(cc);
|
const int mmq_y = get_mmq_y_host(cc);
|
||||||
|
|||||||
Reference in New Issue
Block a user