From c1f3d81e76873ea5f2caf8c8f6416daee8b2efc6 Mon Sep 17 00:00:00 2001 From: joyeamd Date: Thu, 29 May 2025 20:31:14 +0800 Subject: [PATCH] add CShuffleM/NXdlPerWavePerShuffle in cshuffle_epilogue (#2185) * add cshuffle's mxdlperwavepershuffle support, not finished * add epilogue functions * add cshuffle's mxdlperwavepershuffle support, not finished * add epilogue functions * update cshuffle logic * update cshuffle_logics * add some change within review * update some codes following the code review * update epilogue logic * remove from problem * update codes following review. * fix some issues [ROCm/composable_kernel commit: fd6a859b447387c871bc763e4bbccac298a4f30a] --- .../ops/epilogue/cshuffle_epilogue.hpp | 138 +++++++++++++----- 1 file changed, 103 insertions(+), 35 deletions(-) diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9b8dde1905..83fde8764b 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -65,19 +65,6 @@ struct CShuffleEpilogue static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr index_t kMPerIteration = kMPerXdl * kMWave; - static constexpr index_t kNPerIteration = kNPerXdl * kNWave; - - using WG = WarpGemmMfmaDispatcher; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; /** * @brief Get the vector store size for C tensor. @@ -91,10 +78,62 @@ struct CShuffleEpilogue */ CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { - constexpr index_t MaxVectorStoreSize = 16; - return MaxVectorStoreSize / sizeof(ODataType); + constexpr index_t max_vector_store_size = 16; + return max_vector_store_size / sizeof(ODataType); } + /** + * @brief Shuffle tile configuration parameters + * + * @details These parameters control the number of XDL tiles processed per wave in each shuffle + * iteration: + * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave + * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave + */ + static constexpr auto shuffle_tile_tuple = [] { + constexpr index_t vecPerThread = kMPerXdl * kNPerXdl / get_warp_size(); + if constexpr(vecPerThread >= GetVectorSizeC()) + { + return std::make_tuple(1, 1); + } + else + { + constexpr index_t num_xdl_shuffles = GetVectorSizeC() / vecPerThread; + if constexpr(std::is_same_v) + { + static_assert((kMPerBlock % (kMPerXdl * kMWave) == 0) && + (kMPerBlock % num_xdl_shuffles == 0), + "kMPerBlock must be divisible by kMPerXdl*kMWave and " + "num_xdl_shuffles for CShuffleEpilogue"); + return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (kMPerXdl * kMWave)), 1); + } + else + { + static_assert((kNPerBlock % (kNPerXdl * kNWave) == 0) && + (kNPerBlock % num_xdl_shuffles == 0), + "kNPerBlock must be divisible by kNPerXdl*kNWave and " + "num_xdl_shuffles for CShuffleEpilogue"); + return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (kNPerXdl * kNWave))); + } + } + }(); + static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); + static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); + + static constexpr index_t kMPerIteration = kMPerXdl * kMWave * NumMXdlPerWavePerShuffle; + static constexpr index_t kNPerIteration = kNPerXdl * kNWave * NumNXdlPerWavePerShuffle; + + using WG = WarpGemmMfmaDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { @@ -102,15 +141,17 @@ struct CShuffleEpilogue if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number{}, number<1>{})); + make_tuple(number{}, + number{}), + make_tuple(number{}, number<1>{})); } // M is contiguous dimension else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( - make_tuple(number{}, number{}), - make_tuple(number<1>{}, number{})); + make_tuple(number{}, + number{}), + make_tuple(number<1>{}, number{})); } else { @@ -118,34 +159,57 @@ struct CShuffleEpilogue } } + CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() + { + constexpr auto block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); + + return block_dstr_encoding; + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType); + return kMWave * kNWave * kMPerXdl * kNPerXdl * NumMXdlPerWavePerShuffle * + NumNXdlPerWavePerShuffle * sizeof(ODataType); } template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { + constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - const index_t iMWarp = get_warp_id() / kNWave; - const index_t iNWarp = get_warp_id() - iMWarp * kNWave; + auto lds_tile = make_static_distributed_tensor(LdsTileDistr); constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); auto o_lds_block = make_tensor_view( static_cast(p_smem), lds_block_desc); + auto in_lds_window = make_tile_window(o_lds_block, - make_tuple(number{}, number{}), - {number{} * iMWarp, number{} * iNWarp}); + make_tuple(number{}, + number{}), + {0, 0}, + LdsTileDistr); + auto out_lds_window = make_tile_window(o_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, + number{}), {0, 0}); using SFC = space_filling_curve, sequence<0, 1>, - sequence>; + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); using TileEncodingPattern = @@ -160,21 +224,25 @@ struct CShuffleEpilogue to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - CWarpTensor c_warp_in_tensor; + block_sync_lds(); static_for<0, num_access, 1>{}([&](auto iAccess) { constexpr auto idx_y_start = SFC::get_index(iAccess); - constexpr auto mIter = number{}) / (kMPerXdl * kMWave)>{}; - constexpr auto nIter = number{}) / (kNPerXdl * kNWave)>{}; + constexpr auto mIter = number{}) / + (kMPerXdl * kMWave * NumMXdlPerWavePerShuffle)>{}; + constexpr auto nIter = number{}) / + (kNPerXdl * kNWave * NumNXdlPerWavePerShuffle)>{}; - c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); - const auto c_warp_in_tensor_casted = cast_tile(c_warp_in_tensor); + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - block_sync_lds(); - store_tile(in_lds_window, c_warp_in_tensor_casted); + store_tile(in_lds_window, c_warptile_in_tensor_casted); block_sync_lds(); const auto c_out_tensor =