From 0459f595d7d8d03ee71d2061cdbfbbc42fd6c49e Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Wed, 29 Oct 2025 13:56:16 +0200 Subject: [PATCH] CUDA: corectly detect if flash attention is supported (#875) * Don't use vector kernels if K or V are quantized * Correctly determine if FA is supported * Also wmma * Minor --------- Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda.cu | 26 +------ .../src/ggml-cuda/fattn-mma-f16-interface.cuh | 2 + ggml/src/ggml-cuda/fattn-mma-f16.cu | 7 ++ ggml/src/ggml-cuda/fattn-tile-f16.cu | 7 ++ ggml/src/ggml-cuda/fattn-tile-f16.cuh | 2 + ggml/src/ggml-cuda/fattn-tile-f32.cu | 7 ++ ggml/src/ggml-cuda/fattn-tile-f32.cuh | 2 + .../src/ggml-cuda/fattn-vec-f16-interface.cuh | 2 + ggml/src/ggml-cuda/fattn-vec-f16.cu | 48 ++++++++++++- .../src/ggml-cuda/fattn-vec-f32-interface.cuh | 2 + ggml/src/ggml-cuda/fattn-vec-f32.cu | 61 +++++++++++++++- .../ggml-cuda/fattn-wmma-f16-interface.cuh | 2 + ggml/src/ggml-cuda/fattn-wmma-f16.cu | 6 ++ ggml/src/ggml-cuda/fattn.cu | 69 +++++++++++++++++-- ggml/src/ggml-cuda/fattn.cuh | 2 + 15 files changed, 212 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index b47b7dd5..abd2479b 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -4386,31 +4386,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128; #else - if (op->src[0]->ne[0] == 128) { - return true; - } - if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 && - (op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) && - (op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) { - return true; - } - if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) { - return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) || - (op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0); - } - if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) { - const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc; - int gqa = op->src[0]->ne[2]/op->src[1]->ne[2]; - return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0); - } - if (op->src[1]->ne[0] > 256) { - return false; - } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { - return true; - } - return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && - op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; + return ggml_cuda_fattn_is_supported(*cuda_ctx, op); #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh b/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh index 8d443a70..ef0107c1 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh @@ -3,3 +3,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_mma_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu index 253808c3..01e63541 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -83,3 +83,10 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); } + +bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + auto K = dst->src[1]; + auto V = dst->src[1]; + if (K->ne[0] != V->ne[0]) return false; + return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256; +} diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index c79b4821..6dcb12c2 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -353,3 +353,10 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten launch_fattn_tile_f16_64_128(ctx, dst); } } + +bool ggml_cuda_fattn_tile_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + auto K = dst->src[1]; + auto V = dst->src[2]; + if (K->ne[0] != V->ne[0]) return false; + return K->ne[0] == 64 || K->ne[0] == 128; +} diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ggml/src/ggml-cuda/fattn-tile-f16.cuh index ffc58784..5e375795 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_tile_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index 3d1926ce..a3a22904 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -352,3 +352,10 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten launch_fattn_tile_f32_64_128(ctx, dst); } } + +bool ggml_cuda_fattn_tile_f32_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + auto K = dst->src[1]; + auto V = dst->src[2]; + if (K->ne[0] != V->ne[0]) return false; + return K->ne[0] == 64 || K->ne[0] == 128; +} diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ggml/src/ggml-cuda/fattn-tile-f32.cuh index b1c546c8..490cb28d 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_tile_f32_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16-interface.cuh b/ggml/src/ggml-cuda/fattn-vec-f16-interface.cuh index 7fa99ea1..a37a69c2 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16-interface.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16-interface.cuh @@ -3,3 +3,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_vec_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cu b/ggml/src/ggml-cuda/fattn-vec-f16.cu index cadc58a3..d75a5601 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cu +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cu @@ -102,4 +102,50 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } - +bool ggml_cuda_fattn_vec_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + auto K = dst->src[1]; + auto V = dst->src[2]; + if (K->ne[0] != V->ne[0]) { + if (K->ne[0] != 192 || V->ne[2] != 128) return false; + if (K->type != V->type) return false; + return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0; + } +#ifdef GGML_CUDA_FA_ALL_QUANTS + if (K->ne[0] == 64) { + return K->type == GGML_TYPE_F16 && + (V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || + V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || V->type == GGML_TYPE_Q8_0); + } + if (K->ne[0] == 256) { + return K->type == V->type && (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0); + } + if (K->ne[0] != 128 || V->ne[0] != 128) return false; + if ((K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q4_1 || K->type == GGML_TYPE_Q5_0 || K->type == GGML_TYPE_Q5_1 || + K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16) && + (V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || + V->type == GGML_TYPE_Q8_0 || V->type == GGML_TYPE_F16)) return true; + return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) || + (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) || + (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q6_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL); +#else + if (K->ne[0] == 128) { + if (K->type == V->type) { + return K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_IQ4_NL; + } + return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) || + (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL); + } + if (K->type != V->type) return false; + if (K->ne[0] == 64) { + return K->type == GGML_TYPE_F16; + } + if (K->ne[0] == 256) { + return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0; + } + return false; +#endif +} diff --git a/ggml/src/ggml-cuda/fattn-vec-f32-interface.cuh b/ggml/src/ggml-cuda/fattn-vec-f32-interface.cuh index 31f11003..eb9afbf1 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32-interface.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32-interface.cuh @@ -3,3 +3,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_vec_f32_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cu b/ggml/src/ggml-cuda/fattn-vec-f32.cu index 0fe07c1d..c1a5ad2c 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cu +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cu @@ -24,7 +24,7 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) - FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) @@ -71,6 +71,12 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q6_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) #else @@ -83,10 +89,63 @@ void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tens FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE(256, GGML_TYPE_Q8_0,GGML_TYPE_Q8_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ4_NL) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_IQ4_NL) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q6_0, GGML_TYPE_Q5_0) + FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q6_0) + FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_F32_CASE_DKDV(192, 128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + #endif // GGML_CUDA_FA_ALL_QUANTS on_no_fattn_vec_case(Q->ne[0], V->ne[0]); } +bool ggml_cuda_fattn_vec_f32_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + auto K = dst->src[1]; + auto V = dst->src[2]; + if (K->ne[0] != V->ne[0]) { + if (K->ne[0] != 192 || V->ne[2] != 128) return false; + if (K->type != V->type) return false; + return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0; + } +#ifdef GGML_CUDA_FA_ALL_QUANTS + if (K->ne[0] == 64) { + return K->type == GGML_TYPE_F16 && + (V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || + V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || V->type == GGML_TYPE_Q8_0); + } + if (K->ne[0] == 256) { + return K->type == V->type && (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0); + } + if (K->ne[0] != 128 || V->ne[0] != 128) return false; + if ((K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q4_1 || K->type == GGML_TYPE_Q5_0 || K->type == GGML_TYPE_Q5_1 || + K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16) && + (V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || + V->type == GGML_TYPE_Q8_0 || V->type == GGML_TYPE_F16)) return true; + return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) || + (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) || + (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q6_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL); +#else + if (K->ne[0] == 128) { + if (K->type == V->type) { + return K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_IQ4_NL; + } + return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) || + (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) || + (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL); + } + if (K->type != V->type) return false; + if (K->ne[0] == 64) { + return K->type == GGML_TYPE_F16; + } + if (K->ne[0] == 256) { + return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0; + } + return false; +#endif +} diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh index 34946c59..895b7f66 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh @@ -3,3 +3,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_wmma_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 5cc51a46..ca3ae61f 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -165,3 +165,9 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten } } +bool ggml_cuda_fattn_wmma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + auto K = dst->src[1]; + auto V = dst->src[2]; + if (K->ne[0] != V->ne[0]) return K->ne[0] == 192 && V->ne[0] == 128; + return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256; +} diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 90a369d2..a1544e41 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -91,12 +91,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; - if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { - //if (precision == GGML_PREC_DEFAULT) { - // ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - //} else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - //} + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); return; } @@ -125,3 +121,64 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); //ggml_cuda_flash_attn_ext_mma_new(ctx, dst); } + +bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int32_t precision = KQV->op_params[3]; + const int32_t n_swa = KQV->op_params[4]; + if (cc >= CC_OFFSET_AMD) { + return precision == GGML_PREC_DEFAULT ? ggml_cuda_fattn_vec_f16_is_supported(ctx, dst) + : ggml_cuda_fattn_vec_f32_is_supported(ctx, dst); + } + + if (!fast_fp16_available(cc)) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst); + } else { + return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst); + } + } + + if (!fp16_mma_available(cc)) { + if (precision == GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + return ggml_cuda_fattn_vec_f16_is_supported(ctx, dst); + } else { + return ggml_cuda_fattn_tile_f16_is_supported(ctx, dst); + } + } else { + if (Q->ne[1] <= 8 || Q->ne[0] == 256) { + return ggml_cuda_fattn_vec_f32_is_supported(ctx, dst); + } else { + return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst); + } + } + } + + const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations + // So, not sure why in mainline they thought that for CC_ADA_LOVELACE or when KV cache is not f16 the vector kernels are faster. + // On my GPU (RTX-4080) MMA is efinitely faster for GQA, both for f16 and for quantized KV cache. + //const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + //const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < CC_ADA_LOVELACE && !mma_needs_data_conversion; + const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !(Q->ne[1] == 1 && n_swa > 0); + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*WARP_SIZE) == 0; + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1 && !ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + return ggml_cuda_fattn_tile_f32_is_supported(ctx, dst); + } + + if (new_mma_available(cc) && Q->ne[0] == 576) { + return V->ne[0] == 512; + } + + if (!new_mma_available(cc) || K->ne[0] != V->ne[0]) { + return ggml_cuda_fattn_wmma_f16_is_supported(ctx, dst); + } + + return ggml_cuda_fattn_mma_f16_is_supported(ctx, dst); +} diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index ad3ca7a8..bf4b64e3 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_fattn_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);