From 2986d3c21fdc12ce3f53d453cb0d4e983d1ebd17 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 24 Oct 2025 11:27:18 +0300 Subject: [PATCH] Fused ffn_up*unary_op(ffn_gate) for MMVQ (no bias) We see nearly 2% TG speedup for Ling-mini-2.0 and about 1% for DeepSeek-Lite. --- ggml/src/ggml-cuda.cu | 35 ++++--- ggml/src/ggml-cuda/iqk_mmvq.cu | 16 ++-- ggml/src/ggml-cuda/mmvq-args.h | 23 +++++ ggml/src/ggml-cuda/mmvq.cu | 164 ++++++++++++++++++++++++--------- ggml/src/ggml-cuda/mmvq.cuh | 6 ++ 5 files changed, 183 insertions(+), 61 deletions(-) create mode 100644 ggml/src/ggml-cuda/mmvq-args.h diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 9f7fd33f..f507012a 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2644,6 +2644,27 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te local_src1.nb[1] = src_1_ddq_size; } + bool fuse_next = next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && + ggml_backend_buffer_is_cuda(next->src[0]->buffer) && + !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && + ((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id && + ggml_backend_buffer_is_cuda(next->buffer) && + ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id; + + if (!dst->src[4] && !dst->src[5]) { + local_dst.data = dst_gate_contiguous.get(); + auto the_destination = fuse_next ? &local_dst : dst; + auto unary_op = (ggml_unary_op)dst->op_params[0]; + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, the_destination, + (const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), + (float *)dst_gate_contiguous.get(), + 0, src0_1->ne[1], 1, src1_padded_col_size, unary_op, stream); + CUDA_CHECK(cudaGetLastError()); + + if (!fuse_next) return false; + } + else { + local_dst.data = dst_up_contiguous.get(); ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_1, &local_src1, ids, &local_dst, (const char *)src0_1->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_up_contiguous.get(), @@ -2669,24 +2690,16 @@ static bool ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_te local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2], local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream); } + } - if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && - ggml_backend_buffer_is_cuda(next->src[0]->buffer) && - !ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) && - ((ggml_backend_cuda_buffer_context *)next->src[0]->buffer->context)->device == device_id && - ggml_backend_buffer_is_cuda(next->buffer) && - ((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) { + if (fuse_next) { auto unary_op = (ggml_unary_op)dst->op_params[0]; if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream); - } else { - ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids, - (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), - (float *)dst_gate_contiguous.get()); + CUDA_CHECK(cudaGetLastError()); } - CUDA_CHECK(cudaGetLastError()); const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING); GGML_ASSERT(dst->ne[0] % QK8_1 == 0); diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 462200d5..2e53b622 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -272,28 +272,28 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) { switch (args.ncols_y) { case 1: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 2: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 3: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 4: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 5: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 6: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 7: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 8: - iqk_mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); + iqk_mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, row_size, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; default: GGML_ASSERT(false); diff --git a/ggml/src/ggml-cuda/mmvq-args.h b/ggml/src/ggml-cuda/mmvq-args.h new file mode 100644 index 00000000..3cf9189b --- /dev/null +++ b/ggml/src/ggml-cuda/mmvq-args.h @@ -0,0 +1,23 @@ +#pragma once + +#include "common.cuh" + +struct mmvq_args { + const void * vx_u; + const void * vx_g; + const void * vy; + float * dst; + const char * ids_data; + const int ncols_x; + const int nrows_x; + const int nrows_y; + const int ncols_y; + const int nrows_dst; + const int ne2; + const uint64_t nb02; + const uint64_t nb12; + const uint64_t nb2; + const uint64_t ids_nb0; + ggml_unary_op unary_op; +}; + diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index cef18bf0..d60edcbd 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -262,6 +262,29 @@ static __global__ void mul_mat_vec_q( mul_mat_vec_q(cx, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst); } +template +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +// tell the compiler to use as many registers as it wants, see nwarps definition below +__launch_bounds__(nwarps*WARP_SIZE, 1) +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) +static __global__ void fused_mul_mat_vec_q( + const void * __restrict__ vup, const void * __restrict__ vgate, + const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, + const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, + const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) { + + int i2 = blockIdx.y; + char * cdst = (char *)dst + i2*nb2; + int i02 = *(const int *)(ids_data + i2*ids_nb0); + if (i02 < 0) { + return; + } + const char * cx_u = (const char *)vup + i02*nb02; + const char * cx_g = (const char *)vgate + i02*nb02; + const char * cy = (const char *)vy + i2*nb12; + fused_mul_mat_vec_q(cx_u, cx_g, cy, (float *)cdst, ncols_x, nrows_x, nrows_y, nrows_dst, unary_op); +} + template static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) { @@ -273,71 +296,95 @@ static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) { int64_t rows_per_cuda_block = ggml_cuda_info().devices[id].cc < CC_RDNA2 ? args.ncols_y < 4 ? 1 : 2 : 1; - //if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2 - // switch(args.ncols_y) { - // case 1: - // nwarps = 4; - // rows_per_cuda_block = 1; - // break; - // case 2: - // case 3: - // case 4: - // nwarps = 4; - // rows_per_cuda_block = 2; - // break; - // case 5: - // case 6: - // case 7: - // case 8: - // nwarps = 2; - // rows_per_cuda_block = 2; - // break; - // default: - // GGML_ABORT("fatal error"); - // break; - // } - //} const int64_t nblocks = (args.nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; const dim3 block_nums(nblocks, args.ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); + if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) { switch (args.ncols_y) { case 1: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 2: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 3: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 4: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 5: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 6: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 7: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + case 8: + fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, + args.dst, args.ids_data, + args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0, args.unary_op); + break; + default: + GGML_ABORT("fatal error"); + break; + } + } else { + switch (args.ncols_y) { + case 1: + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 2: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 3: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 4: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 5: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 6: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 7: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; case 8: - mul_mat_vec_q<<>>(args.vx, args.vy, args.dst, args.ids_data, + mul_mat_vec_q<<>>(args.vx_u, args.vy, args.dst, args.ids_data, args.ncols_x, args.nrows_x, args.nrows_y, args.nrows_dst, args.nb02, args.nb12, args.nb2, args.ids_nb0); break; default: GGML_ABORT("fatal error"); break; } + } } template @@ -446,9 +493,9 @@ static void mul_mat_vec_iq3_s_q8_1_cuda(const mmvq_args & args, cudaStream_t str static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggml_type type, const int64_t ne00, const int64_t ne0, const int64_t ne2, const int64_t nb02, const int64_t nb12, const int64_t nb2, const int64_t ids_nb0, - const char * src0_dd_i, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data, + const char * src0_dd_u, const char * src0_dd_g, const char * src1_ddq_i, float * dst_dd_i, const char * ids_data, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { + const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream) { const int64_t row_diff = row_high - row_low; @@ -475,7 +522,8 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm // const uint64_t nb2; // const uint64_t ids_nb0; //}; - mmvq_args args{/* vx */ src0_dd_i, + mmvq_args args{/* vx_u */ src0_dd_u, + /* vx_g */ src0_dd_g, /* vy */ src1_ddq_i, /* dst */ dst_dd_i, /* ids_data */ ids_data, @@ -488,7 +536,8 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm /* nb02 */ uint64_t(nb02), /* nb12 */ uint64_t(nb12), /* nb2 */ uint64_t(nb2), - /* ids_nb0 */ uint64_t(ids_nb0) + /* ids_nb0 */ uint64_t(ids_nb0), + /* unary_op */ unary_op }; switch (type) { @@ -608,9 +657,9 @@ void ggml_cuda_op_mul_mat_vec_q_3D( ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, dst->ne[2], src0->nb[2], src1_row_size, dst->nb[2], 0, - src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, + src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, row_low, row_high, src1_ncols, - src1_padded_row_size, stream); + src1_padded_row_size, GGML_UNARY_OP_COUNT, stream); GGML_UNUSED(src1_ddf_i); } @@ -629,9 +678,9 @@ void ggml_cuda_op_mul_mat_vec_q( ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, 1, 0, 0, 0, 0, - src0_dd_i, src1_ddq_i, dst_dd_i, nullptr, + src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, nullptr, row_low, row_high, src1_ncols, - src1_padded_row_size, stream); + src1_padded_row_size, GGML_UNARY_OP_COUNT, stream); GGML_UNUSED(src1_ddf_i); } @@ -655,13 +704,44 @@ void ggml_cuda_op_mul_mat_vec_q_id( ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, dst->ne[2], src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], - src0_dd_i, src1_ddq_i, dst_dd_i, (const char *)ids->data, + src0_dd_i, nullptr, src1_ddq_i, dst_dd_i, (const char *)ids->data, row_low, row_high, src1_ncols, - src1_padded_row_size, stream); + src1_padded_row_size, GGML_UNARY_OP_COUNT, stream); GGML_UNUSED(src1_ddf_i); } +void ggml_cuda_op_fused_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_u, const char * src0_dd_g, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream) { + + GGML_ASSERT(unary_op == GGML_UNARY_OP_SILU || + unary_op == GGML_UNARY_OP_RELU || + unary_op == GGML_UNARY_OP_GELU); + GGML_ASSERT(src0_dd_u && src0_dd_g && ids && ids->data); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); + GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1); + GGML_ASSERT(ids->ne[0] == dst->ne[2]); + + const int64_t ne0 = dst->ne[0]; + + ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, + ne00, ne0, dst->ne[2], + src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], + src0_dd_u, src0_dd_g, src1_ddq_i, dst_dd_i, (const char *)ids->data, + row_low, row_high, src1_ncols, + src1_padded_row_size, unary_op, stream); + + GGML_UNUSED(src1_ddf_i); +} + + bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) { switch (src0_type) { case GGML_TYPE_Q4_0: diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index d17765f1..94542f8a 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -25,3 +25,9 @@ void ggml_cuda_op_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); + +void ggml_cuda_op_fused_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const char * src0_dd_u, const char * src0_dd_g, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, ggml_unary_op unary_op, cudaStream_t stream);