From 235019100910c0e507de65ef0370ea30ca7b052a Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 10 Jun 2025 22:44:50 -0700 Subject: [PATCH] Epilogue cshuffle Improvement (#2312) * add cshuffle's mxdlperwavepershuffle support, not finished * add epilogue functions * add cshuffle's mxdlperwavepershuffle support, not finished * add epilogue functions * update cshuffle logic * update cshuffle_logics * add some change within review * update some codes following the code review * update epilogue logic * remove from problem * update codes following review. * fix some issues * solve the previous PR error, refine the code * Update include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Comment addressed * handling tile_engine failing case * handling tile_engine failing case --------- Co-authored-by: joyeamd Co-authored-by: joye Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: khushbu agarwal [ROCm/composable_kernel commit: 06e0b8436c218349f08527cf0e5d2c502c622b77] --- .../ops/epilogue/cshuffle_epilogue.hpp | 194 ++++++++++++------ 1 file changed, 133 insertions(+), 61 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 1f53dfd93c..5a6521deb5 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -17,11 +17,11 @@ template struct CShuffleEpilogueProblem @@ -34,11 +34,11 @@ struct CShuffleEpilogueProblem static constexpr index_t kBlockSize = kBlockSize_; static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; - static constexpr index_t kMWave = kMWave_; - static constexpr index_t kNWave = kNWave_; - static constexpr index_t kMPerXdl = kMPerXdl_; - static constexpr index_t kNPerXdl = kNPerXdl_; - static constexpr index_t kKPerXdl = kKPerXdl_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; @@ -59,25 +59,14 @@ struct CShuffleEpilogue static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = Problem::kMPerBlock; static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t kMWave = Problem::kMWave; - static constexpr index_t kNWave = Problem::kNWave; - static constexpr index_t kMPerXdl = Problem::kMPerXdl; - static constexpr index_t kNPerXdl = Problem::kNPerXdl; - static constexpr index_t kKPerXdl = Problem::kKPerXdl; + static constexpr index_t MWave = Problem::MWave; + static constexpr index_t NWave = Problem::NWave; + static constexpr index_t MPerXdl = Problem::MPerXdl; + static constexpr index_t NPerXdl = Problem::NPerXdl; + static constexpr index_t KPerXdl = Problem::KPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr index_t kMPerIteration = kMPerXdl * kMWave; - static constexpr index_t kNPerIteration = kNPerXdl * kNWave; - - using WG = WarpGemmMfmaDispatcher; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; /** * @brief Get the vector store size for C tensor. @@ -89,18 +78,18 @@ struct CShuffleEpilogue * * @return The vector store size for C tensor. */ - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() { - constexpr index_t MaxVectorStoreSize = 16; + constexpr index_t max_vector_size = 16; if constexpr(std::is_same_v) { - return std::min(static_cast(kNPerIteration), - static_cast(MaxVectorStoreSize / sizeof(ODataType))); + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(ODataType))); } else if constexpr(std::is_same_v) { - return std::min(static_cast(kMPerIteration), - static_cast(MaxVectorStoreSize / sizeof(ODataType))); + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(ODataType))); } else { @@ -108,6 +97,65 @@ struct CShuffleEpilogue } } + /** + * @brief Shuffle tile configuration parameters + * + * @details These parameters control the number of XDL tiles processed per wave in each shuffle + * iteration: + * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave + * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave + */ + static constexpr auto shuffle_tile_tuple = [] { + constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); + if constexpr(elem_per_thread >= GetVectorSizeC()) + { + return std::make_tuple(1, 1); + } + else + { + constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; + if constexpr(std::is_same_v) + { + static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && + (kMPerBlock % num_xdl_shuffles == 0), + "kMPerBlock must be divisible by MPerXdl*MWave and " + "num_xdl_shuffles for CShuffleEpilogue"); + return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); + } + else + { + static_assert((kNPerBlock % (NPerXdl * NWave) == 0) && + (kNPerBlock % num_xdl_shuffles == 0), + "kNPerBlock must be divisible by NPerXdl*NWave and " + "num_xdl_shuffles for CShuffleEpilogue"); + return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); + } + } + }(); + static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); + static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); + + static constexpr auto MNPerIterationShuffle = [] { + constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; + constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle; + if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0) + return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave); + else + return std::make_tuple(m_val, n_val); + }(); + static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); + static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); + using WG = WarpGemmMfmaDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { @@ -115,15 +163,15 @@ struct CShuffleEpilogue if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{})); + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{})); } // M is contiguous dimension else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number<1>{}, number{})); + make_tuple(number{}, number{}), + make_tuple(number<1>{}, number{})); } else { @@ -131,40 +179,62 @@ struct CShuffleEpilogue } } + CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() + { + constexpr auto block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); + + return block_dstr_encoding; + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType); + return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); } template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { + constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - const index_t iMWarp = get_warp_id() / kNWave; - const index_t iNWarp = get_warp_id() - iMWarp * kNWave; + auto lds_tile = make_static_distributed_tensor(LdsTileDistr); constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); auto o_lds_block = make_tensor_view( static_cast(p_smem), lds_block_desc); - auto in_lds_window = - make_tile_window(o_lds_block, - make_tuple(number{}, number{}), - {number{} * iMWarp, number{} * iNWarp}); - auto out_lds_window = - make_tile_window(o_lds_block, - make_tuple(number{}, number{}), - {0, 0}); + + auto in_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + LdsTileDistr); + + auto out_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); using SFC = space_filling_curve, sequence<0, 1>, - sequence>; + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); + static_assert(std::is_same_v, + "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); + using TileEncodingPattern = TileDistributionEncodingPattern2D; constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); @@ -173,21 +243,23 @@ struct CShuffleEpilogue to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - CWarpTensor c_warp_in_tensor; static_for<0, num_access, 1>{}([&](auto iAccess) { + block_sync_lds(); constexpr auto idx_y_start = SFC::get_index(iAccess); - constexpr auto mIter = number{}) / (kMPerXdl * kMWave)>{}; - constexpr auto nIter = number{}) / (kNPerXdl * kNWave)>{}; + constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; + constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); - const auto c_warp_in_tensor_casted = cast_tile(c_warp_in_tensor); + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - block_sync_lds(); - store_tile(in_lds_window, c_warp_in_tensor_casted); + store_tile(in_lds_window, c_warptile_in_tensor_casted); block_sync_lds(); const auto c_out_tensor =