diff --git a/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh b/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh new file mode 100644 index 00000000..8d443a70 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cu b/ggml/src/ggml-cuda/fattn-mma-f16.cu new file mode 100644 index 00000000..253808c3 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cu @@ -0,0 +1,85 @@ +#include "fattn-mma-f16.cuh" +#include "fattn-mma-f16-interface.cuh" + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + if (Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + if (Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); +} + +template +static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); +} diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 243211f7..1019de4e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -12,95 +12,12 @@ #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" -#include "fattn-mma-f16.cuh" +#include "fattn-mma-f16-interface.cuh" #include "fattn-new-mma.cuh" #include "fattn.cuh" #include -template -static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - if (Q->ne[1] <= 8/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; - } - - if (Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; - } - - if (Q->ne[1] <= 32/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; - } - - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); -} - -template -static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); - break; - default: - GGML_ABORT("fatal error"); - break; - } -} - -static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * K = dst->src[1]; - const ggml_tensor * mask = dst->src[3]; - - float max_bias = 0.0f; - memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - - const bool use_gqa_opt = mask && max_bias == 0.0f; - - GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); - const int gqa_ratio = Q->ne[2] / K->ne[2]; - - if (use_gqa_opt && gqa_ratio % 8 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); - return; - } - - if (use_gqa_opt && gqa_ratio % 4 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); - return; - } - - if (use_gqa_opt && gqa_ratio % 2 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); - return; - } - - ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); -} - static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0];