From 21a37bfc6c510ed470bf500ddbcf8e9d6eac3329 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 26 Oct 2025 16:59:31 +0200 Subject: [PATCH] Faster copy when tensors are contiguous Relevant for storing data into the KV cache. I see ~1% speedup for fast models (Ling-mini-2.0, gpt-oss-20b, etc.) --- ggml/src/ggml-cuda/cpy.cu | 56 ++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 68dec960..6734acb3 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -38,6 +38,25 @@ static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne cpy_1(cx + x_offset, cdst + dst_offset); } +template +static __global__ void cpy_flt_contiguous(const char * cx, char * cdst_direct, const int ne, + char ** cdst_indirect, int graph_cpynode_index) { + const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + auto dst = (cdst_indirect != nullptr) ? (dst_t *)cdst_indirect[graph_cpynode_index] : (dst_t *)cdst_direct; + auto src = (const src_t *)cx; + + if constexpr (std::is_same_v) { + dst[i] = __float2bfloat16(src[i]); + } else { + dst[i] = (dst_t)src[i]; + } +} + static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { float * cdstf = (float *)(cdsti); @@ -163,6 +182,16 @@ static void ggml_cpy_flt_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +template +static void ggml_cpy_flt_contiguous_cuda( + const char * cx, char * cdst, const int ne, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_flt_contiguous<<>> + (cx, cdst, ne, cdst_indirect, graph_cpynode_index++); +} + static void ggml_cpy_f32_q8_0_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, @@ -404,6 +433,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char * src0_ddc = (char *) src0->data; char * src1_ddc = (char *) src1->data; + bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1); + char ** dest_ptrs_d = nullptr; int graph_cpynode_index = -1; #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) @@ -429,11 +460,23 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if (fast_cpy) { + ggml_cpy_flt_contiguous_cuda(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if (fast_cpy) { + ggml_cpy_flt_contiguous_cuda(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + if (fast_cpy) { + ggml_cpy_flt_contiguous_cuda(src0_ddc, src1_ddc, ne, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -505,6 +548,7 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { + bool fast_cpy = ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_are_same_shape(src0, src1); if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { // Prioritize CUDA graph compatibility over direct memory copy optimization. // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. @@ -514,11 +558,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return nullptr; } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_flt>; + return fast_cpy ? (void *)cpy_flt_contiguous : (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_flt>; + return fast_cpy ? (void *)cpy_flt_contiguous : (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_flt>; + return fast_cpy ? (void *)cpy_flt_contiguous : (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {