mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-03-03 10:30:27 +00:00
Sadly, the previous commit was wrong
This commit is contained in:
@@ -61,7 +61,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
template <ggml_type type, int ncols_y, int nwarps>
|
||||
static __device__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||
@@ -73,10 +73,8 @@ static __device__ void mul_mat_vec_q(
|
||||
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||
constexpr int nwarps = 1;
|
||||
constexpr int rows_per_cuda_block = 1;
|
||||
#else
|
||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||
|
||||
@@ -139,10 +137,10 @@ static __device__ void mul_mat_vec_q(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type, int ncols_y>
|
||||
template <ggml_type type, int ncols_y, int nwarps>
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
||||
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
||||
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void mul_mat_vec_q(
|
||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
|
||||
@@ -153,11 +151,11 @@ static __global__ void mul_mat_vec_q(
|
||||
const char * cx = (const char *)vx + i02*nb02;
|
||||
const char * cy = (const char *)vy + i2*nb12;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
mul_mat_vec_q<type, ncols_y>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
mul_mat_vec_q<type, ncols_y, nwarps>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst);
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
template <ggml_type type, int nwarps>
|
||||
static void mul_mat_vec_q_cuda_T(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
|
||||
@@ -167,61 +165,61 @@ static void mul_mat_vec_q_cuda(
|
||||
|
||||
int id = ggml_cuda_get_device();
|
||||
|
||||
int64_t nwarps = 1;
|
||||
int64_t rows_per_cuda_block = 1;
|
||||
int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ?
|
||||
ncols_y < 4 ? 1 : 2 : 1;
|
||||
|
||||
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
switch(ncols_y) {
|
||||
case 1:
|
||||
if (ne2 == 1) nwarps = 4;
|
||||
rows_per_cuda_block = 1;
|
||||
break;
|
||||
case 2:
|
||||
case 3:
|
||||
case 4:
|
||||
if (ne2 == 1) nwarps = 4;
|
||||
rows_per_cuda_block = 2;
|
||||
break;
|
||||
case 5:
|
||||
case 6:
|
||||
case 7:
|
||||
case 8:
|
||||
if (ne2 == 1) nwarps = 2;
|
||||
rows_per_cuda_block = 2;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
//if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
// switch(ncols_y) {
|
||||
// case 1:
|
||||
// nwarps = 4;
|
||||
// rows_per_cuda_block = 1;
|
||||
// break;
|
||||
// case 2:
|
||||
// case 3:
|
||||
// case 4:
|
||||
// nwarps = 4;
|
||||
// rows_per_cuda_block = 2;
|
||||
// break;
|
||||
// case 5:
|
||||
// case 6:
|
||||
// case 7:
|
||||
// case 8:
|
||||
// nwarps = 2;
|
||||
// rows_per_cuda_block = 2;
|
||||
// break;
|
||||
// default:
|
||||
// GGML_ABORT("fatal error");
|
||||
// break;
|
||||
// }
|
||||
//}
|
||||
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
|
||||
const dim3 block_nums(nblocks, ne2, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
switch (ncols_y) {
|
||||
case 1:
|
||||
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 2, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 3:
|
||||
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 3, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 4:
|
||||
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 4, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 5:
|
||||
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 5, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 6:
|
||||
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 6, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 7:
|
||||
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 7, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
case 8:
|
||||
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
mul_mat_vec_q<type, 8, nwarps><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
@@ -229,6 +227,31 @@ static void mul_mat_vec_q_cuda(
|
||||
}
|
||||
}
|
||||
|
||||
template <ggml_type type>
|
||||
static void mul_mat_vec_q_cuda(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) {
|
||||
int nwarps = 1;
|
||||
int id = ggml_cuda_get_device();
|
||||
if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
|
||||
nwarps = ncols_y <= 4 ? 4 : 2;
|
||||
}
|
||||
switch (nwarps) {
|
||||
case 1:
|
||||
mul_mat_vec_q_cuda_T<type, 1>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
|
||||
ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
break;
|
||||
case 2:
|
||||
mul_mat_vec_q_cuda_T<type, 2>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
|
||||
ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
break;
|
||||
default:
|
||||
mul_mat_vec_q_cuda_T<type, 4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst,
|
||||
ne2, nb02, nb12, nb2, ids_nb0, stream);
|
||||
}
|
||||
}
|
||||
|
||||
static void mul_mat_vec_q4_0_q8_1_cuda(
|
||||
const void * vx, const void * vy, float * dst, const char * ids_data,
|
||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
|
||||
|
||||
Reference in New Issue
Block a user