diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 621ccffc..f87ebb96 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -61,7 +61,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { } } -template +template static __device__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { @@ -73,10 +73,8 @@ static __device__ void mul_mat_vec_q( constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) - constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1; #else - constexpr int nwarps = ncols_y <= 4 ? 4 : 2; constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) @@ -139,10 +137,10 @@ static __device__ void mul_mat_vec_q( } } -template +template #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) // tell the compiler to use as many registers as it wants, see nwarps definition below -__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) +__launch_bounds__(nwarps*WARP_SIZE, 1) #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, @@ -153,11 +151,11 @@ static __global__ void mul_mat_vec_q( const char * cx = (const char *)vx + i02*nb02; const char * cy = (const char *)vy + i2*nb12; char * cdst = (char *)dst + i2*nb2; - mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); + mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } -template -static void mul_mat_vec_q_cuda( +template +static void mul_mat_vec_q_cuda_T( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { @@ -167,61 +165,61 @@ static void mul_mat_vec_q_cuda( int id = ggml_cuda_get_device(); - int64_t nwarps = 1; - int64_t rows_per_cuda_block = 1; + int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ? + ncols_y < 4 ? 1 : 2 : 1; - if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - switch(ncols_y) { - case 1: - if (ne2 == 1) nwarps = 4; - rows_per_cuda_block = 1; - break; - case 2: - case 3: - case 4: - if (ne2 == 1) nwarps = 4; - rows_per_cuda_block = 2; - break; - case 5: - case 6: - case 7: - case 8: - if (ne2 == 1) nwarps = 2; - rows_per_cuda_block = 2; - break; - default: - GGML_ABORT("fatal error"); - break; - } - } + //if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + // switch(ncols_y) { + // case 1: + // nwarps = 4; + // rows_per_cuda_block = 1; + // break; + // case 2: + // case 3: + // case 4: + // nwarps = 4; + // rows_per_cuda_block = 2; + // break; + // case 5: + // case 6: + // case 7: + // case 8: + // nwarps = 2; + // rows_per_cuda_block = 2; + // break; + // default: + // GGML_ABORT("fatal error"); + // break; + // } + //} const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); switch (ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); break; default: GGML_ABORT("fatal error"); @@ -229,6 +227,31 @@ static void mul_mat_vec_q_cuda( } } +template +static void mul_mat_vec_q_cuda( + const void * vx, const void * vy, float * dst, const char * ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, + const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { + int nwarps = 1; + int id = ggml_cuda_get_device(); + if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + nwarps = ncols_y <= 4 ? 4 : 2; + } + switch (nwarps) { + case 1: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + case 2: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + break; + default: + mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, + ne2, nb02, nb12, nb2, ids_nb0, stream); + } +} + static void mul_mat_vec_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const char * ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,