diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 4820ff6958..c0ad4b5489 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -39,7 +39,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = std::is_same_v ? 32 : 16; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; using CodegenFlatmmShape = ck_tile::TileFlatmmShape, @@ -132,6 +132,10 @@ int run_flatmm_example(int argc, char* argv[]) { run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); } + else if(data_type == "bf8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } else { throw std::runtime_error("Unsupported data_type!"); diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index b995fced86..125e418061 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -61,6 +61,15 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; @@ -69,6 +78,12 @@ struct DataTypeTraits { static constexpr const char* name = "fp8"; }; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; template <> struct DataTypeTraits { @@ -87,6 +102,12 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; +template +struct is_8bit_type + : std::bool_constant || std::is_same_v> +{ +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 2413f16056..5b30035031 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -9,6 +9,8 @@ constexpr const char* DataTypeToString() { return "fp16"; } else if constexpr (std::is_same_v) { return "fp8"; + } else if constexpr (std::is_same_v) { + return "bf8"; } else { return "unknown"; } @@ -41,13 +43,13 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) { ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) { ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); std::copy(t.begin(), t.end(), t_view.begin()); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f7ff3c3b19..3f34e4268b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -127,6 +127,10 @@ using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, + 2>>; using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index caad358d78..9cd5c5c59b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -59,6 +59,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; // clang-format on