From 10498656ef9b5714993b34be8301e21b387e3bf6 Mon Sep 17 00:00:00 2001 From: kylasa Date: Thu, 12 Jun 2025 18:24:02 -0700 Subject: [PATCH] Code drop for 2 warp ping pong scheduler along K dimension. (#2276) * Code drop for 2 warp ping pong scheduler along K dimension. * Addressing code review comments. * Addressing Clang formatting issues. * Addressing build issues. * Addressing build issues of other GEMM pipelines with ping pong scheduler code drop. * Fix for LDS memory size for GEMM pipelines. * Addressing code review feedback comments. * Change log update. * Addressing code review comments and build issues. * Added new policy for pipeline specific logic about LDS needs. * Clang Fix during build. [ROCm/composable_kernel commit: 5f1ad09b610cb0e083f63988479ab022bda70588] --- CHANGELOG.md | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 35 +- example/ck_tile/03_gemm/universal_gemm.cpp | 8 +- .../algorithm/static_encoding_pattern.hpp | 92 +++-- .../ops/epilogue/cshuffle_epilogue.hpp | 7 +- include/ck_tile/ops/gemm.hpp | 2 + .../block/block_gemm_areg_breg_creg_v1.hpp | 160 ++++++-- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 17 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 10 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 1 + .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 1 + .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 379 ++++++++++++++++++ ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 63 +++ .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 1 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 2 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 54 ++- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 4 +- 17 files changed, 727 insertions(+), 110 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index aecf16d83d..af8d965b30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. * Added benchmarking support for tile engine GEMM. +* Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. ### Optimized diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index cd4ace6d2f..f3d11c751b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -14,6 +14,7 @@ #define CK_TILE_PIPELINE_COMPUTE_V3 1 #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 #ifndef CK_TILE_PIPELINE_DEFAULT #define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 @@ -31,6 +32,10 @@ #define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4 #define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4 #define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV5 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV5 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave #else #error "unsupported CK_TILE_PIPELINE_DEFAULT value" #endif @@ -51,7 +56,8 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler @@ -67,7 +73,8 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t NumWaveGroups = 1; #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) // Compute friendly for Intrawave scheduler // Using the ping pong reader in the lds level @@ -83,7 +90,29 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t NumWaveGroups = 1; +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5) + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + + // Available wavegroups will be split into `NumWaveGroups` and each of these wavegroups + // will be responsible for specific jobs. For instance, perform Global Memory read operations, + // perform block-gemm operation etc... + static constexpr ck_tile::index_t NumWaveGroups = 2; #endif static constexpr bool kPadM = false; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index 3a7cc93df8..fafe40c333 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -50,7 +50,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& CLayout, GemmConfig::TransposeC, GemmConfig::UseStructuredSparsity, - Persistent>; + Persistent, + GemmConfig::NumWaveGroups>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -96,7 +97,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, UniversalGemmProblem::TransposeC, - memory_operation>>; + memory_operation, + GemmConfig::NumWaveGroups>>; + using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -190,7 +193,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& }; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; } diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index b56bda3741..d8a8f6ab66 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -56,19 +56,24 @@ template + tile_distribution_pattern DistributionPattern, + index_t NumWaveGroups = 1> struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern { }; // Thread raked -template +template struct TileDistributionEncodingPattern2D - : public TileDistributionEncodingPattern + tile_distribution_pattern::thread_raked, + NumWaveGroups> : public TileDistributionEncodingPattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! @@ -83,45 +88,76 @@ struct TileDistributionEncodingPattern2D, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + if constexpr(NumWaveGroups != 1) + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } } CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() { - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<1, 2>>{}); + if constexpr(NumWaveGroups != 1) + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + else + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<1, 2>>{}); + } } }; // Warp raked -template +template struct TileDistributionEncodingPattern2D - : public TileDistributionEncodingPattern + tile_distribution_pattern::warp_raked, + NumWaveGroups> : public TileDistributionEncodingPattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -164,13 +200,17 @@ struct TileDistributionEncodingPattern2D +template struct TileDistributionEncodingPattern2D - : public TileDistributionEncodingPattern + tile_distribution_pattern::block_raked, + NumWaveGroups> : public TileDistributionEncodingPattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 5a6521deb5..6613ceebb2 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -23,7 +23,8 @@ template + memory_operation_enum MemoryOperation_, + index_t kNumWaveGroups_ = 1> struct CShuffleEpilogueProblem { using ADataType = remove_cvref_t; @@ -41,6 +42,7 @@ struct CShuffleEpilogueProblem static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr index_t kNumWaveGroups = kNumWaveGroups_; }; template @@ -236,7 +238,8 @@ struct CShuffleEpilogue MPerIterationShuffle, NPerIterationShuffle, GetVectorSizeC(), - tile_distribution_pattern::thread_raked>; + tile_distribution_pattern::thread_raked, + Problem::kNumWaveGroups>; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); constexpr auto c_warp_y_lengths = diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 35f5170179..8db822ebd1 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,6 +31,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp index b4362d9069..28d8b3eead 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp @@ -60,52 +60,105 @@ struct BlockGemmARegBRegCRegV1 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp; static constexpr index_t NIterPerWarp = Traits::NIterPerWarp; - static constexpr index_t MWarp = Traits::MWarp; - static constexpr index_t NWarp = Traits::NWarp; + static constexpr index_t MWarp = Traits::MWarp; + static constexpr index_t NWarp = Traits::NWarp; + static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1); CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto a_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; - return a_block_dstr_encode; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } + else + { + constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + + return a_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto b_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - return b_block_dstr_encode; + return b_block_dstr_encode; + } + else + { + constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + return b_block_dstr_encode; + } } CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + if constexpr(UseDefaultScheduler) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - return c_block_dstr_encode; + return c_block_dstr_encode; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + return c_block_dstr_encode; + } } // C += A * B @@ -201,19 +254,38 @@ struct BlockGemmARegBRegCRegV1 CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { - constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< - sequence<>, - tuple, sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; + if constexpr(UseDefaultScheduler) + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple<>, + tuple<>, + sequence<1, 2>, + sequence<0, 0>>{}; - constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); - constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); - return c_block_tensor; + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + else + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } } // C = A * B diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index edcde4a09f..bfb0d2626b 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -644,6 +644,7 @@ struct GemmKernel * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * */ + template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -671,11 +672,15 @@ struct GemmKernel const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I2); + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, smem_ptr_0); + } } /** @@ -772,7 +777,9 @@ struct GemmKernel EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { - RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm( + a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } } } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 24bd66a59e..07bfb33252 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -71,7 +71,8 @@ struct GemmPipelineAgBgCrImplBase template CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, const ALdsTensorView& a_lds_block_view, - const ALdsLoadTileDistr&) const + const ALdsLoadTileDistr&, + const array& offset = {0, 0}) const { constexpr bool is_col_major = std::is_same_v; @@ -82,7 +83,7 @@ struct GemmPipelineAgBgCrImplBase auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(YPerTile{}, XPerTile{}), - a_dram_block_window_tmp.get_window_origin(), + a_dram_block_window_tmp.get_window_origin() + offset, Policy::template MakeADramTileDistribution()); // A LDS tile window for store @@ -103,7 +104,8 @@ struct GemmPipelineAgBgCrImplBase template CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp, const BLdsTensorView& b_lds_block_view, - const BLdsLoadTileDistr&) const + const BLdsLoadTileDistr&, + const array& offset = {0, 0}) const { constexpr bool is_row_major = std::is_same_v; @@ -113,7 +115,7 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_dram_window = make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(YPerTile{}, XPerTile{}), - b_dram_block_window_tmp.get_window_origin(), + b_dram_block_window_tmp.get_window_origin() + offset, Policy::template MakeBDramTileDistribution()); // TODO: Do we really need those two tile windows??? diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index a6267e4c89..eb47d9bad6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -143,6 +143,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 6fc6ba2ba2..8424c43e86 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -134,6 +134,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp new file mode 100644 index 0000000000..9ef7f3f0ef --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -0,0 +1,379 @@ +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed Tensor: register + +template +struct BaseGemmPipelineAgBgCrCompV5 +{ + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t) + { + return TailNumber::Empty; + } + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber) + { + return run_func(bool_constant{}, integral_constant{}); + } +}; + +template +struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 +{ + using Base = BaseGemmPipelineAgBgCrCompV5; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr index_t NumWarps = BlockGemmShape::NumWarps; + static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{}); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV5", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem_0) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + static_assert( + KPerBlock % ((NumWarps / 2) * KTileSize) == 0, + "Ping Pong Warps, TileSize and Block Size for K dimensions does not match."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + index_t warp_id = get_warp_id(); + index_t operation_id = + __builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm + + auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock); + auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock); + + auto tensor_views = + Base::GetABLdsTensorViews(static_cast(static_cast(p_smem_0))); + auto& a_lds_block = tensor_views.get(number<0>{}); + auto& b_lds_block = tensor_views.get(number<1>{}); + + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + auto a_windows = Base::GetAWindows( + a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset); + auto& a_copy_dram_window = a_windows.get(number<0>{}); + auto& a_copy_lds_window = a_windows.get(number<1>{}); + auto& a_lds_window = a_windows.get(number<2>{}); + + auto b_windows = Base::GetBWindows( + b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset); + auto& b_copy_dram_window = b_windows.get(number<0>{}); + auto& b_copy_lds_window = b_windows.get(number<1>{}); + auto& b_lds_window = b_windows.get(number<2>{}); + + // DRAM window steps. + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock * NumWarps, 0) + : make_array(0, KPerBlock * NumWarps); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock * NumWarps, 0) + : make_array(0, KPerBlock * NumWarps); + + constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using AGemmTile = decltype(make_static_distributed_tensor(AGemmTileDistr)); + using BGemmTile = decltype(make_static_distributed_tensor(BGemmTileDistr)); + AGemmTile a_tile_0, a_tile_1; + BGemmTile b_tile_0, b_tile_1; + + // Register tile for A and B. + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + ABlockTile a_global_load_tile; + BBlockTile b_global_load_tile; + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile_0 = block_gemm.MakeCBlockTile(); + auto c_block_tile_1 = block_gemm.MakeCBlockTile(); + + CDataType* __restrict__ p_c_lds = static_cast(p_smem_0); + auto c_lds_block_0 = + make_naive_tensor_view(p_c_lds, + make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, 1), + number{}, + number<1>{}); + auto c_window_0 = make_tile_window(c_lds_block_0, + make_tuple(number{}, number{}), + {0, 0}, + c_block_tile_1.get_tile_distribution()); + + // initialize C + if(warp_id == 0) + { + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0); + } + else + { + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1); + } + + // define ping, pong steps here as lambda functions. + auto MemoryOpsStep = [&](auto idx) { + // Memory read half here. + Base::GlobalPrefetch( + a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( + b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_global_load_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + } + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_global_load_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + } + + if(idx == 0) + { + Base::LocalPrefetch(a_tile_0, a_lds_window); + Base::LocalPrefetch(b_tile_0, b_lds_window); + } + else + { + Base::LocalPrefetch(a_tile_1, a_lds_window); + Base::LocalPrefetch(b_tile_1, b_lds_window); + } + }; + + auto ComputeStep = [&](auto idx) { + if(idx == 0) + { + block_gemm(c_block_tile_0, a_tile_0, b_tile_0); + } + else + { + block_gemm(c_block_tile_1, a_tile_1, b_tile_1); + } + }; + + if(operation_id == 0) + { + MemoryOpsStep(warp_id); + } + + index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop); + while(num_compute_steps > 1) + { + block_sync_lds(); + operation_id = (operation_id + 1) % NumWaveGroups; + + if(operation_id == 0) + { + MemoryOpsStep(warp_id); + } + else + { + ComputeStep(warp_id); + } + num_compute_steps -= 1; + } + block_sync_lds(); + + if(operation_id == 0) + { + ComputeStep(warp_id); + } + block_sync_lds(); + + if(warp_id == 1) + { + store_tile(c_window_0, c_block_tile_1); + } + block_sync_lds(); + + if(warp_id == 0) + { + load_tile(c_block_tile_1, c_window_0); + + constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + auto idx2 = make_tuple(idx0, idx1); + c_block_tile_0(idx2) += c_block_tile_1(idx2); + }); + }); + } + return c_block_tile_0; + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem_0); + } + + public: + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + num_loop, + p_smem_0); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp new file mode 100644 index 0000000000..c03db08c3f --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompV5DefaultPolicy + : public UniversalGemmBasePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } + + template + CK_TILE_DEVICE static constexpr index_t GetSmemSizeC() + { + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + + return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock, + 16); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + constexpr index_t smem_size_c = GetSmemSizeC(); + + return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b) + : (smem_size_c); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index f7b5f9b3cb..1f2ab80797 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -188,6 +188,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr bool kPadK = Problem::kPadK; static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Where is the right place for HasHotLoop and TailNum ??? static constexpr bool HasHotLoop = Problem::HasHotLoop; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 0b38e7789e..678fb6eb46 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -198,6 +198,8 @@ struct UniversalGemmPipelineProblem static constexpr bool TransposeC = Traits::TransposeC; static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + + static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 6890cf2f64..91e845d200 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) @@ -438,7 +439,8 @@ struct UniversalGemmBasePolicy MPerBlock, KPerBlock, VecLoadSize, - ATileAccessPattern>; + ATileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } // Tile: KPerBlock X MPerBlock @@ -448,7 +450,8 @@ struct UniversalGemmBasePolicy KPerBlock, MPerBlock, VecLoadSize, - ATileAccessPattern>; + ATileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } } @@ -458,10 +461,11 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) @@ -470,7 +474,8 @@ struct UniversalGemmBasePolicy KPerBlock, NPerBlock, VecLoadSize, - BTileAccessPattern>; + BTileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } // Tile: NPerBlock X KPerBlock @@ -480,7 +485,8 @@ struct UniversalGemmBasePolicy NPerBlock, KPerBlock, VecLoadSize, - BTileAccessPattern>; + BTileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::Make2DStaticTileDistribution(); } } @@ -490,16 +496,18 @@ struct UniversalGemmBasePolicy { using ALayout = remove_cvref_t; static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using TileEncodingPattern = TileDistributionEncodingPattern2D; + ATileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } @@ -508,16 +516,18 @@ struct UniversalGemmBasePolicy { using BLayout = remove_cvref_t; static_assert(std::is_same_v); - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using TileEncodingPattern = TileDistributionEncodingPattern2D; + BTileAccessPattern, + NumWaveGroups>; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index a61b0eee3c..353192d86f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -39,7 +39,8 @@ template + bool UsePersistentKernel_ = false, + index_t NumWaveGroups_ = 1> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; @@ -55,6 +56,7 @@ struct TileGemmUniversalTraits static constexpr bool TransposeC = TransposeC_; static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; static constexpr bool UsePersistentKernel = UsePersistentKernel_; + static constexpr index_t NumWaveGroups = NumWaveGroups_; }; template