diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 10d16aeb..4bd99b65 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,23 @@ #include "iqk_mmvq.cuh" #include "vecdotq.cuh" +struct mmvq_args { + const void * vx; + const void * vy; + float * dst; + const char * ids_data; + const int ncols_x; + const int nrows_x; + const int nrows_y; + const int ncols_y; + const int nrows_dst; + const int ne2; + const uint64_t nb02; + const uint64_t nb12; + const uint64_t nb2; + const uint64_t ids_nb0; +}; + typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { @@ -151,25 +168,11 @@ static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) { + int i2 = blockIdx.y; char * cdst = (char *)dst + i2*nb2; int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; if (i02 < 0) { - // We clear the buffer via cudaMemset instead -//#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) -// constexpr int rows_per_cuda_block = 1; -//#else -// constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; -//#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) -// const int row0 = rows_per_cuda_block*blockIdx.x; -// if (threadIdx.y == 0) { -// dst = (float *)cdst; -// for (int j = 0; j < ncols_y; ++j) { -// if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { -// dst[j*nrows_dst + row0 + threadIdx.x] = 0; -// } -// } -// } return; } const char * cx = (const char *)vx + i02*nb02; @@ -178,21 +181,18 @@ static __global__ void mul_mat_vec_q( } template -static void mul_mat_vec_q_cuda_T( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { +static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) { - GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); - GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); + GGML_ASSERT(args.ncols_x % ggml_blck_size(type) == 0); + GGML_ASSERT(args.ncols_y <= MMVQ_MAX_BATCH_SIZE); int id = ggml_cuda_get_device(); int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ? - ncols_y < 4 ? 1 : 2 : 1; + args.ncols_y < 4 ? 1 : 2 : 1; //if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - // switch(ncols_y) { + // switch(args.ncols_y) { // case 1: // nwarps = 4; // rows_per_cuda_block = 1; @@ -215,34 +215,42 @@ static void mul_mat_vec_q_cuda_T( // break; // } //} - const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; - const dim3 block_nums(nblocks, ne2, 1); + const int64_t nblocks = (args.nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; + const dim3 block_nums(nblocks, args.ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); - switch (ncols_y) { + switch (args.ncols_y) { case 1: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 2: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 3: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 4: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 5: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 6: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 7: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 8: - mul_mat_vec_q<<>>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, nrows_dst, nb02, nb12, nb2, ids_nb0); + mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; default: GGML_ABORT("fatal error"); @@ -251,196 +259,106 @@ static void mul_mat_vec_q_cuda_T( } template -static void mul_mat_vec_q_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, cudaStream_t stream) { +static void mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) { int nwarps = 1; int id = ggml_cuda_get_device(); - if (ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - nwarps = ncols_y <= 4 ? 4 : 2; + if (args.ne2 < 2 && ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 + nwarps = args.ncols_y <= 4 ? 4 : 2; } switch (nwarps) { case 1: - mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, - ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q_cuda_T(args, stream); break; case 2: - mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, - ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q_cuda_T(args, stream); break; default: - mul_mat_vec_q_cuda_T(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, - ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q_cuda_T(args, stream); } } -static void mul_mat_vec_q4_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q4_0_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q4_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q4_1_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q5_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q5_0_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q5_1_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q5_1_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q6_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q6_0_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q8_0_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q8_0_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q2_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q2_K_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q3_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q3_K_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q4_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q4_K_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q5_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q5_K_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_q6_K_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_q6_K_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq2_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq2_xxs_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq2_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq2_xs_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq2_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq2_s_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq3_xxs_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq3_xxs_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq1_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq1_s_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq1_m_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq1_m_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq4_nl_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq4_nl_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_mxfp4_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_mxfp4_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq4_xs_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq4_xs_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } -static void mul_mat_vec_iq3_s_q8_1_cuda( - const void * vx, const void * vy, float * dst, const char * ids_data, - const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, - const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) { - - mul_mat_vec_q_cuda(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); +static void mul_mat_vec_iq3_s_q8_1_cuda(const mmvq_args & args, cudaStream_t stream) { + mul_mat_vec_q_cuda(args, stream); } static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, @@ -458,57 +376,102 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + +//struct mmvq_args { +// const void * vx; +// const void * vy; +// float * dst; +// const char * ids_data; +// const int ncols_x; +// const int nrows_x; +// const int nrows_y; +// const int ncols_y; +// const int nrows_dst; +// const int ne2; +// const uint64_t nb02; +// const uint64_t nb12; +// const uint64_t nb2; +// const uint64_t ids_nb0; +//}; + mmvq_args args{/* vx */ src0_dd_i, + /* vy */ src1_ddq_i, + /* dst */ dst_dd_i, + /* ids_data */ ids_data, + /* ncols_x */ int(ne00), + /* nrows_x */ int(row_diff), + /* nrows_y */ int(src1_padded_row_size), + /* ncols_y */ int(src1_ncols), + /* nrows_dst*/ int(nrows_dst), + /* ne2 */ int(ne2), + /* nb02 */ uint64_t(nb02), + /* nb12 */ uint64_t(nb12), + /* nb2 */ uint64_t(nb2), + /* ids_nb0 */ uint64_t(ids_nb0) + }; + switch (type) { case GGML_TYPE_Q4_0: - mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q4_0_q8_1_cuda(args, stream); break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q4_1_q8_1_cuda(args, stream); break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q5_0_q8_1_cuda(args, stream); break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q5_1_q8_1_cuda(args, stream); break; case GGML_TYPE_Q6_0: - mul_mat_vec_q6_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q6_0_q8_1_cuda(args, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q8_0_q8_1_cuda(args, stream); break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q2_K_q8_1_cuda(args, stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q3_K_q8_1_cuda(args, stream); break; case GGML_TYPE_Q4_K: - mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q4_K_q8_1_cuda(args, stream); break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q5_K_q8_1_cuda(args, stream); break; case GGML_TYPE_Q6_K: - mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_q6_K_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ2_XXS: - mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_iq2_xxs_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ2_XS: - mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_iq2_xs_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ2_S: - mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_iq2_s_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ3_XXS: - mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_iq3_xxs_q8_1_cuda(args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_iq3_s_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ1_S: - mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_iq1_s_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ1_M: - mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); + mul_mat_vec_iq1_m_q8_1_cuda(args, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_iq4_nl_q8_1_cuda(args, stream); + break; + case GGML_TYPE_MXFP4: + mul_mat_vec_mxfp4_q8_1_cuda(args, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_iq4_xs_q8_1_cuda(args, stream); break; case GGML_TYPE_IQ1_BN: mul_mat_vec_iq1_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); @@ -516,15 +479,6 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ2_BN: mul_mat_vec_iq2_bn_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; - case GGML_TYPE_IQ4_NL: - mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); - break; - case GGML_TYPE_MXFP4: - mul_mat_vec_mxfp4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); - break; - case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); - break; case GGML_TYPE_IQ2_K: mul_mat_vec_iq2_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; @@ -570,9 +524,6 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm case GGML_TYPE_IQ6_K: mul_mat_vec_iq6_k_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break; - case GGML_TYPE_IQ3_S: - mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); - break; case GGML_TYPE_IQ2_K_R4: mul_mat_vec_iq2_k_r4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream); break;