From e86c0333a5b088288847c2d31f4b0e45c6db732b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 15 Apr 2025 12:43:41 +0300 Subject: [PATCH] Allow q8_0 KV cache for head size 256 --- ggml/src/ggml-cuda.cu | 5 +++++ ggml/src/ggml-cuda/fattn.cu | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 63e16d52..1f62b882 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -3578,6 +3578,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons if (op->src[0]->ne[0] == 128) { return true; } + if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 && + (op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) && + (op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) { + return true; + } if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) { return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8c3f4000..55116058 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -248,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) @@ -265,6 +266,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F16_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F16_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) @@ -347,6 +349,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) @@ -358,6 +361,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)