From 906a3bffd9ca661e2e6145c15b3fb25d6889ff02 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 5 Nov 2025 15:21:32 +0200 Subject: [PATCH] Fuse copies to K- and V-cache on CUDA --- ggml/src/ggml-cuda.cu | 9 ++++- ggml/src/ggml-cuda/cpy.cu | 72 ++++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/cpy.cuh | 3 ++ 3 files changed, 83 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 143275e5..96f541c0 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3082,7 +3082,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_CPY: - ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); + if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes && + cgraph->nodes[i+1]->op == GGML_OP_VIEW && + cgraph->nodes[i+2]->op == GGML_OP_CPY && + ggml_cuda_cpy_2(ctx, dst->src[0], cgraph->nodes[i+2]->src[0], dst->src[1], cgraph->nodes[i+2]->src[1])) { + i += 2; + } else { + ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); + } break; case GGML_OP_CONT: ggml_cuda_dup(ctx, dst); diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 9eb2fb8d..02713c10 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -614,3 +614,75 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { ggml_type_name(src0->type), ggml_type_name(src1->type)); } } + +template +static __global__ void cpy_flt_contiguous(const int ne, const char * cx1, const char * cx2, char * cdst_direct1, char * cdst_direct2, + char ** cdst_indirect, int graph_cpynode_index) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + auto dst1 = (cdst_indirect != nullptr) ? (dst_t *)cdst_indirect[graph_cpynode_index+0] : (dst_t *)cdst_direct1; + auto dst2 = (cdst_indirect != nullptr) ? (dst_t *)cdst_indirect[graph_cpynode_index+1] : (dst_t *)cdst_direct2; + auto src1 = (const src_t *)cx1; + auto src2 = (const src_t *)cx2; + + if constexpr (std::is_same_v) { + dst1[i] = __float2bfloat16(src1[i]); + dst2[i] = __float2bfloat16(src2[i]); + } else { + dst1[i] = (dst_t)src1[i]; + dst2[i] = (dst_t)src2[i]; + } +} + +template +static void ggml_cpy_flt_contiguous_cuda_2( + const char * cx1, const char * cx2, char * cdst1, char * cdst2, const int ne, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt_contiguous<<>> + (ne, cx1, cx2, cdst1, cdst2, cdst_indirect, graph_cpynode_index); + graph_cpynode_index += 2; +} + +bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src1, const ggml_tensor * src2, + ggml_tensor * dst1, ggml_tensor * dst2, bool disable_indirection) { + if (src1->type != GGML_TYPE_F32 || src2->type != GGML_TYPE_F32) return false; + if (dst1->type != GGML_TYPE_F16 && dst1->type != GGML_TYPE_BF16) return false; + if (dst2->type != GGML_TYPE_F16 && dst2->type != GGML_TYPE_BF16) return false; + bool fast_cpy_1 = ggml_is_contiguous(src1) && ggml_is_contiguous(dst1) && ggml_are_same_shape(src1, dst1); + bool fast_cpy_2 = ggml_is_contiguous(src2) && ggml_is_contiguous(dst2) && ggml_are_same_shape(src2, dst2); + if (!fast_cpy_1 || !fast_cpy_2) return false; + auto nelem = ggml_nelements(dst1); + if (ggml_nelements(dst2) != nelem) return false; + + char ** dest_ptrs = nullptr; + int graph_cpynode_index = -1; +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { + dest_ptrs = ctx.cuda_graph->dest_ptrs_d; + graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; + } +#else + GGML_UNUSED(disable_indirection); +#endif + + if (dst1->type == GGML_TYPE_F16) { + ggml_cpy_flt_contiguous_cuda_2((const char *)src1->data, (const char *)src2->data, + (char *)dst1->data, (char *)dst2->data, nelem, ctx.stream(), dest_ptrs, graph_cpynode_index); + } else { + ggml_cpy_flt_contiguous_cuda_2((const char *)src1->data, (const char *)src2->data, + (char *)dst1->data, (char *)dst2->data, nelem, ctx.stream(), dest_ptrs, graph_cpynode_index); + } + +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) + if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection) { + ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; + } +#endif + return true; +} diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh index 0bd3c0c6..21f3c874 100644 --- a/ggml/src/ggml-cuda/cpy.cuh +++ b/ggml/src/ggml-cuda/cpy.cuh @@ -9,3 +9,6 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1); void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream); + +bool ggml_cuda_cpy_2(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, + ggml_tensor * dst1, ggml_tensor * dst2, bool disable_indirection = false);