diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 41d4ba3a..da11a427 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3034,6 +3034,15 @@ static void ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor src0_1->type, stream); CUDA_CHECK(cudaGetLastError()); + if (src1->ne[1] == 1 && src0_1->type == src0_2->type) { + ggml_cuda_op_fused_mul_mat_vec_q_id(ctx, src0_1, src1, nullptr, dst, + dst->src[4], dst->src[5], + (const char *)src0_1->data, (const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), + (float *)dst->data, 0, src0_1->ne[1], 1, ne10_padded, + (ggml_unary_op)dst->op_params[0], stream); + return; + } + ggml_cuda_op_mul_mat_vec_q(ctx, src0_1, src1, dst, (const char *)src0_1->data, nullptr, src1_quantized.get(), dst_up.get(), 0, src0_1->ne[1], src1->ne[1], ne10_padded, stream); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/iqk_mmvq.cu b/ggml/src/ggml-cuda/iqk_mmvq.cu index 770bad95..c355f354 100644 --- a/ggml/src/ggml-cuda/iqk_mmvq.cu +++ b/ggml/src/ggml-cuda/iqk_mmvq.cu @@ -251,7 +251,7 @@ __global__ void iqk_fused_mul_mat_vec_q( const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) { int i2 = blockIdx.y; - int i02 = *(const int *)(ids_data + i2*ids_nb0); + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; if (i02 < 0) return; const char * cx_u = (const char *)vx_u + i02*nb02; const char * cx_g = (const char *)vx_g + i02*nb02; @@ -305,12 +305,7 @@ void iqk_mul_mat_vec_q_cuda(const mmvq_args & args, cudaStream_t stream) { const int64_t row_size = ggml_row_size(type, args.ncols_x); - //const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const char * __restrict__ ids_data, - //const void * __restrict__ bias_u, const void * __restrict__ bias_g, const uint64_t bias_nb1, - //const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst, const int64_t row_size, - //const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, const int64_t ids_nb0, ggml_unary_op unary_op) { - - if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) { + if (args.vx_u && args.vx_g && args.unary_op != GGML_UNARY_OP_COUNT) { switch (args.ncols_y) { case 1: iqk_fused_mul_mat_vec_q<<>>( diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 96ca17be..bd6758a9 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -287,7 +287,7 @@ static __global__ void fused_mul_mat_vec_q( int i2 = blockIdx.y; char * cdst = (char *)dst + i2*nb2; - int i02 = *(const int *)(ids_data + i2*ids_nb0); + int i02 = ids_data ? *(const int *)(ids_data + i2*ids_nb0) : i2; if (i02 < 0) { return; } @@ -314,7 +314,7 @@ static void mul_mat_vec_q_cuda_T(const mmvq_args & args, cudaStream_t stream) { const dim3 block_nums(nblocks, args.ne2, 1); const dim3 block_dims(WARP_SIZE, nwarps, 1); - if (args.vx_u && args.vx_g && args.ids_data && args.unary_op != GGML_UNARY_OP_COUNT) { + if (args.vx_u && args.vx_g && args.unary_op != GGML_UNARY_OP_COUNT) { switch (args.ncols_y) { case 1: fused_mul_mat_vec_q<<>>(args.vx_u, args.vx_g, args.vy, @@ -520,23 +520,6 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm // nrows_dst == nrows of the matrix that the kernel writes into const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; - -//struct mmvq_args { -// const void * vx; -// const void * vy; -// float * dst; -// const char * ids_data; -// const int ncols_x; -// const int nrows_x; -// const int nrows_y; -// const int ncols_y; -// const int nrows_dst; -// const int ne2; -// const uint64_t nb02; -// const uint64_t nb12; -// const uint64_t nb2; -// const uint64_t ids_nb0; -//}; mmvq_args args{/* vx_u */ src0_dd_u, /* vx_g */ src0_dd_g, /* bias_u */ bias_u, @@ -748,21 +731,21 @@ void ggml_cuda_op_fused_mul_mat_vec_q_id(ggml_backend_cuda_context & ctx, GGML_ASSERT(bias_u->ne[0] == dst->ne[0]); GGML_ASSERT(bias_g->ne[0] == dst->ne[0]); } - GGML_ASSERT(src0_dd_u && src0_dd_g && ids && ids->data); + GGML_ASSERT(src0_dd_u && src0_dd_g); const int64_t ne00 = src0->ne[0]; const int64_t ne10 = src1->ne[0]; GGML_ASSERT(ne10 % QK8_1 == 0); GGML_ASSERT(src0->ne[3] == 1 && src1->ne[3] == 1 && dst->ne[3] == 1); GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1); - GGML_ASSERT(ids->ne[0] == dst->ne[2]); + GGML_ASSERT(!ids || ids->ne[0] == dst->ne[2]); const int64_t ne0 = dst->ne[0]; ggml_cuda_op_mul_mat_vec_q_impl(ctx, src0->type, ne00, ne0, dst->ne[2], - src0->nb[2], src1->nb[2], dst->nb[2], ids->nb[0], bias_u ? bias_u->nb[1] : 0, - src0_dd_u, src0_dd_g, src1_ddq_i, dst_dd_i, (const char *)ids->data, + src0->nb[2], src1->nb[2], dst->nb[2], ids ? ids->nb[0] : 0, bias_u ? bias_u->nb[1] : 0, + src0_dd_u, src0_dd_g, src1_ddq_i, dst_dd_i, ids ? (const char *)ids->data : nullptr, bias_u ? bias_u->data : nullptr, bias_g ? bias_g->data : nullptr, row_low, row_high, src1_ncols, src1_padded_row_size, unary_op, stream);