Also iqk quants

This commit is contained in:
Iwan Kawrakow
2025-10-24 18:55:44 +03:00
parent efa858cb4a
commit 17dedc8cba

View File

@@ -28,7 +28,8 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M_R4> {
namespace {
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
__device__ void iqk_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const void * __restrict__ vx, const void * __restrict__ vy,
const float * bias, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
@@ -102,7 +103,7 @@ __device__ void iqk_mul_mat_vec_q(
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x];
}
}
}
@@ -227,16 +228,18 @@ template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y,
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__global__ void iqk_mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const char * __restrict__ ids_data, const void * __restrict__ bias,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0) {
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, const int64_t bias_nb1) {
int i2 = blockIdx.y;
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
if (i02 < 0) return;
const char * cx = (const char *)vx + i02*nb02;
const char * cy = (const char *)vy + i2*nb12;
char * cdst = (char *)dst + i2*nb2;
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
const float * b = (const float *)(bias ? ids_data ? (const char *)bias + i02*bias_nb1 : bias : nullptr);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, ncols_y, n_interleaved>(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size);
}
template <ggml_type type, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda, int ncols_y, int n_interleaved = 1>
@@ -370,28 +373,52 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
} else {
switch (args.ncols_y) {
case 1:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 2:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 2, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 3:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 3, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 4:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 4, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 5:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 5, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 6:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 6, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 7:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 7, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
case 8:
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0);
iqk_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 8, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u,
args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst,
row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1);
break;
default:
GGML_ASSERT(false);