From 7f5503244e7f1d6b1981d7594d8ab72856877bb7 Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 23 Jan 2026 06:47:29 +0000 Subject: [PATCH] Handle quantized cache --- ggml/src/ggml-cuda/fattn.cu | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index efd9cba0..b2c744ac 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -13,6 +13,7 @@ #include "fattn-mma-f16-interface.cuh" #include "fattn-new-mma.cuh" #include "fattn.cuh" +#include "convert.cuh" #include @@ -112,6 +113,31 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst auto local_Q = *Q; auto local_dst = *dst; + ggml_tensor local_K, local_V; + ggml_cuda_pool_alloc K_f16(ctx.pool()); + if (ggml_is_quantized(K->type)) { + // We need to dequantize here, else we will dequantize the same cache twice + K_f16.alloc(ggml_nelements(K)); + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K->data, K_f16.ptr, 1, ggml_nelements(K), ctx.stream()); + + auto bs = ggml_blck_size(K->type); + auto ts = ggml_type_size(K->type); + + local_K = *K; + local_K.data = K_f16.get(); + local_K.type = GGML_TYPE_F16; + local_K.nb[0] = sizeof(half); + local_K.nb[1] = local_K.nb[1]*bs*sizeof(half)/ts; + local_K.nb[2] = local_K.nb[2]*bs*sizeof(half)/ts; + local_K.nb[3] = local_K.nb[3]*bs*sizeof(half)/ts; + local_dst.src[1] = &local_K; + + local_V = local_K; + local_V.ne[0] = V->ne[0]; + local_dst.src[2] = &local_V; + } + local_Q.ne[2] = 16; local_dst.ne[1] = 16; local_dst.src[0] = &local_Q;