diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 9fbe65e3a7..f4d823e91a 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -3,5 +3,6 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) # list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 05d0c73b7e..5f2c2a5aab 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -12,7 +12,13 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" -template +template float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) { // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. @@ -23,18 +29,32 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr int kBlockPerCu = 2; // This part comes from the Codegen +#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16) constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t K_Tile = 128; constexpr ck_tile::index_t M_Warp = 1; constexpr ck_tile::index_t N_Warp = 4; constexpr ck_tile::index_t K_Warp = 1; - constexpr ck_tile::index_t M_Warp_Tile = 32; - constexpr ck_tile::index_t N_Warp_Tile = 32; - constexpr ck_tile::index_t K_Warp_Tile = 16; + constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 16 : 32; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 64 : 16; +#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 128; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 8; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 32 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; +#endif using CodegenFlatmmShape = ck_tile::TileFlatmmShape, ck_tile::sequence, @@ -49,54 +69,112 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con AccDataType, CodegenFlatmmShape, CodegenGemmTraits>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; - using CodegenFlatmmPipeline = - ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::FlatmmKernel; + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; - auto kargs = Kernel::MakeKernelArgs(args); + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + auto kargs = Kernel::MakeKernelArgs(args); - if(!Kernel::IsSupportedArgument(kargs)) + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + return Run(ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + return Run(ck_tile::integral_constant{}); } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; } #include "run_flatmm_example.inc" +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } + return -1; +} + int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 355ac45ebe..bbce978724 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -31,7 +31,7 @@ #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif -template +template struct GemmBasicTypeConfig; template <> @@ -44,9 +44,47 @@ struct GemmBasicTypeConfig // ToDo: Add more bias config to support different categories of GEMM. }; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + template struct DataTypeTraits; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; template <> struct DataTypeTraits { @@ -65,13 +103,11 @@ struct DataTypeTraits static constexpr const char* name = "fp16"; }; -using Types = GemmBasicTypeConfig; - -// Specific type aliases for easy access -using ADataType = Types::ADataType; -using BDataType = Types::BDataType; -using AccDataType = Types::AccDataType; -using CDataType = Types::CDataType; +template +struct is_8bit_type + : std::bool_constant || std::is_same_v> +{ +}; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc index 864d888074..15a9df2c0c 100644 --- a/example/ck_tile/18_flatmm/run_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -1,6 +1,20 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include + +template +constexpr const char* DataTypeToString() { + if constexpr (std::is_same_v) { + return "fp16"; + } else if constexpr (std::is_same_v) { + return "fp8"; + } else if constexpr (std::is_same_v) { + return "bf8"; + } else { + return "unknown"; + } +} template static constexpr inline auto is_row_major(Layout layout_) @@ -11,7 +25,7 @@ static constexpr inline auto is_row_major(Layout layout_) // mfma_type, 0:32x32, 1:16x16 template -auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; @@ -29,13 +43,13 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) { ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } - else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) { ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); std::copy(t.begin(), t.end(), t_view.begin()); @@ -44,6 +58,7 @@ auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma return t; } +template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -64,7 +79,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, ck_tile::DeviceMem& b_shuffle_dev_buf, ck_tile::DeviceMem& c_dev_buf, @@ -91,7 +112,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = flatmm_calc( + float ave_time = flatmm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; @@ -100,7 +121,7 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Flatmm kernel with M =" << M << " N =" << N << " K =" << K + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -108,7 +129,10 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, return ave_time; } -template +template int run_flatmm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, @@ -119,6 +143,11 @@ int run_flatmm_example_with_layouts(int argc, if(!result) return -1; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); @@ -154,11 +183,17 @@ int run_flatmm_example_with_layouts(int argc, // do pre-shuffle std::string mfma = arg_parser.get_str("prec"); - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, 0); +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + ck_tile::index_t mfma_type = 1; +#else + ck_tile::index_t mfma_type = 0; +#endif + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type); ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); - invoke_flatmm(a_dev_buf, + invoke_flatmm( + a_dev_buf, b_shuffle_dev_buf, c_dev_buf, M, @@ -184,7 +219,7 @@ int run_flatmm_example_with_layouts(int argc, a_host, b_origin_host, c_ref_host); const float max_accumulated_value = *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_rslt_host, c_ref_host, "Error: Incorrect results!", @@ -242,7 +277,7 @@ int run_flatmm_example_with_layouts(int argc, c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); const float max_accumulated_value = *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); - const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); + const auto rtol_atol = calculate_rtol_atol(K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", @@ -257,25 +292,3 @@ int run_flatmm_example_with_layouts(int argc, return pass; } - -int run_flatmm_example(int argc, char* argv[]) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - std::string a_layout = arg_parser.get_str("a_layout"); - std::string b_layout = arg_parser.get_str("b_layout"); - - if(a_layout == "R" && b_layout == "C") - { - return run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); - } -} diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 935eb2c028..18b2fe6483 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -66,76 +66,24 @@ struct BlockFlatmmASmemBSmemCRegV1 } // C += A * B - template + template CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockWindow& a_block_window, - const BFlatBlockWindow& b_flat_block_window) const + ABlockWindow& a_warp_windows, + BFlatBlockTensor& b_warp_tensor) const { - static_assert(std::is_same_v && - std::is_same_v && - std::is_same_v, - "wrong!"); - constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; - constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}]; - - static_assert(MPerBlock == BlockGemmShape::kM && KPerBlock == BlockGemmShape::kK, "wrong!"); + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t KPerBlock = BlockGemmShape::kK; constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); constexpr index_t KIterPerWarp = KPerBlock / WG::kK; - constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; - - constexpr index_t NFlatPerBlockPerIter = BlockGemmShape::flatNPerWarp; - constexpr index_t KFlatPerBlockPerIter = BlockGemmShape::flatKPerWarp; - - const index_t iMWarp = get_warp_id() / NWarp; - - // construct A-warp-window - auto a_warp_window_tmp = make_tile_window( - a_block_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - statically_indexed_array< - statically_indexed_array, - MIterPerWarp> - a_warp_windows; - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - a_warp_windows(mIter)(kIter) = a_warp_window_tmp; - - move_tile_window(a_warp_windows(mIter)(kIter), - {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); - }); - }); - - // construct Bflat-warp-window - auto b_flat_warp_windows_tmp = b_flat_block_window; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> - b_flat_warp_windows; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - b_flat_warp_windows(nIter)(kIter) = b_flat_warp_windows_tmp; - - move_tile_window(b_flat_warp_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - }); - }); - - // auto b_warp_windows = b_origin_warp_windows; - auto b_warp_windows = b_flat_warp_windows; - using CWarpDstr = typename WG::CWarpDstr; using CWarpTensor = typename WG::CWarpTensor; @@ -150,9 +98,6 @@ struct BlockFlatmmASmemBSmemCRegV1 const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read B warp tensor from B Block window - const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); - // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -161,7 +106,7 @@ struct BlockFlatmmASmemBSmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); // write C warp tensor into C block tensor c_block_tensor.set_y_sliced_thread_data( @@ -172,16 +117,6 @@ struct BlockFlatmmASmemBSmemCRegV1 }); }); } - - // C = A * B - template - CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, - const BFlatBlockWindow& b_flat_block_window) const - { - auto c_block_tensor = MakeCBlockTile(); - operator()(c_block_tensor, a_block_tensor_tmp, b_flat_block_window); - return c_block_tensor; - } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index eb45e6c0bd..a9ed1519e6 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -321,7 +321,7 @@ struct FlatmmKernel const auto& c_tensor_view = [&]() { if constexpr(std::is_same_v) { - return make_naive_tensor_view( + return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(kargs.stride_C, 1), @@ -330,7 +330,7 @@ struct FlatmmKernel } else { - return make_naive_tensor_view( + return make_naive_tensor_view( c_ptr, make_tuple(kargs.M, kargs.N), make_tuple(1, kargs.stride_C), @@ -426,7 +426,6 @@ struct FlatmmKernel return make_tuple(a_block_window, b_flat_block_window, c_block_window); } - template CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, const BDataType* b_flat_ptr, CDataType* c_ptr, @@ -438,7 +437,8 @@ struct FlatmmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -453,9 +453,8 @@ struct FlatmmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr); } CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const @@ -475,21 +474,12 @@ struct FlatmmKernel // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } - else - { - // Do not compile in case where we have unsupported - // VectorSizeC & data type configuration. - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunFlatmm( - a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } - } } }; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 611aff318f..2ff9d1ebf0 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -73,6 +73,83 @@ struct FlatmmPipelineAGmemBGmemCRegV1 return PipelinePolicy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + + constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; + constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; + constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + // constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num; +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + }); + +#elif defined(USING_MFMA_32x32x16) + static_for<0, + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA +#endif + } + template CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, @@ -89,6 +166,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + // A tile in LDS ADataType* p_a_lds = static_cast(p_smem); @@ -112,6 +208,25 @@ struct FlatmmPipelineAGmemBGmemCRegV1 auto a_lds_gemm_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); + auto a_warp_window_tmp = make_tile_window( + a_lds_gemm_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + // Block GEMM auto block_flatmm = BlockFlatmm(); @@ -126,16 +241,45 @@ struct FlatmmPipelineAGmemBGmemCRegV1 b_flat_distribution); // Acc register tile - auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){}; + auto c_block_tile = block_flatmm.MakeCBlockTile(); // prefetch // global read 0 auto a_block_tile = load_tile(a_copy_dram_window); + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_2; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + { // move to 1 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -152,40 +296,116 @@ struct FlatmmPipelineAGmemBGmemCRegV1 { store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); } + block_sync_lds(); } - index_t iCounter = num_loop - 1; + index_t iCounter = num_loop / 2 - 1; while(iCounter > 0) { // global read i + 1 a_block_tile = load_tile(a_copy_dram_window); - block_sync_lds(); - // GEMM i - block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); block_sync_lds(); + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move to i + 2 move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window, a_block_tile_tmp); + HotLoopScheduler(); + block_sync_lds(); + + // iCounter--; + + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); // move to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // LDS write i + 1 + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + HotLoopScheduler(); + block_sync_lds(); + iCounter--; } // tail { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + // move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // move to next flat K + // move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + HotLoopScheduler(); block_sync_lds(); // GEMM num_loop - 1 - block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); } return c_block_tile; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index d1aac07d54..474924ec84 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -19,23 +19,100 @@ struct UniversalFlatmmPipelineAgBgCrPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number<8>{}), - make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), - number<8>{}, + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; +#elif defined(USING_MFMA_32x32x16) + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, number<1>{}); constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / 8, 8))), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; +#endif +/*xor*/ +#if 0 + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + using ADataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); +#endif return a_lds_block_desc; } @@ -58,7 +135,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() { - return Problem::VectorLoadSize; + return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); } template @@ -82,7 +159,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy constexpr index_t KPack = GetSmemPackA(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; - if constexpr(get_warp_size() % (K2 * M0)) + if constexpr(get_warp_size() >= (K2 * M0)) { constexpr index_t K1 = get_warp_size() / (K2 * M0); constexpr index_t K0 = BlockSize / get_warp_size(); @@ -209,7 +286,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t warp_size = get_warp_size(); - if constexpr(warp_size % (K2 * M0) == 0) + if constexpr(warp_size >= (K2 * M0)) { constexpr index_t K1 = warp_size / (K2 * M0); constexpr index_t K0 = kBlockSize / warp_size; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e75aca1d91..c98d46e3a0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -193,6 +193,14 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 96c3c3d29f..69d22496f1 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1022,7 +1022,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base } else if constexpr(std::is_same_v && std::is_same_v) { - DISPATCH_MFMA_("mfma_f32_116x16x32_fp8_bf8", "+v", "v", "v", "v") + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v") } else if constexpr(std::is_same_v && std::is_same_v) { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 64bd61a3dc..b2f5d56d01 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -57,6 +57,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; @@ -65,6 +66,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };