diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 83fde8764b..9b8dde1905 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -65,6 +65,19 @@ struct CShuffleEpilogue static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr index_t kMPerIteration = kMPerXdl * kMWave; + static constexpr index_t kNPerIteration = kNPerXdl * kNWave; + + using WG = WarpGemmMfmaDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; /** * @brief Get the vector store size for C tensor. @@ -78,62 +91,10 @@ struct CShuffleEpilogue */ CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { - constexpr index_t max_vector_store_size = 16; - return max_vector_store_size / sizeof(ODataType); + constexpr index_t MaxVectorStoreSize = 16; + return MaxVectorStoreSize / sizeof(ODataType); } - /** - * @brief Shuffle tile configuration parameters - * - * @details These parameters control the number of XDL tiles processed per wave in each shuffle - * iteration: - * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave - * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave - */ - static constexpr auto shuffle_tile_tuple = [] { - constexpr index_t vecPerThread = kMPerXdl * kNPerXdl / get_warp_size(); - if constexpr(vecPerThread >= GetVectorSizeC()) - { - return std::make_tuple(1, 1); - } - else - { - constexpr index_t num_xdl_shuffles = GetVectorSizeC() / vecPerThread; - if constexpr(std::is_same_v) - { - static_assert((kMPerBlock % (kMPerXdl * kMWave) == 0) && - (kMPerBlock % num_xdl_shuffles == 0), - "kMPerBlock must be divisible by kMPerXdl*kMWave and " - "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (kMPerXdl * kMWave)), 1); - } - else - { - static_assert((kNPerBlock % (kNPerXdl * kNWave) == 0) && - (kNPerBlock % num_xdl_shuffles == 0), - "kNPerBlock must be divisible by kNPerXdl*kNWave and " - "num_xdl_shuffles for CShuffleEpilogue"); - return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (kNPerXdl * kNWave))); - } - } - }(); - static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); - static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); - - static constexpr index_t kMPerIteration = kMPerXdl * kMWave * NumMXdlPerWavePerShuffle; - static constexpr index_t kNPerIteration = kNPerXdl * kNWave * NumNXdlPerWavePerShuffle; - - using WG = WarpGemmMfmaDispatcher; - - using CWarpDstr = typename WG::CWarpDstr; - using CWarpTensor = typename WG::CWarpTensor; - template CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() { @@ -141,17 +102,15 @@ struct CShuffleEpilogue if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( - make_tuple(number{}, - number{}), - make_tuple(number{}, number<1>{})); + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{})); } // M is contiguous dimension else if constexpr(std::is_same_v) { return make_naive_tensor_descriptor( - make_tuple(number{}, - number{}), - make_tuple(number<1>{}, number{})); + make_tuple(number{}, number{}), + make_tuple(number<1>{}, number{})); } else { @@ -159,57 +118,34 @@ struct CShuffleEpilogue } } - CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() - { - constexpr auto block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, - sequence>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( - block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); - - return block_dstr_encoding; - } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return kMWave * kNWave * kMPerXdl * kNPerXdl * NumMXdlPerWavePerShuffle * - NumNXdlPerWavePerShuffle * sizeof(ODataType); + return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType); } template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { - constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); - auto lds_tile = make_static_distributed_tensor(LdsTileDistr); + const index_t iMWarp = get_warp_id() / kNWave; + const index_t iNWarp = get_warp_id() - iMWarp * kNWave; constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); auto o_lds_block = make_tensor_view( static_cast(p_smem), lds_block_desc); - auto in_lds_window = make_tile_window(o_lds_block, - make_tuple(number{}, - number{}), - {0, 0}, - LdsTileDistr); - + make_tuple(number{}, number{}), + {number{} * iMWarp, number{} * iNWarp}); auto out_lds_window = make_tile_window(o_lds_block, - make_tuple(number{}, - number{}), + make_tuple(number{}, number{}), {0, 0}); using SFC = space_filling_curve, sequence<0, 1>, - sequence>; + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); using TileEncodingPattern = @@ -224,25 +160,21 @@ struct CShuffleEpilogue to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - block_sync_lds(); + CWarpTensor c_warp_in_tensor; static_for<0, num_access, 1>{}([&](auto iAccess) { constexpr auto idx_y_start = SFC::get_index(iAccess); - constexpr auto mIter = number{}) / - (kMPerXdl * kMWave * NumMXdlPerWavePerShuffle)>{}; - constexpr auto nIter = number{}) / - (kNPerXdl * kNWave * NumNXdlPerWavePerShuffle)>{}; + constexpr auto mIter = number{}) / (kMPerXdl * kMWave)>{}; + constexpr auto nIter = number{}) / (kNPerXdl * kNWave)>{}; - lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); + c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); + const auto c_warp_in_tensor_casted = cast_tile(c_warp_in_tensor); - store_tile(in_lds_window, c_warptile_in_tensor_casted); + block_sync_lds(); + store_tile(in_lds_window, c_warp_in_tensor_casted); block_sync_lds(); const auto c_out_tensor =