diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index c355f354..0010d32e 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -28,7 +28,8 @@ struct ggml_cuda_type_traits { namespace { template __device__ void iqk_mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, + const void * __restrict__ vx, const void * __restrict__ vy, + const float * bias, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size) { constexpr int qk = ggml_cuda_type_traits::qk; @@ -102,7 +103,7 @@ __device__ void iqk_mul_mat_vec_q( } if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) { - dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; + dst[j*nrows_dst + row0 + threadIdx.x] = bias ? tmp[j][threadIdx.x] + bias[j*nrows_dst + row0 + threadIdx.x] : tmp[j][threadIdx.x]; } } } @@ -227,16 +228,18 @@ template (cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); + const float * b = (const float *)(bias ? ids_data ? (const char *)bias + i02*bias_nb1 : bias : nullptr); + iqk_mul_mat_vec_q(cx, cy, b, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, row_size); } template @@ -370,28 +373,52 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) { } else { switch (args.ncols_y) { case 1: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 2: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 3: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 4: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 5: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 6: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 7: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; case 8: - iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>( + args.vx_u, args.vy, args.dst, args.ids_data, args.bias_u, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, + row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.bias_nb1); break; default: GGML_ASSERT(false);