From 99d9036365ceb41fdb3665b03d9e366c5195130a Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 11 Mar 2025 11:36:11 +0200 Subject: [PATCH] WIP --- ggml/src/ggml-cuda/cpy.cu | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 76bce4c0..7d5b3023 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -527,19 +527,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg if (src1->type == GGML_TYPE_F16) { auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); if (to_fp16) { - to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream); + to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); } } else if (src1->type == GGML_TYPE_F32) { auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); if (to_fp32) { - to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream); + to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); } } else if (src1->type == GGML_TYPE_BF16) { auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); if (to_bf16) { - to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream); + to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[0], main_stream); } } } else { @@ -579,9 +579,21 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_f32_f16; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_f32_f16; - } else { - fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, - ggml_type_name(src0->type), ggml_type_name(src1->type)); - GGML_ABORT("fatal error"); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) return (void*)to_fp16; + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) return (void*)to_fp32; + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) return (void*)to_bf16; + } } + fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); }