mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-24 07:04:11 +00:00
Fusing mmvq also in non-MoE up+gate
This commit is contained in:
@@ -3034,6 +3034,15 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
|
||||
src0_1->type, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
if (src1->ne[1] == 1 && src0_1->type == src0_2->type) {
|
||||
ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, src1, nullptr, dst,
|
||||
dst->src[4], dst->src[5],
|
||||
(const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(),
|
||||
(float *)dst->data, 0, src0_1->ne[1], 1, ne10_padded,
|
||||
(ggml_unary_op)dst->op_params[0], stream);
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_cuda_op_mul_mat_vec_q(ctx, src0_1, src1, dst, (const char *)src0_1->data, nullptr, src1_quantized.get(), dst_up.get(),
|
||||
0, src0_1->ne[1], src1->ne[1], ne10_padded, stream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@@ -251,7 +251,7 @@ __global__ void iqk_fused_mul_mat_vec_q(
|
||||
const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
|
||||
|
||||
int i2 = blockIdx.y;
|
||||
int i02 = *(const int *)(ids_data + i2*ids_nb0);
|
||||
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
|
||||
if (i02 < 0) return;
|
||||
const char * cx_u = (const char *)vx_u + i02*nb02;
|
||||
const char * cx_g = (const char *)vx_g + i02*nb02;
|
||||
@@ -305,12 +305,7 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) {
|
||||
|
||||
const int64_t row_size = ggml_row_size(type, args.ncols_x);
|
||||
|
||||
//const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data,
|
||||
//const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1,
|
||||
//const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size,
|
||||
//const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) {
|
||||
|
||||
if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) {
|
||||
if (args.vx_u && args.vx_g && args.unary_op != GGML_UNARY_OP_COUNT) {
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
iqk_fused_mul_mat_vec_q<type, vdr, vec_dot_q_cuda, 1, n_interleaved><<<block_nums, block_dims, 0, stream>>>(
|
||||
|
||||
@@ -287,7 +287,7 @@ static __global__ void fused_mul_mat_vec_q(
|
||||
|
||||
int i2 = blockIdx.y;
|
||||
char * cdst = (char *)dst + i2*nb2;
|
||||
int i02 = *(const int *)(ids_data + i2*ids_nb0);
|
||||
int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2;
|
||||
if (i02 < 0) {
|
||||
return;
|
||||
}
|
||||
@@ -314,7 +314,7 @@ static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) {
|
||||
const dim3 block_nums(nblocks, args.ne2, 1);
|
||||
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||
|
||||
if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) {
|
||||
if (args.vx_u && args.vx_g && args.unary_op != GGML_UNARY_OP_COUNT) {
|
||||
switch (args.ncols_y) {
|
||||
case 1:
|
||||
fused_mul_mat_vec_q<type, 1, nwarps><<<block_nums, block_dims, 0, stream>>>(args.vx_u, args.vx_g, args.vy,
|
||||
@@ -520,23 +520,6 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
|
||||
// nrows_dst == nrows of the matrix that the kernel writes into
|
||||
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
|
||||
|
||||
|
||||
//struct mmvq_args {
|
||||
// const void * vx;
|
||||
// const void * vy;
|
||||
// float * dst;
|
||||
// const char * ids_data;
|
||||
// const int ncols_x;
|
||||
// const int nrows_x;
|
||||
// const int nrows_y;
|
||||
// const int ncols_y;
|
||||
// const int nrows_dst;
|
||||
// const int ne2;
|
||||
// const uint64_t nb02;
|
||||
// const uint64_t nb12;
|
||||
// const uint64_t nb2;
|
||||
// const uint64_t ids_nb0;
|
||||
//};
|
||||
mmvq_args args{/* vx_u */ src0_dd_u,
|
||||
/* vx_g */ src0_dd_g,
|
||||
/* bias_u */ bias_u,
|
||||
@@ -748,21 +731,21 @@ void ggml_cuda_op_fused_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx,
|
||||
GGML_ASSERT(bias_u->ne[0] == dst->ne[0]);
|
||||
GGML_ASSERT(bias_g->ne[0] == dst->ne[0]);
|
||||
}
|
||||
GGML_ASSERT(src0_dd_u && src0_dd_g && ids && ids->data);
|
||||
GGML_ASSERT(src0_dd_u && src0_dd_g);
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne10 = src1->ne[0];
|
||||
GGML_ASSERT(ne10 % QK8_1 == 0);
|
||||
GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1);
|
||||
GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1);
|
||||
GGML_ASSERT(ids->ne[0] == dst->ne[2]);
|
||||
GGML_ASSERT(!ids || ids->ne[0] == dst->ne[2]);
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
|
||||
ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type,
|
||||
ne00, ne0, dst->ne[2],
|
||||
src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], bias_u ? bias_u->nb[1] : 0,
|
||||
src0_dd_u, src0_dd_g, src1_ddq_i, dst_dd_i, (const char *)ids->data,
|
||||
src0->nb[2], src1->nb[2], dst->nb[2], ids ? ids->nb[0] : 0, bias_u ? bias_u->nb[1] : 0,
|
||||
src0_dd_u, src0_dd_g, src1_ddq_i, dst_dd_i, ids ? (const char *)ids->data : nullptr,
|
||||
bias_u ? bias_u->data : nullptr, bias_g ? bias_g->data : nullptr,
|
||||
row_low, row_high, src1_ncols,
|
||||
src1_padded_row_size, unary_op, stream);
|
||||
|
||||
Reference in New Issue
Block a user