From 9565ca21ec3ec7c0092b366b30da27ea54d7acf0 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Tue, 19 May 2026 20:53:19 +0200 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#5552 (commit 369c7a2) [CK Tile] Eight Waves pipeline for MX GEMM (#5552) ## Motivation Integrate Eight Waves pipeline in MX GEMM ## Technical Details - EightWaves pipeline: - Add pipeline, policy and block gemm (internally using existing implementation used by GEMM and ABQuant) - Extend support of EightWaves policy for FP4 (packed types) - Async pipeline: - Fix pipeline with packed scales (requires MRepeat and NRepeat to be contiguous) - block gemm specific for MX GEMM is defined because distribution encodings have changed - CShuffle: - Add new functionality to support MRepeat and NRepeat contiguous (defined by `TilesPacked`) - Examples: - Refactor examples to easily switch different configurations (similar to GEMM universal) - Scales values generated consistently with other microscale implementations in CK Tile - Add configuration for EightWaves pipeline - Tests: - Unify existing FP8 and FP4 tests - Add tests for EightWaves pipeline - Scales values generated consistently with other microscale implementations in CK Tile Note: FP6 support for MX GEMM was added later and the support for the Eight Waves pipeline will be done in following PR ## Test Plan Add new pipeline to tests: `test_ck_tile_mx_gemm_async` for both FP4 and FP8 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- example/ck_tile/42_mx_gemm/CMakeLists.txt | 2 + example/ck_tile/42_mx_gemm/mx_gemm.cpp | 8 +- example/ck_tile/42_mx_gemm/mx_gemm.hpp | 21 +- .../ck_tile/42_mx_gemm/mx_gemm_instance.hpp | 17 +- example/ck_tile/42_mx_gemm/run_mx_gemm.inc | 40 ++- .../arch/amd_buffer_addressing_builtins.hpp | 4 +- .../ops/epilogue/cshuffle_epilogue.hpp | 144 +++++--- .../block/block_gemm_areg_breg_creg_v1.hpp | 144 -------- ...peline_ag_bg_cr_comp_async_eight_waves.hpp | 7 +- ...ag_bg_cr_comp_async_eight_waves_policy.hpp | 48 +-- ...emm_pipeline_ag_bg_cr_eight_waves_base.hpp | 46 ++- include/ck_tile/ops/gemm_mx.hpp | 4 + ..._mx_gemm_areg_breg_creg_eight_waves_v1.hpp | 310 +++++++++++++++++ .../block/block_mx_gemm_areg_breg_creg_v1.hpp | 324 ++++++++++++++++++ .../ops/gemm_mx/kernel/gemm_mx_kernel.hpp | 21 +- .../gemm_pipeline_ag_bg_cr_comp_async.hpp | 23 +- ...ine_ag_bg_cr_comp_async_default_policy.hpp | 15 +- ...peline_ag_bg_cr_comp_async_eight_waves.hpp | 282 +++++++++++++++ ...ag_bg_cr_comp_async_eight_waves_policy.hpp | 203 +++++++++++ ..._abquant_pipeline_ag_bg_cr_eight_waves.hpp | 7 +- ...t_pipeline_ag_bg_cr_eight_waves_policy.hpp | 9 +- test/ck_tile/gemm_mx/CMakeLists.txt | 7 +- test/ck_tile/gemm_mx/test_mx_gemm_async.cpp | 33 ++ test/ck_tile/gemm_mx/test_mx_gemm_config.hpp | 18 +- test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp | 30 -- test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp | 30 -- .../ck_tile/gemm_mx/test_mx_gemm_instance.hpp | 20 +- test/ck_tile/gemm_mx/test_mx_gemm_util.hpp | 37 +- .../test_mx_grouped_gemm_ut_cases.inc | 1 - 29 files changed, 1472 insertions(+), 383 deletions(-) create mode 100644 include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp create mode 100644 include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp create mode 100644 include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp create mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_async.cpp delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp delete mode 100644 test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp diff --git a/example/ck_tile/42_mx_gemm/CMakeLists.txt b/example/ck_tile/42_mx_gemm/CMakeLists.txt index 3ae8913dd7..ca6a529c54 100644 --- a/example/ck_tile/42_mx_gemm/CMakeLists.txt +++ b/example/ck_tile/42_mx_gemm/CMakeLists.txt @@ -14,6 +14,8 @@ endforeach() if(has_supported_gpu) add_executable(tile_example_mx_gemm mx_gemm.cpp) set(EXAMPLE_MX_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template) + list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1") + list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_MX_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.cpp b/example/ck_tile/42_mx_gemm/mx_gemm.cpp index b1028021af..182815a30c 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.cpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.cpp @@ -102,9 +102,9 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf, auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "4096", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "4096", "k dimension") + arg_parser.insert("m", "1024", "m dimension") + .insert("n", "1024", "n dimension") + .insert("k", "2048", "k dimension") .insert("a_layout", "R", "A tensor data layout - Row by default") .insert("b_layout", "C", "B tensor data layout - Row by default") .insert("c_layout", "R", "C tensor data layout - Row by default") @@ -125,4 +125,4 @@ auto create_args(int argc, char* argv[]) #include "run_mx_gemm.inc" -int main(int argc, char* argv[]) { return run_mx_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_mx_gemm_example(argc, argv); } diff --git a/example/ck_tile/42_mx_gemm/mx_gemm.hpp b/example/ck_tile/42_mx_gemm/mx_gemm.hpp index e59a2fc57a..f17fe96529 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm.hpp @@ -9,6 +9,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_mx.hpp" #include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" template @@ -83,17 +84,23 @@ struct MxGemmConfig static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; -struct MXfp4_GemmConfig16 : MxGemmConfig + +struct MX_GemmConfigEightWaves : MxGemmConfig { - static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong! + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128 * N_Warp; + static constexpr ck_tile::index_t K_Tile = 128 * K_Warp; + + static constexpr int kBlockPerCu = 2; }; -// GEMM config with 16x16 warp tile -struct MXfp8_GemmConfig16 : MxGemmConfig +struct MX_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; }; diff --git a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp index fef782cfba..1421b4d705 100644 --- a/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp +++ b/example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp @@ -57,7 +57,12 @@ float mx_gemm_calc(const MXGemmHostArgs& args, const ck_tile::st GemmConfig::Scheduler>; // Use the new MX comp_async pipeline with MX scaling support - using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync; + constexpr bool IsEightWave = + (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8; + using MXGemmPipeline = + std::conditional_t, + ck_tile::MXGemmPipelineAgBgCrCompAsync>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner& args, const ck_tile::st GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC>>; + MXPipelineProblem::TransposeC, + 1, // kNumWaveGroups_ (Default) + false, // FixedVectorSize_ (Default) + 1, // VectorSizeC_ (Default) + 1, // BlockedXDLN_PerWarp_ (Default) + false, // DoubleSmemBuffer_ (Default) + ComputeDataType, // AComputeDataType + ComputeDataType, // BComputeDataType + true>>; // TilesPacked_ (because of packed scales) using Kernel = ck_tile::MXGemmKernel; diff --git a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc index ac6c51cde1..7ccd4e4273 100644 --- a/example/ck_tile/42_mx_gemm/run_mx_gemm.inc +++ b/example/ck_tile/42_mx_gemm/run_mx_gemm.inc @@ -124,29 +124,42 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) ck_tile::host_tensor_descriptor(M, scale_k_size, stride_scale_a, is_row_major(ALayout{}))); ck_tile::HostTensor scale_b_host( ck_tile::host_tensor_descriptor(scale_k_size, N, stride_scale_b, is_row_major(BLayout{}))); - int seed = 1234; + + std::mt19937 gen(42); + std::uniform_int_distribution fill_seed(0, 500); + + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{range_min, range_max, fill_seed(gen)}( + pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + switch(init_method) { case 0: // Initialize A, B, and scales to random values - ck_tile::FillUniformDistribution{-2.f, 2.f, seed++}(a_host); - ck_tile::FillUniformDistribution{-2.f, 2.f, seed++}(b_host); - ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_a_host); - ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_b_host); + ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); + ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); + gen_scales(scale_a_host, -2, 2); + gen_scales(scale_b_host, -2, 2); break; case 1: // Initialize A, B, and scales to 1.0 ck_tile::FillConstant{ADataType(1.f)}(a_host); ck_tile::FillConstant{BDataType(1.f)}(b_host); - ck_tile::FillConstant{ScaleType(1.f)}(scale_a_host); - ck_tile::FillConstant{ScaleType(1.f)}(scale_b_host); + gen_scales(scale_a_host, 0, 0); + gen_scales(scale_b_host, 0, 0); break; case 2: // Initialize A and B with random values but with constant 1.0 scales - ck_tile::FillUniformDistribution{-2.f, 2.f, seed++}(a_host); - ck_tile::FillUniformDistribution{-2.f, 2.f, seed++}(b_host); - ck_tile::FillConstant{ScaleType(0.1f)}(scale_a_host); - ck_tile::FillConstant{ScaleType(0.1f)}(scale_b_host); + ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); + ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); + gen_scales(scale_a_host, 0, 0); + gen_scales(scale_b_host, 0, 0); break; } @@ -248,6 +261,7 @@ int run_mx_gemm_with_layouts(int argc, char* argv[], ALayout, BLayout, CLayout) return pass ? 0 : -1; } +template int run_mx_gemm_example(int argc, char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -268,7 +282,7 @@ int run_mx_gemm_example(int argc, char* argv[]) return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else if(mx_prec == "fp8" || mx_prec == "fp8xfp8") @@ -276,7 +290,7 @@ int run_mx_gemm_example(int argc, char* argv[]) return run_mx_gemm_with_layouts(argc, argv, Row{}, Col{}, Row{}); } else diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index e86e04edbf..aac6d39647 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2793,13 +2793,15 @@ template = {}) { index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); constexpr index_t src_linear_addr_offset = static_cast(linear_offset_t{}) * sizeof(T); + constexpr index_t PackedSize = numeric_traits::PackedSize; + index_t src_wave_addr_offset = src_wave_element_offset * sizeof(T) / PackedSize; amd_async_buffer_load(smem, rsrc, diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index fc3cb52d20..f6d7d82951 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -36,7 +36,8 @@ template + typename BComputeDataType_ = void, + bool TilesPacked_ = false> struct CShuffleEpilogueProblem { using AsDataType = remove_cvref_t; @@ -64,7 +65,7 @@ struct CShuffleEpilogueProblem static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; static constexpr index_t kNumWaveGroups = kNumWaveGroups_; static constexpr index_t NumDTensor = DsDataType::size(); - + static constexpr bool TilesPacked = TilesPacked_; static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); }; @@ -140,15 +141,19 @@ struct CShuffleEpilogue static constexpr bool EightWave = false; #endif + // If the wave tiles computed by a single wave are packed + // This implies that in the block gemm MRepeat and NRepeat are contiguous + static constexpr bool TilesPacked = Problem::TilesPacked; static constexpr index_t BlockedXDLN_PerWarp = - EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr index_t MPerIteration = MPerXdl * MWave; - static constexpr index_t NPerIteration = NPerXdl * NWave; - static constexpr index_t NumDTensor = Problem::NumDTensor; - static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); - static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); + (EightWave || TilesPacked) ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp; + static constexpr index_t BlockedXDLM_PerWarp = (TilesPacked) ? kMPerBlock / MWave / MPerXdl : 1; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); + static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); CDElementwise elfunc_; @@ -288,7 +293,8 @@ struct CShuffleEpilogue } } }(); - static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); + static constexpr index_t NumMXdlPerWavePerShuffle = + max(BlockedXDLM_PerWarp, std::get<0>(shuffle_tile_tuple)); static constexpr index_t NumNXdlPerWavePerShuffle = max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple)); @@ -447,64 +453,96 @@ struct CShuffleEpilogue CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() { constexpr auto block_outer_dstr_encoding = [] { - if constexpr(BlockedXDLN_PerWarp == 1) + if constexpr(TilesPacked) { - return tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + if constexpr(EightWave) + { + constexpr int RakedXDLN_PerWarp = + NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<1, 0, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + } } else { -#if defined(__gfx950__) || defined(__gfx12__) - constexpr auto UseBlockedLayout = true; -#else - constexpr auto UseBlockedLayout = false; -#endif - constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; - // BlockedLayout - // this branch is for original a16w4 - if constexpr(UseBlockedLayout || - is_any_of::value || - is_any_of::value) + if constexpr(BlockedXDLN_PerWarp == 1) { - if constexpr(EightWave) + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + } + else + { +#if defined(__gfx950__) || defined(__gfx12__) + constexpr auto UseBlockedLayout = true; +#else + constexpr auto UseBlockedLayout = false; +#endif + constexpr int RakedXDLN_PerWarp = + NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; + // BlockedLayout + // this branch is for original a16w4 + if constexpr(UseBlockedLayout || + is_any_of::value || + is_any_of::value) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 2>>{}; + if constexpr(EightWave) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } + else + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } } else { return tile_distribution_encoding< sequence<>, tuple, - sequence>, + sequence>, + tuple>, tuple>, - tuple>, sequence<1, 2, 2>, - sequence<0, 0, 2>>{}; + sequence<0, 0, 1>>{}; } } - else - { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2, 2>, - sequence<0, 0, 1>>{}; - } } }(); constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index ffad4171fa..dc76a410b5 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -388,150 +388,6 @@ struct BlockGemmARegBRegCRegV1 }); } - // C += A * B with MX scaling and packed-in-two (XdlPack) optimization - // Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t - // values (for A) or NXdlPack * KXdlPack (for B), packed on the host. - // Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call. - // XdlPack template parameters default to 2; fall back to 1 when iteration count is too small. - template - CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, - const ABlockTensor& a_block_tensor, - const BBlockTensor& b_block_tensor, - const ScaleATensor& scale_a_tensor, - const ScaleBTensor& scale_b_tensor) const - { - static_assert(std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); - - // check ABC-block-distribution - static_assert( - std::is_same_v, - remove_cvref_t>, - "A distribution is wrong!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "B distribution is wrong!"); - static_assert( - std::is_same_v, - remove_cvref_t>, - "C distribution is wrong!"); - - using AWarpDstr = typename WarpGemm::AWarpDstr; - using BWarpDstr = typename WarpGemm::BWarpDstr; - using CWarpDstr = typename WarpGemm::CWarpDstr; - - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto a_warp_y_lengths = - to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto b_warp_y_lengths = - to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - constexpr auto c_warp_y_lengths = - to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); - - constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; - constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - - // Effective XdlPack: fall back to 1 when iteration count is insufficient - constexpr index_t MXdlPack = - (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; - constexpr index_t NXdlPack = - (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; - constexpr index_t KXdlPack = - (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; - - constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; - constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; - constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; - - // hot loop with MX scaling and pre-packed int32_t scales: - // Outer loops iterate over pack groups (scale tile indices) - static_ford>{}([&](auto ii) { - constexpr auto ikpack = number{}]>{}; - constexpr auto impack = number{}]>{}; - // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) - auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); - - static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { - // Get pre-packed int32_t B scale - auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( - sequence{}, sequence<1, 1, 1>{}); - const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); - - // Inner loops: issue MFMAs within the pack group using OpSel - static_ford>{}([&](auto jj) { - constexpr auto ikxdl = number{}]>{}; - constexpr auto imxdl = number{}]>{}; - constexpr auto kIter = ikpack * KXdlPack + ikxdl; - constexpr auto mIter = impack * MXdlPack + imxdl; - - // read A warp tensor from A block tensor - AWarpTensor a_warp_tensor; - a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, a_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); - - // OpSel for A: selects byte within packed int32_t - constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; - - static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - constexpr auto nIter = inpack * NXdlPack + inxdl; - - // read B warp tensor from B block tensor - BWarpTensor b_warp_tensor; - b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( - merge_sequences(sequence{}, b_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - - // OpSel for B: selects byte within packed int32_t - constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; - - // read C warp tensor from C block tensor - using c_iter_idx = std::conditional_t, - sequence>; - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM with MX scaling using pre-packed scale and OpSel - WarpGemm{}.template operator(), OpSelB>( - c_warp_tensor, - a_warp_tensor, - b_warp_tensor, - a_scale_packed, - b_scale_packed); - - // write C warp tensor into C block tensor - c_block_tensor.set_y_sliced_thread_data( - merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - }); - }); - }); - }); - } - CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp index 2f2a67deae..a9f6dced9d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -118,12 +118,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrComp CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - // We are not storing the original packed type in LDS, so we need to multiply the smem size - // by the packed size. - constexpr index_t smem_size_a = Policy::template GetSmemSizeA() * APackedSize; - constexpr index_t smem_size_b = Policy::template GetSmemSizeB() * BPackedSize; - - return 2 * (smem_size_a + smem_size_b); + return Policy::template GetSmemSize(); } static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp index 80728dba67..1a12eaa4fe 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -17,10 +17,9 @@ namespace detail { template struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy { - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; - static constexpr auto WGAccessDouble = WGAttrNumAccessEnum::Double; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -29,14 +28,21 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy using CDataType = remove_cvref_t; using AComputeDataType = remove_cvref_t; using BComputeDataType = remove_cvref_t; - static_assert(std::is_same_v, "Wrong!"); - static_assert(std::is_same_v, "Wrong!"); - static_assert(std::is_same_v || - std::is_same_v); - static_assert(std::is_same_v || - std::is_same_v); + using ComputeDataType = AComputeDataType; + static_assert(std::is_same_v, + "ALayout must be RowMajor!"); + static_assert(std::is_same_v, + "BLayout must be ColumnMajor!"); + static_assert(is_any_of::value); + static_assert(is_any_of::value); + static_assert(std::is_same_v); static_assert(std::is_same_v); + static constexpr auto WGAccess = std::is_same_v + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + static constexpr auto PackedSize = numeric_traits::PackedSize; + using BlockGemmShape = typename Problem::BlockGemmShape; using BlockWarps = typename BlockGemmShape::BlockWarps; using WarpTile = typename BlockGemmShape::WarpTile; @@ -88,7 +94,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static constexpr index_t NIterPerWarp = NWarpTiles / NWarps; static constexpr index_t KPerWarp = KPerBlock / KWarps; static constexpr index_t NPerWarp = NPerBlock / NWarps; - static_assert(NWarps == 2, "KWarps == 2 for ping-pong!"); + static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); static_assert(KWarpTiles == KWarps, "Wrong!"); static constexpr index_t warp_size = get_warp_size(); @@ -98,8 +104,8 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!"); static constexpr index_t ElementSize = sizeof(ADataType); - static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16 - static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 + static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize * PackedSize; // 16 + static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 static constexpr index_t K0 = KPerWarp / (K1 * K2); static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!"); static_assert(K0 == 1, "Wrong!"); @@ -176,7 +182,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy const index_t k_tiles = cols / (KWarps * K1 * K2); const auto col_lens = make_tuple(k_tiles, number{}, number{}, number{}); - constexpr index_t M1 = warp_size / static_cast(WGAccessDouble) / K1; // 4 + constexpr index_t M1 = warp_size / static_cast(WGAccess) / K1; // 4 const index_t M0 = integer_divide_ceil(rows, M1); const auto row_lens = make_tuple(M0, number{}); @@ -227,9 +233,9 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy template CK_TILE_DEVICE static constexpr auto MakeABLdsBlockDescriptor_() { - constexpr index_t M4 = warp_size / static_cast(WGAccessDouble) / K1; // 4 - constexpr index_t M3 = static_cast(WGAccessDouble); // 2 - constexpr index_t M2 = WarpTileM / M4 / M3; // 2 + constexpr index_t M4 = warp_size / static_cast(WGAccess) / K1; // 4 + constexpr index_t M3 = static_cast(WGAccess); // 2 + constexpr index_t M2 = WarpTileM / M4 / M3; // 2 constexpr index_t M1 = (warp_num / warp_groups_) / M2; constexpr index_t M0 = MNPerBlock / M1 / M2 / M3 / M4; @@ -337,12 +343,14 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy CK_TILE_DEVICE static constexpr index_t GetSmemSizeA() { constexpr index_t desc_size = MakeALdsBlockDescriptor().get_element_space_size(); - return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size, 16); + return integer_least_multiple(sizeof(typename Problem::ADataType) * desc_size / PackedSize, + 16); } CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { constexpr index_t desc_size = MakeBLdsBlockDescriptor().get_element_space_size(); - return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size, 16); + return integer_least_multiple(sizeof(typename Problem::BDataType) * desc_size / PackedSize, + 16); } CK_TILE_DEVICE static constexpr index_t GetSmemSize() @@ -361,7 +369,7 @@ struct GemmPipelineAgBgCrCompAsyncEightWavesPolicy CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { // TODO: Fix for transpose - constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Double; + constexpr auto wg_attr_num_access = WGAccess; using WarpGemm = WarpGemmDispatcher, + bool>* = nullptr> + CK_TILE_DEVICE static constexpr auto GetInstCountAQ(const AQDramBlockWindowTmp&) + { + return 0; + } + + template , + bool>* = nullptr> + CK_TILE_DEVICE static constexpr auto GetInstCountBQ(const BQDramBlockWindowTmp&) + { + return 0; + } + // A/B Quant template , @@ -234,6 +250,22 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< return Policy::template GetKStepBQ(); } + template , + bool>* = nullptr> + CK_TILE_DEVICE static constexpr auto GetInstCountAQ(const AQDramBlockWindowTmp&) + { + return Policy::template GetInstCountAQ(); + } + + template , + bool>* = nullptr> + CK_TILE_DEVICE static constexpr auto GetInstCountBQ(const BQDramBlockWindowTmp&) + { + return Policy::template GetInstCountBQ(); + } + template = 1, "wrong!"); - // Instructions Count - constexpr index_t VectorSizeB = Policy::template GetVectorSizeB(); - constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / VectorSizeB; - constexpr index_t AQ_LOAD_INST = - std::is_same_v ? 0 : MIterPerWarp; - constexpr index_t BQ_LOAD_INST = - std::is_same_v ? 0 : 1; - // ----- // Setup // ----- @@ -314,6 +338,12 @@ struct GemmPipelineAgBgCrEightWavesImplBase : public GemmPipelineAgBgCrImplBase< constexpr AQDramTileWindowStep aq_move_step = {0, GetKStepAQ(aq_copy_dram_window)}; constexpr BQDramTileWindowStep bq_move_step = {0, GetKStepBQ(bq_copy_dram_window)}; + // Instructions Count + constexpr index_t VectorSizeB = Policy::template GetVectorSizeB(); + constexpr index_t B_LOAD_INST = NPerBlock * KPerBlock / BlockSize / VectorSizeB; + constexpr index_t AQ_LOAD_INST = GetInstCountAQ(aq_copy_dram_window); + constexpr index_t BQ_LOAD_INST = GetInstCountBQ(bq_copy_dram_window); + // ------- // Lambdas // ------- diff --git a/include/ck_tile/ops/gemm_mx.hpp b/include/ck_tile/ops/gemm_mx.hpp index 29fccf8057..edd2f6d657 100644 --- a/include/ck_tile/ops/gemm_mx.hpp +++ b/include/ck_tile/ops/gemm_mx.hpp @@ -2,10 +2,14 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp" +#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp" #include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp" #include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" +#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp" +#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp b/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp new file mode 100644 index 0000000000..79fbe347c6 --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp @@ -0,0 +1,310 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockMXGemmARegBRegCRegEightWavesV1 +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + static constexpr index_t KWarp = Problem::BlockGemmShape::BlockWarps::at(number<2>{}); + + using I0 = number<0>; + using I1 = number<1>; + + static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}), + "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!"); + static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}), + "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}), + "Error! WarpGemm's M is not consistent with BlockGemmShape!"); + static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}), + "Error! WarpGemm's N is not consistent with BlockGemmShape!"); + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / (KWarp * WarpGemm::kK); + + // Controls how many MAC clusters (MFMA blocks) we have per wave + // If InterWaveSchedulingMacClusters = 1; + // Then we group all WarpGemms into single MAC cluster. + // But if InterWaveSchedulingMacClusters = 2, then we + // split the warp gemms into two groups. + static constexpr index_t InterWaveSchedulingMacClusters = 1; + + static constexpr index_t KPackA = WarpGemm::kAKPack; + static constexpr index_t KPackB = WarpGemm::kBKPack; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; + static constexpr bool TransposeC = Problem::TransposeC; + }; + + public: + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using Traits = GemmTraits_; + + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + static constexpr index_t KWarp = Traits::KWarp; + + static constexpr auto Scheduler = Traits::Scheduler; + static constexpr bool TransposeC = Traits::TransposeC; + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static_assert(std::is_same_v); + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + using I0 = number<0>; + using I1 = number<1>; + + // Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale + // packing. + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + constexpr index_t KPerThread = Traits::KPerThread; + constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; + + using KIterSeq = std::conditional_t, + sequence>; + + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, KIterSeq>, + tuple>, + tuple>, + sequence<>, + sequence<>>{}; + + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence<2, NIterPerWarp, NWarp / 2>>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 1>>{}; + constexpr auto c_block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + return c_block_dstr_encoding; + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + return make_static_distributed_tensor( + make_static_tile_distribution(MakeCBlockDistributionEncode())); + } + + using ALdsTile = decltype(make_static_distributed_tensor( + make_static_tile_distribution(MakeABlockDistributionEncode()))); + using BLdsTiles = statically_indexed_array< + statically_indexed_array( + make_static_tile_distribution( + MakeBBlockDistributionEncode()))), + KIterPerWarp>, + NIterPerWarp>; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ALdsTile& a_warp_tile_, + const BLdsTiles& b_warp_tiles_, + const ScaleATensor& scale_a_tensor, + const ScaleBTensor& scale_b_tensor) const + { + // checks + static_assert(std::is_same_v>, + "CDataType must be same as CBlockTensor::DataType!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + // Effective XdlPack: fall back to 1 when iteration count is insufficient + constexpr index_t MXdlPack = + (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; + constexpr index_t NXdlPack = + (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; + constexpr index_t KXdlPack = + (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; + + // hot loop: + static_for_product, + number, + number>{}([&](auto ikpack, auto inpack, auto impack) { + // get A scale for this M-K tile using get_y_sliced_thread_data + auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); + + // get B scale for this N-K tile using get_y_sliced_thread_data + auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); + + // Inner loops: issue MFMAs within the pack group using OpSel + static_for_product, number, number>{}( + [&](auto ikxdl, auto inxdl, auto imxdl) { + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; + constexpr auto nIter = inpack * NXdlPack + inxdl; + + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + + // read A warp tensor from A Block window + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tiles_[number{}][number{}].get_thread_buffer(); + + // read C warp tensor from C block tensor + using c_iter_idx = sequence; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM with MX scaling + WarpGemm{}.template operator(), OpSelB>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp new file mode 100644 index 0000000000..7e190dc8e1 --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp" + +namespace ck_tile { + +// A is block distributed tensor +// B is block distributed tensor +// C is block distributed tensor +template +struct BlockMXGemmARegBRegCRegV1 +{ + private: + template + struct GemmTraits_ + { + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + static constexpr index_t KPackA = WarpGemm::kAKPack; + static constexpr index_t KPackB = WarpGemm::kBKPack; + }; + + public: + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + static constexpr bool TransposeC = TransposeC_; + + using Traits = GemmTraits_; + + using WarpGemm = typename Traits::WarpGemm; + using BlockGemmShape = typename Traits::BlockGemmShape; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; + static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; + static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; + + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); + + // Note: distribution encodings have MIterPerWarp and NIterPerWarp contiguous because of scale + // packing. + + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() + { + if constexpr(UseDefaultScheduler) + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + } + + CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() + { + if constexpr(UseDefaultScheduler) + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<1, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() + { + using c_distr_ys_major = std::conditional_t, sequence<1, 2>>; + if constexpr(UseDefaultScheduler) + { + using c_distr_ys_minor = std::conditional_t, sequence<0, 1>>; + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + c_distr_ys_minor>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + c_distr_ys_major, + sequence<1, 1>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } + } + + // C += A * B with MX scaling and packed-in-two (XdlPack) optimization + // Scale tensors contain pre-packed int32_t: each int32_t holds MXdlPack * KXdlPack e8m0_t + // values (for A) or NXdlPack * KXdlPack (for B), packed on the host. + // Uses OpSel (0-3) to select which byte within the packed int32_t for each MFMA call. + // XdlPack template parameters default to 2; fall back to 1 when iteration count is too small. + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ABlockTensor& a_block_tensor, + const BBlockTensor& b_block_tensor, + const ScaleATensor& scale_a_tensor, + const ScaleBTensor& scale_b_tensor) const + { + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "Datatypes do not match BlockTensor datatypes!"); + + // check ABC-block-distribution + static_assert( + std::is_same_v, + remove_cvref_t>, + "A distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "B distribution is wrong!"); + static_assert( + std::is_same_v, + remove_cvref_t>, + "C distribution is wrong!"); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // Effective XdlPack: fall back to 1 when iteration count is insufficient + constexpr index_t MXdlPack = + (MIterPerWarp >= MXdlPack_ && MIterPerWarp % MXdlPack_ == 0) ? MXdlPack_ : 1; + constexpr index_t NXdlPack = + (NIterPerWarp >= NXdlPack_ && NIterPerWarp % NXdlPack_ == 0) ? NXdlPack_ : 1; + constexpr index_t KXdlPack = + (KIterPerWarp >= KXdlPack_ && KIterPerWarp % KXdlPack_ == 0) ? KXdlPack_ : 1; + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; + + // hot loop with MX scaling and pre-packed int32_t scales: + // Outer loops iterate over pack groups (scale tile indices) + static_ford>{}([&](auto ii) { + constexpr auto ikpack = number{}]>{}; + constexpr auto impack = number{}]>{}; + // Get pre-packed int32_t A scale (already contains MXdlPack*KXdlPack e8m0_t) + auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t a_scale_packed = bit_cast(scale_a_slice[number<0>{}]); + + static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) { + // Get pre-packed int32_t B scale + auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data( + sequence{}, sequence<1, 1, 1>{}); + const int32_t b_scale_packed = bit_cast(scale_b_slice[number<0>{}]); + + // Inner loops: issue MFMAs within the pack group using OpSel + static_ford>{}([&](auto jj) { + constexpr auto ikxdl = number{}]>{}; + constexpr auto imxdl = number{}]>{}; + constexpr auto kIter = ikpack * KXdlPack + ikxdl; + constexpr auto mIter = impack * MXdlPack + imxdl; + + // read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // OpSel for A: selects byte within packed int32_t + constexpr index_t kOpSelA = ikxdl * MXdlPack + imxdl; + + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto nIter = inpack * NXdlPack + inxdl; + + // read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // OpSel for B: selects byte within packed int32_t + constexpr index_t kOpSelB = ikxdl * NXdlPack + inxdl; + + // read C warp tensor from C block tensor + using c_iter_idx = std::conditional_t, + sequence>; + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM with MX scaling using pre-packed scale and OpSel + WarpGemm{}.template operator(), OpSelB>( + c_warp_tensor, + a_warp_tensor, + b_warp_tensor, + a_scale_packed, + b_scale_packed); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(c_iter_idx{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + }); + } + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + return make_static_distributed_tensor( + make_static_tile_distribution(MakeCBlockDistributionEncode())); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp index a6428b88ac..bd647dfc87 100644 --- a/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp +++ b/include/ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp @@ -332,8 +332,7 @@ struct MXGemmKernel : UniversalGemmKernel& bs_ptr, const std::array& ds_ptr, EDataType* e_ptr, - void* smem_ptr_ping, - void* smem_ptr_pong, + void* smem_ptr, const KernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t i_m, @@ -363,24 +362,18 @@ struct MXGemmKernel : UniversalGemmKernel CK_TILE_DEVICE void operator()(KernelArgs kargs, int partition_idx = get_block_id()) const @@ -389,8 +382,7 @@ struct MXGemmKernel : UniversalGemmKernel(); + constexpr index_t smem_size = Policy::template GetSmemSize(); + return 2 * smem_size; } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() @@ -688,9 +689,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< const ScaleADramBlockWindowTmp& scale_a_window, const ScaleBDramBlockWindowTmp& scale_b_window, index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { + constexpr index_t smem_size = Policy::template GetSmemSize(); + const auto smem = reinterpret_cast(p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -703,8 +706,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< scale_a_window, scale_b_window, num_loop, - p_smem_0, - p_smem_1); + smem, + smem + smem_size); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); @@ -720,9 +723,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< const ScaleADramBlockWindowTmp& scale_a_window, const ScaleBDramBlockWindowTmp& scale_b_window, const index_t num_loop, - void* __restrict__ p_smem_0, - void* __restrict__ p_smem_1) const + void* __restrict__ p_smem) const { + constexpr index_t smem_size = Policy::template GetSmemSize(); + const auto smem = reinterpret_cast(p_smem); + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); const auto tail_number = Base::GetBlockLoopTailNum(num_loop); @@ -735,8 +740,8 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync< scale_a_window, scale_b_window, num_loop, - p_smem_0, - p_smem_1); + smem, + smem + smem_size); }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp index d90271d235..a67beb5544 100644 --- a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_v1.hpp" #include namespace ck_tile { @@ -128,7 +129,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy BlockWarps, WarpGemm>; - return BlockGemmARegBRegCRegV1{}; + return BlockMXGemmARegBRegCRegV1{}; } // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension @@ -170,12 +171,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); + sequence<0, 1, 2>>{}); } template @@ -208,12 +209,12 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy return make_static_tile_distribution( tile_distribution_encoding, - tuple, + tuple, sequence>, tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, sequence<2, 1, 2>, - sequence<0, 0, 2>>{}); + sequence<0, 1, 2>>{}); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp new file mode 100644 index 0000000000..3b25d6091a --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves.hpp @@ -0,0 +1,282 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_eight_waves_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" + +namespace ck_tile { + +/** + * @brief Compute optimized pipeline version async for 8 waves + * + * This pipeline introduces asynchronous load from global memory to LDS, + * skipping the intermediate loading into pipeline registers. + */ +template +struct MXGemmPipelineAgBgCrCompAsyncEightWaves : public BaseGemmPipelineAgBgCrCompV3 +{ + using Base = BaseGemmPipelineAgBgCrCompV3; + using PipelineImplBase = GemmPipelineAgBgCrEightWavesImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = ck_tile::numeric_traits::PackedSize; + static constexpr index_t BPackedSize = ck_tile::numeric_traits::PackedSize; + + using BlockGemm = remove_cvref_t())>; + using WarpGemm = typename BlockGemm::WarpGemm; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); + + static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock; + + static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * WarpGemm::kM); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * WarpGemm::kN); + static constexpr index_t KIterPerWarp = KPerBlock / (KWarps * WarpGemm::kK); + + static constexpr bool Async = true; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr auto Scheduler = Problem::Scheduler; + + [[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName() + { + // clang-format off + return "COMPUTE_ASYNC_EIGHT_WAVES"; + // clang-format on + } + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + return concat('_', "pipeline_AgBgCrCompAsyncEightWaves", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB()), + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK), + Problem::GetName()); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + static constexpr index_t MFMA_INST = MIterPerWarp * NIterPerWarp * KIterPerWarp; + + // Scales are packed so odd numbers of iterations greater than 1 are not supported + static_assert((MIterPerWarp == 1) || (MIterPerWarp % 2 == 0)); + static_assert((NIterPerWarp == 1) || (NIterPerWarp % 2 == 0)); + static_assert((KIterPerWarp == 1) || (KIterPerWarp % 2 == 0)); + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem) const + { + // TODO: A/B elementwise functions currently not supported + ignore = a_element_func; + ignore = b_element_func; + + // ------ + // Checks + // ------ + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + static_assert(std::is_same_v, "Wrong!"); + static_assert(std::is_same_v, "Wrong!"); + + static_assert((MPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == AsDramBlockWindowTmp{}.get_window_lengths()[I1]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(Preshuffle // + ? (NWarps == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && + kflatKPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]) + : (NPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == BsDramBlockWindowTmp{}.get_window_lengths()[I1]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------ + // Hot loop scheduler + // ------------------ + auto hot_loop_scheduler = [&]() { + __builtin_amdgcn_sched_group_barrier(0x008, MIterPerWarp, 0); // MFMA + s_waitcnt_lgkm<4>(); + __builtin_amdgcn_sched_group_barrier(0x004, 1, 0); // lgkmcnt / SALU + static_for<0, MFMA_INST - MIterPerWarp, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + }; + + // ------- + // Compute + // ------- + return Base::template Run_(p_smem, + num_loop, + a_dram_block_window_tmp, + b_dram_block_window_tmp, + scale_a_window, + scale_b_window, + hot_loop_scheduler); + } + }; + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem) const + { + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + scale_a_window, + scale_b_window, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* p_smem) const + { + const bool has_hot_loop = Base::BlockHasHotloop(num_loop); + const auto tail_number = Base::GetBlockLoopTailNum(num_loop); + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + identity{}, + b_dram_block_window_tmp, + identity{}, + scale_a_window, + scale_b_window, + num_loop, + p_smem); + }; + + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp new file mode 100644 index 0000000000..ec5db8afdd --- /dev/null +++ b/include/ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp @@ -0,0 +1,203 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_eight_waves_policy.hpp" +#include "ck_tile/ops/gemm_mx/block/block_mx_gemm_areg_breg_creg_eight_waves_v1.hpp" + +namespace ck_tile { +namespace detail { + +template +struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + // MX scaling configuration: each e8m0 scale covers 32 elements in K + static constexpr int BlockScaleSize = 32; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using AComputeDataType = remove_cvref_t; + using BComputeDataType = remove_cvref_t; + using ComputeDataType = AComputeDataType; + static_assert(std::is_same_v, "Wrong!"); + static_assert(std::is_same_v, "Wrong!"); + static_assert(is_any_of::value); + static_assert(is_any_of::value); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + using BlockGemmShape = typename Problem::BlockGemmShape; + using BlockWarps = typename BlockGemmShape::BlockWarps; + using WarpTile = typename BlockGemmShape::WarpTile; + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + static constexpr index_t MWarps = BlockGemmShape::BlockWarps::at(I0); + static constexpr index_t NWarps = BlockGemmShape::BlockWarps::at(I1); + static constexpr index_t KWarps = BlockGemmShape::BlockWarps::at(I2); + static constexpr index_t WarpTileM = WarpTile::at(I0); + static constexpr index_t WarpTileN = WarpTile::at(I1); + static constexpr index_t WarpTileK = WarpTile::at(I2); + static constexpr index_t MWarpTiles = MPerBlock / WarpTileM; + static constexpr index_t NWarpTiles = NPerBlock / WarpTileN; + static constexpr index_t KWarpTiles = KPerBlock / WarpTileK; + + // XdlPack: how many e8m0_t scale values are packed into one int32_t per dimension + // Host packs MXdlPack * KXdlPack e8m0_t into one int32_t for A scales + // Host packs NXdlPack * KXdlPack e8m0_t into one int32_t for B scales + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; + + // Compute effective XdlPack sizes (fall back to 1 when iter count < pack) + static constexpr index_t MPerXdl = WarpTile::at(I0); + static constexpr index_t NPerXdl = WarpTile::at(I1); + static constexpr index_t KPerXdl = WarpTile::at(I2); + static constexpr index_t MIterPerWarp = MPerBlock / (MWarps * MPerXdl); + static constexpr index_t NIterPerWarp = NPerBlock / (NWarps * NPerXdl); + static constexpr index_t KIterPerWarp = KPerBlock / KPerXdl; + + static constexpr index_t MXdlPackEff = + (MIterPerWarp >= MXdlPack && MIterPerWarp % MXdlPack == 0) ? MXdlPack : 1; + static constexpr index_t NXdlPackEff = + (NIterPerWarp >= NXdlPack && NIterPerWarp % NXdlPack == 0) ? NXdlPack : 1; + static constexpr index_t KXdlPackEff = + (KIterPerWarp >= KXdlPack && KIterPerWarp % KXdlPack == 0) ? KXdlPack : 1; + + static constexpr index_t KPerBlockScale = KPerBlock / BlockScaleSize / KXdlPackEff; + + static constexpr index_t KPerWarp = KPerBlock / KWarps; + static constexpr index_t NPerWarp = NPerBlock / NWarps; + static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); + static_assert(KWarpTiles == KWarps, "Wrong!"); + + static constexpr index_t warp_size = get_warp_size(); + static constexpr index_t warp_num = BlockSize / warp_size; + static_assert(warp_size == 64, "Wrong!"); + static_assert(warp_num * warp_size == BlockSize, "Wrong!"); + + static_assert(sizeof(ADataType) == sizeof(BDataType), "Wrong!"); + static constexpr index_t ElementSize = sizeof(ADataType); + static constexpr index_t K2 = Problem::VectorLoadSize / ElementSize; // 16 + static constexpr index_t K1 = WarpTile::at(I2) / K2; // 8 + static constexpr index_t K0 = KPerWarp / (K1 * K2); + static_assert(K0 * K1 * K2 == KPerWarp, "Wrong!"); + static_assert(K0 == 1, "Wrong!"); + + CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockScale; } + CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockScale; } + + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() + { + return (MIterPerWarp / MXdlPackEff) * (KIterPerWarp / KXdlPackEff); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() + { + return (NIterPerWarp / NXdlPackEff) * (KIterPerWarp / KXdlPackEff); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution() + { + constexpr index_t K_Lane = get_warp_size() / WarpTileM; + + constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; + + constexpr index_t MIterPerWarp_packed = MIterPerWarp / MXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // repeat over MWarps + tuple, // M dimension (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 2>>, + sequence<2, 1, 2>, // + sequence<0, 1, 2>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeBQBlockDistribution() + { + constexpr index_t K_Lane = get_warp_size() / WarpTileN; + + constexpr index_t KPerLane = WarpTileK / BlockScaleSize / K_Lane; + + constexpr index_t NIterPerWarp_packed = NIterPerWarp / NXdlPackEff; + constexpr index_t KIterPerWarp_packed = KIterPerWarp / KXdlPackEff; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // repeat over MWarps + tuple, // N dimension + // (first) + sequence>, // K dimension (second) + tuple, sequence<2, 1>>, // , + tuple, sequence<1, 3>>, + sequence<2, 1, 2>, // + sequence<0, 1, 2>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + constexpr auto wg_attr_num_access = + (std::is_same_v || std::is_same_v) + ? WGAttrNumAccessEnum::Double + : WGAttrNumAccessEnum::Single; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockMXGemmARegBRegCRegEightWavesV1{}; + } +}; +} // namespace detail + +struct MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy + : public GemmPipelineAgBgCrCompAsyncEightWavesPolicy +{ + +#define FORWARD_METHOD_(method) \ + template \ + CK_TILE_HOST_DEVICE static constexpr auto method(Args&&... args) \ + { \ + return detail::MXGemmPipelineAgBgCrCompAsyncEightWavesPolicy::method( \ + std::forward(args)...); \ + } + + FORWARD_METHOD_(MakeAQBlockDistribution); + FORWARD_METHOD_(MakeBQBlockDistribution); + FORWARD_METHOD_(GetBlockGemm); + FORWARD_METHOD_(GetKStepAQ); + FORWARD_METHOD_(GetKStepBQ); + FORWARD_METHOD_(GetInstCountAQ); + FORWARD_METHOD_(GetInstCountBQ); + +#undef FORWARD_METHOD_ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp index e6249ffa4f..23ad2dd12a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves.hpp @@ -134,12 +134,7 @@ struct ABQuantGemmPipelineAgBgCrEightWaves : public BaseGemmPipelineAgBgCrCompV3 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - // We are not storing the original packed type in LDS, so we need to multiply the smem size - // by the packed size. - constexpr index_t smem_size_a = Policy::template GetSmemSizeA() * APackedSize; - constexpr index_t smem_size_b = Policy::template GetSmemSizeB() * BPackedSize; - - return 2 * (smem_size_a + smem_size_b); + return Policy::template GetSmemSize(); } CK_TILE_HOST static std::string Print() { return "ABQuantGemmPipelineAgBgCrEightWaves\n"; } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp index 860c102cb0..d52cb9ddc1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_eight_waves_policy.hpp @@ -61,7 +61,7 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy static constexpr index_t NIterPerWarp = NWarpTiles / NWarps; static constexpr index_t KPerWarp = KPerBlock / KWarps; static constexpr index_t NPerWarp = NPerBlock / NWarps; - static_assert(NWarps == 2, "KWarps == 2 for ping-pong!"); + static_assert(NWarps == 2, "NWarps == 2 for ping-pong!"); static_assert(KWarpTiles == KWarps, "Wrong!"); static constexpr index_t KPerWarpAQ = KPerWarp / Problem::AQuantGroupSize::kK; @@ -87,6 +87,11 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy CK_TILE_HOST_DEVICE static constexpr auto GetKStepAQ() { return KPerBlockAQ; } CK_TILE_HOST_DEVICE static constexpr auto GetKStepBQ() { return KPerBlockBQ; } + // TODO: generalize instruction count calculation + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountAQ() { return MIterPerWarp; } + + CK_TILE_HOST_DEVICE static constexpr auto GetInstCountBQ() { return 1; } + CK_TILE_HOST_DEVICE static constexpr auto MakeAQBlockDistribution() { return make_static_tile_distribution( @@ -156,6 +161,8 @@ struct GemmABQuantPipelineAgBgCrAsyncPolicy : public GemmPipelineAgBgCrCompAsync FORWARD_METHOD_(GetBlockGemm); FORWARD_METHOD_(GetKStepAQ); FORWARD_METHOD_(GetKStepBQ); + FORWARD_METHOD_(GetInstCountAQ); + FORWARD_METHOD_(GetInstCountBQ); #undef FORWARD_METHOD_ }; diff --git a/test/ck_tile/gemm_mx/CMakeLists.txt b/test/ck_tile/gemm_mx/CMakeLists.txt index 51a16fbc3e..4b6e6b795c 100644 --- a/test/ck_tile/gemm_mx/CMakeLists.txt +++ b/test/ck_tile/gemm_mx/CMakeLists.txt @@ -7,11 +7,8 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_ck_tile_mx_gemm_fp4 test_mx_gemm_fp4.cpp) - target_compile_options(test_ck_tile_mx_gemm_fp4 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) - - add_gtest_executable(test_ck_tile_mx_gemm_fp8 test_mx_gemm_fp8.cpp) - target_compile_options(test_ck_tile_mx_gemm_fp8 PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_mx_gemm_async test_mx_gemm_async.cpp) + target_compile_options(test_ck_tile_mx_gemm_async PRIVATE ${TEST_MX_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile MX GEMM tests for current target") endif() diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp new file mode 100644 index 0000000000..489bb4d25c --- /dev/null +++ b/test/ck_tile/gemm_mx/test_mx_gemm_async.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_mx_gemm_config.hpp" +#include "test_mx_gemm_util.hpp" + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; +using F4 = ck_tile::pk_fp4_t; +using F8 = ck_tile::fp8_t; +using F6 = ck_tile::pk_fp6x16_t; + +// clang-format off +using MxTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; +// clang-format on + +template +class TestMxGemm : public TestMxGemmUtil +{ +}; + +TYPED_TEST_SUITE(TestMxGemm, MxTypes); + +TYPED_TEST(TestMxGemm, Default) +{ + // No M/N/K padding so we use 128x256x256 as smallest dimensions + this->Run(128, 256, 256); + this->Run(256, 256, 512); + this->Run(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp index 3cce36a85d..ab1f1a20f4 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_config.hpp @@ -80,16 +80,22 @@ struct MxGemmConfig static constexpr bool TiledMMAPermuteN = false; }; -struct MXfp4_GemmConfig16 : MxGemmConfig +struct MX_GemmConfig16 : MxGemmConfig { static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 256; }; -struct MXfp8_GemmConfig16 : MxGemmConfig +struct MX_GemmConfigEightWaves : MxGemmConfig { - static constexpr ck_tile::index_t M_Tile = 64; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256; + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong! + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128 * N_Warp; + static constexpr ck_tile::index_t K_Tile = 128 * K_Warp; + + static constexpr int kBlockPerCu = 2; }; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp deleted file mode 100644 index 307ea9bcf8..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_fp4.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_util.hpp" - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - -using MxFp4Types = ::testing::Types< - std::tuple>; - -template -class TestMxGemmFp4 : public TestMxGemmUtil, - std::tuple_element_t<1, TypeParam>, - std::tuple_element_t<2, TypeParam>, - std::tuple_element_t<3, TypeParam>, - std::tuple_element_t<4, TypeParam>, - std::tuple_element_t<5, TypeParam>> -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp4, MxFp4Types); - -TYPED_TEST(TestMxGemmFp4, BasicSizes) -{ - this->Run(64, 64, 256); - this->Run(128, 128, 256); - this->Run(64, 128, 512); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp b/test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp deleted file mode 100644 index f7f9891b73..0000000000 --- a/test/ck_tile/gemm_mx/test_mx_gemm_fp8.cpp +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "test_mx_gemm_config.hpp" -#include "test_mx_gemm_util.hpp" - -using Row = ck_tile::tensor_layout::gemm::RowMajor; -using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - -using MxFp8Types = - ::testing::Types>; - -template -class TestMxGemmFp8 : public TestMxGemmUtil, - std::tuple_element_t<1, TypeParam>, - std::tuple_element_t<2, TypeParam>, - std::tuple_element_t<3, TypeParam>, - std::tuple_element_t<4, TypeParam>, - std::tuple_element_t<5, TypeParam>> -{ -}; - -TYPED_TEST_SUITE(TestMxGemmFp8, MxFp8Types); - -TYPED_TEST(TestMxGemmFp8, BasicSizes) -{ - this->Run(64, 64, 256); - this->Run(128, 128, 256); - this->Run(64, 128, 512); -} diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp index 1eac249f13..775b0ca978 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_instance.hpp @@ -4,8 +4,7 @@ #pragma once #include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" -#include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp" +#include "ck_tile/ops/gemm_mx.hpp" #include "test_mx_gemm_config.hpp" template & args, const ck_tile::st MXGemmTraits, GemmConfig::Scheduler>; - using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync; + constexpr bool IsEightWave = + (GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp) == 8; + using MXGemmPipeline = + std::conditional_t, + ck_tile::MXGemmPipelineAgBgCrCompAsync>; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner& args, const ck_tile::st GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - MXPipelineProblem::TransposeC>>; + MXPipelineProblem::TransposeC, + 1, // kNumWaveGroups_ (Default) + false, // FixedVectorSize_ (Default) + 1, // VectorSizeC_ (Default) + 1, // BlockedXDLN_PerWarp_ (Default) + false, // DoubleSmemBuffer_ (Default) + ADataType, // AComputeDataType + BDataType, // BComputeDataType + true>>; // TilesPacked_ (because of packed scales) using Kernel = ck_tile::MXGemmKernel; diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp index 1f510a6b77..6020a5a4b1 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_util.hpp @@ -30,15 +30,17 @@ auto calculate_rtol_atol_mx(ck_tile::index_t K, float max_accumulated_value) return ck_tile::make_tuple(rtol, atol); } -template +template class TestMxGemmUtil : public ::testing::Test { protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using GemmConfig = std::tuple_element_t<2, Tuple>; + using ALayout = std::tuple_element_t<3, Tuple>; + using BLayout = std::tuple_element_t<4, Tuple>; + using CLayout = std::tuple_element_t<5, Tuple>; + using AccDataType = float; using CDataType = ck_tile::fp16_t; using ScaleType = ck_tile::e8m0_t; @@ -94,7 +96,7 @@ class TestMxGemmUtil : public ::testing::Test return packed; } - void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, int seed = 1234) + void Run(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) { const ck_tile::index_t scale_k_size = K / 32; const ck_tile::index_t stride_A = @@ -119,10 +121,23 @@ class TestMxGemmUtil : public ::testing::Test ck_tile::HostTensor scale_b_host(ck_tile::host_tensor_descriptor( scale_k_size, N, stride_scale_b, is_row_major(BLayout{}))); - ck_tile::FillUniformDistribution{-2.f, 2.f, seed++}(a_host); - ck_tile::FillUniformDistribution{-2.f, 2.f, seed++}(b_host); - ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_a_host); - ck_tile::FillUniformDistribution{0.001f, 10.f, seed++}(scale_b_host); + std::mt19937 gen(42); + std::uniform_int_distribution fill_seed(0, 500); + + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{ + range_min, range_max, fill_seed(gen)}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + + ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(a_host); + ck_tile::FillUniformDistribution{-2.f, 2.f, fill_seed(gen)}(b_host); + gen_scales(scale_a_host, -2, 2); + gen_scales(scale_b_host, -2, 2); // Compute effective XdlPack sizes based on GemmConfig tile dimensions constexpr ck_tile::index_t MPerXdl = GemmConfig::M_Warp_Tile; diff --git a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc index 716ef4b626..c41350db33 100644 --- a/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc +++ b/test/ck_tile/grouped_gemm_mx/test_mx_grouped_gemm_ut_cases.inc @@ -20,4 +20,3 @@ TYPED_TEST(TestCkTileMxGroupedGemm, Basic) this->Run(Ms, Ns, Ks, kbatch, group_count); } -