From a845e2bfd62c638beb163b4a86a62da1d4658746 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 10 Mar 2025 17:18:23 +0200 Subject: [PATCH] FlashMLA(CUDA): WIP to allow q8_0 quantized cache --- ggml/src/ggml-cuda/cpy.cu | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index fabe8843..76bce4c0 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,4 +1,5 @@ #include "cpy.cuh" +#include "convert.cuh" typedef void (*cpy_kernel_t)(const char * cx, char * cdst); @@ -522,6 +523,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) { + if (src1->type == GGML_TYPE_F16) { + auto to_fp16 = ggml_get_to_fp16_cuda(src0->type); + if (to_fp16) { + to_fp16(src0->data, (half *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream); + } + } + else if (src1->type == GGML_TYPE_F32) { + auto to_fp32 = ggml_get_to_fp32_cuda(src0->type); + if (to_fp32) { + to_fp32(src0->data, (float *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream); + } + } + else if (src1->type == GGML_TYPE_BF16) { + auto to_bf16 = ggml_get_to_bf16_cuda(src0->type); + if (to_bf16) { + to_bf16(src0->data, (nv_bfloat16 *)src1->data, ggml_nrows(src0), src0->ne[1], main_stream); + } + } } else { fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type));