From c4f2af9ccc42ee77fd9ae1ce29c8e0d5d69a75b3 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 25 Aug 2025 06:52:33 +0300 Subject: [PATCH] WIP --- ggml/src/ggml-cuda.cu | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index c234ec07..b0d5620f 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -2718,13 +2718,34 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor dst_row.data = dst_up_contiguous.get(); ggml_cuda_mul_mat_q_id(ctx, src0_1, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[4]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[4]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } dst_row.data = dst_gate_contiguous.get(); ggml_cuda_mul_mat_q_id(ctx, src0_2, src1, ids, &dst_row, (char *)ids_device.get(), src1_quantized.get()); + if (dst->src[5]) { + ggml_cuda_add_id((const float *)dst_row.data, (const float *)dst->src[5]->data, (const int32_t *)ids->data, + (float *)dst_row.data, dst_row.ne[0], dst_row.ne[1], dst_row.ne[2], dst_row.ne[0], dst_row.ne[1], + dst_row.nb[1], dst_row.nb[2], dst->src[4]->nb[1], ids->nb[1], stream); + CUDA_CHECK(cudaGetLastError()); + } - ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), + auto unary_op = (ggml_unary_op)dst->op_params[0]; + if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) { + ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), + (float *)dst->data, ggml_nelements(dst), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0], + 1.702f, 7.0f, stream); + } else { + ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row), (const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data); + } + CUDA_CHECK(cudaGetLastError()); + if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) && ggml_cuda_should_use_mmq(next->src[0]->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) {