diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh new file mode 100644 index 00000000..34946c59 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-wmma-f16-interface.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu new file mode 100644 index 00000000..5cc51a46 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -0,0 +1,167 @@ +#include "fattn-wmma-f16.cuh" +#include "fattn-wmma-f16-interface.cuh" + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * V = dst->src[2]; + + if (Q->ne[0] != V->ne[0]) { + if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) { + fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]); + GGML_ABORT("fatal error"); + } + } + + const int32_t precision = KQV->op_params[3]; + + if (precision != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); + break; + default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); + GGML_ABORT("fatal error"); + break; + } + } else { + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); + break; + // case 256: + // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + // break; + default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); + GGML_ABORT("fatal error"); + break; + } + } + return; + } + + if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { + constexpr int cols_per_block = 8; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); + break; + default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); + GGML_ABORT("fatal error"); + break; + } + return; + } + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); + break; + default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); + GGML_ABORT("fatal error"); + break; + } + return; + } + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); + break; + case 192: + ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); + break; + default: + fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); + GGML_ABORT("fatal error"); + break; + } +} + diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1019de4e..34a59827 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -11,177 +11,13 @@ #include "fattn-tile-f32.cuh" #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" -#include "fattn-wmma-f16.cuh" +#include "fattn-wmma-f16-interface.cuh" #include "fattn-mma-f16-interface.cuh" #include "fattn-new-mma.cuh" #include "fattn.cuh" #include -static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - const ggml_tensor * V = dst->src[2]; - - if (Q->ne[0] != V->ne[0]) { - if (!((Q->ne[0] == 192 && V->ne[0] == 128) || (Q->ne[0] == 576 && V->ne[0] == 512))) { - fprintf(stderr, "======================= %s: Unhandled head size combination %d, %d\n", __func__, (int)Q->ne[0], (int)V->ne[0]); - GGML_ABORT("fatal error"); - } - } - - const int32_t precision = KQV->op_params[3]; - - if (precision != GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 32 || Q->ne[0] > 128) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, float>(ctx, dst); - break; - case 192: - ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); - break; - default: - fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); - GGML_ABORT("fatal error"); - break; - } - } else { - constexpr int cols_per_block = 32; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, float>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, float>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, float>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, float>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, float>(ctx, dst); - break; - case 192: - ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, float>(ctx, dst); - break; - // case 256: - // ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); - // break; - default: - fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); - GGML_ABORT("fatal error"); - break; - } - } - return; - } - - if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { - constexpr int cols_per_block = 8; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); - break; - case 192: - ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); - break; - default: - fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); - GGML_ABORT("fatal error"); - break; - } - return; - } - - if (Q->ne[1] <= 32) { - constexpr int cols_per_block = 16; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); - break; - case 192: - ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); - break; - default: - fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); - GGML_ABORT("fatal error"); - break; - } - return; - } - - constexpr int cols_per_block = 32; - switch (Q->ne[0]) { - case 64: - ggml_cuda_flash_attn_ext_wmma_f16_case< 64, 64, cols_per_block, half>(ctx, dst); - break; - case 80: - ggml_cuda_flash_attn_ext_wmma_f16_case< 80, 80, cols_per_block, half>(ctx, dst); - break; - case 96: - ggml_cuda_flash_attn_ext_wmma_f16_case< 96, 96, cols_per_block, half>(ctx, dst); - break; - case 112: - ggml_cuda_flash_attn_ext_wmma_f16_case<112, 112, cols_per_block, half>(ctx, dst); - break; - case 128: - ggml_cuda_flash_attn_ext_wmma_f16_case<128, 128, cols_per_block, half>(ctx, dst); - break; - case 192: - ggml_cuda_flash_attn_ext_wmma_f16_case<192, 128, cols_per_block, half>(ctx, dst); - break; - case 256: - ggml_cuda_flash_attn_ext_wmma_f16_case<256, 256, cols_per_block, half>(ctx, dst); - break; - default: - fprintf(stderr, "======================= %s: Unhandled head size %d\n", __func__, (int)Q->ne[0]); - GGML_ABORT("fatal error"); - break; - } -} #define FATTN_VEC_F16_CASE(D, type_K, type_V) \ if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ ggml_cuda_flash_attn_ext_vec_f16_case(ctx, dst); \