diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 40ac388d..c27b2a3c 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3224,7 +3224,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_group_norm(ctx, dst); break; case GGML_OP_CONCAT: - ggml_cuda_op_concat(ctx, dst); + if (fusion && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_CPY && + ggml_cuda_concat_cpy(ctx, dst, cgraph->nodes[i+2])) { + i += 2; + } else { + ggml_cuda_op_concat(ctx, dst); + } break; case GGML_OP_UPSCALE: ggml_cuda_op_upscale(ctx, dst); diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 02713c10..d3194b61 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -686,3 +686,65 @@ bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1, #endif return true; } + +template +static __global__ void concat_cpy(const char * csrc1, const char * csrc2, char * cdst, int ne1, int ne, + char ** dest_ptrs, int copy_index) { + + auto dst = (dst_t *)(dest_ptrs ? dest_ptrs[copy_index] : cdst); + auto src1 = (const src_t *)csrc1; + auto src2 = (const src_t *)csrc2; + + for (int i = threadIdx.x; i < ne; i += blockDim.x) { + if constexpr (std::is_same_v) { + dst[i] = __float2bfloat16(i < ne1 ? src1[i] : src2[i - ne1]); + } else { + dst[i] = (dst_t)(i < ne1 ? src1[i] : src2[i - ne1]); + } + } +} + +template +static void ggml_concat_cpy_cuda(const char * src1, const char * src2, char * dst, int ne1, int ne, cudaStream_t stream, + char ** dest_ptrs, int& copy_index) { + + int block_dim = std::min(ne, 768); + concat_cpy<<<1, block_dim, 0, stream>>>(src1, src2, dst, ne1, ne, dest_ptrs, copy_index); + ++copy_index; +} + +bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst, + [[maybe_unused]] bool disable_indirection) { + + if (dst->type != GGML_TYPE_F16 && dst->type != GGML_TYPE_BF16) return false; + //if (ggml_nrows(dst) > 1) return false; + if (dst->src[0] != concat) return false; + if (ggml_nrows(concat->src[0]) != 1 || ggml_nrows(concat->src[1]) != 1) return false; + if (concat->src[0]->type != GGML_TYPE_F32 || concat->src[1]->type != GGML_TYPE_F32) return false; + if (dst->ne[0] != concat->src[0]->ne[0] + concat->src[1]->ne[0]) return false; + + char ** dest_ptrs = nullptr; + int graph_cpynode_index = -1; +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { + dest_ptrs = ctx.cuda_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + } +#endif + + if (dst->type == GGML_TYPE_F16) { + ggml_concat_cpy_cuda((const char *)concat->src[0]->data, (const char *)concat->src[1]->data, + (char *)dst->data, concat->src[0]->ne[0], dst->ne[0], ctx.stream(), dest_ptrs, graph_cpynode_index); + } else { + ggml_concat_cpy_cuda((const char *)concat->src[0]->data, (const char *)concat->src[1]->data, + (char *)dst->data, concat->src[0]->ne[0], dst->ne[0], ctx.stream(), dest_ptrs, graph_cpynode_index); + } + +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { + ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + } +#endif + return true; + +} diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 21f3c874..b2aa7ad5 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -12,3 +12,6 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst1, ggml_tensor * dst2, bool disable_indirection = false); + +bool ggml_cuda_concat_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * concat, const ggml_tensor * dst, + bool disable_indirection = false); diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 154870e3..0475995a 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -6268,11 +6268,13 @@ ggml_cgraph * llm_build_context::build_deepseek2() { kqv = ggml_mul_mat(ctx0, wv_b, kqv_compressed); cb(kqv, "kqv", il); - kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); - cb(kqv, "kqv_perm", il); - - cur = ggml_view_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens, ggml_row_size(kqv->type, n_embd_head_v*n_head), 0); + if (n_tokens > 1) { + kqv = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cb(kqv, "kqv_perm", il); + } + cur = ggml_reshape_2d(ctx0, kqv, n_embd_head_v*n_head, n_tokens); cb(cur, "kqv_2d", il); + } ggml_build_forward_expand(gf, cur);