diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 92cd558876..0dab7fc22d 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -195,23 +195,12 @@ struct CShuffleEpilogue static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple); static constexpr auto MNPerIterationShuffle = [] { - if constexpr (NumGroupsToMerge == 1) - { - constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; + constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle; if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0) return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave); else return std::make_tuple(m_val, n_val); - } - else - { - // When NumGroupsToMerge > 1, we want to write out only the diagonal blocks. - // Hence, we configure the shuffle such that it iterates one merge group block at a time. - constexpr index_t MPerGroupBlock = kMPerBlock / NumGroupsToMerge; - constexpr index_t NPerGroupBlock = kNPerBlock / NumGroupsToMerge; - return std::make_tuple(MPerGroupBlock, NPerGroupBlock); - } }(); static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); @@ -276,6 +265,26 @@ struct CShuffleEpilogue const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* p_smem) + + { + if constexpr (NumGroupsToMerge > 1) + { + // When NumGroupsToMerge > 1, we want to write out only the diagonal blocks. + // Hence, we configure the shuffle such that it iterates one merge group block at a time. + return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); + } + else + { + // When NumGroupsToMerge == 1, we want to write out all the blocks. + return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem); + } + } + + template + CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* p_smem) { constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); auto lds_tile = make_static_distributed_tensor(LdsTileDistr); @@ -284,31 +293,11 @@ struct CShuffleEpilogue auto o_lds_block = make_tensor_view( static_cast(p_smem), lds_block_desc); - if constexpr (NumGroupsToMerge > 1) - { - // The full tile size is (kMPerBlock, kNPerBlock). - // We access the tile in MPerIterationShuffle x NPerIterationShuffle chunks - // using the serpentine order. - // The serpentine access order should be better than row-major order when - // the number of merged groups is even (we typically have powers of 2 for the number of merged groups). - // When the NumGroupsToMerge > 1, we want to write out only the diagonal blocks. - // Hence, we configure the SFC such that it iterates one merge group block at a time. - constexpr index_t MPerGroupBlock = kMPerBlock / NumGroupsToMerge; - constexpr index_t NPerGroupBlock = kNPerBlock / NumGroupsToMerge; - - // TODO: Remove these debug asserts that are specific to a given case. - static_assert(kMPerBlock == 8, "kMPerBlock must be 8"); - static_assert(kNPerBlock == 128, "kNPerBlock must be 128"); - static_assert(MPerIterationShuffle == 1, "MPerIterationShuffle must be 1"); - static_assert(NPerIterationShuffle == 16, "NPerIterationShuffle must be 16"); - static_assert(NumMXdlPerWavePerShuffle == 1, "NumMXdlPerWavePerShuffle must be 1"); - static_assert(NumNXdlPerWavePerShuffle == 1, "NumNXdlPerWavePerShuffle must be 1"); - - static_assert(MPerIterationShuffle == MPerGroupBlock, - "MPerIterationShuffle should be equal to MPerGroupBlock"); - static_assert(NPerIterationShuffle == NPerGroupBlock, - "NPerIterationShuffle should be equal to NPerGroupBlock"); - } + // TODO: Remove these debug asserts that are specific to a given case. + static_assert(kMPerBlock == 8, "kMPerBlock must be 8"); + static_assert(kNPerBlock == 128, "kNPerBlock must be 128"); + static_assert(NumMXdlPerWavePerShuffle == 1, "NumMXdlPerWavePerShuffle must be 1"); + static_assert(NumNXdlPerWavePerShuffle == 1, "NumNXdlPerWavePerShuffle must be 1"); auto in_lds_window = make_tile_window( o_lds_block, @@ -316,19 +305,10 @@ struct CShuffleEpilogue {0, 0}, LdsTileDistr); - auto out_lds_window = make_tile_window( - o_lds_block, - make_tuple(number{}, number{}), - {0, 0}); - using SFC = space_filling_curve, sequence<0, 1>, sequence>; - using SFC_dram = space_filling_curve, - sequence<0, 1>, - sequence>; - constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, @@ -353,91 +333,55 @@ struct CShuffleEpilogue to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; - // Ensure that we have the expected number of accesses. - if constexpr (NumGroupsToMerge > 1) - { - static_assert(num_access == NumGroupsToMerge * NumGroupsToMerge, - "Number of accesses must be equal to NumGroupsToMerge squared."); + // Store full data to LDS. + block_sync_lds(); + static_for<0, num_access, 1>{}([&](auto iAccess) { + + constexpr auto idx_y_start = SFC::get_index(iAccess); - static_for<0, NumGroupsToMerge, 1>{} - ( - [&](auto group) - { - block_sync_lds(); - constexpr auto iAccess = number{}; - constexpr auto idx_y_start = SFC::get_index(iAccess); + constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; + constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; + lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); - lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); + store_tile(in_lds_window, c_warptile_in_tensor_casted); + + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); - store_tile(in_lds_window, c_warptile_in_tensor_casted); - block_sync_lds(); + move_tile_window(in_lds_window, {step.at(number<0>{}), step.at(number<1>{})}); + } + }); + block_sync_lds(); - auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + // Copy diagonal block from LDS to global memory. + constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge; + constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge; + auto out_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); - const auto ds_tensor = generate_tuple( - [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + using SFC_lds = space_filling_curve, + sequence<0, 1>, + sequence>; - const auto c_ds_tiles = concat_tuple_of_reference( - tie(c_out_tensor, c_out_tensor), - generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, - number{})); - - tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); - - if constexpr(MemoryOperation == memory_operation_enum::set) - { - store_tile(out_dram_window, c_out_tensor); - } - else - { - update_tile(out_dram_window, c_out_tensor); - } - - if constexpr(group != NumGroupsToMerge - 1) - { - constexpr auto step = SFC_dram::get_forward_step(group); - - move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); - - static_for<0, NumDTensor, 1>{}([&](auto idx) { - move_tile_window(d_dram_windows[idx], - {step.at(number<0>{}), step.at(number<1>{})}); - }); - } - } - ); - } - else - { - static_for<0, num_access, 1>{}([&](auto iAccess) { - block_sync_lds(); - constexpr auto idx_y_start = SFC::get_index(iAccess); - - constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; - constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; - - lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence{}, - c_warp_y_lengths)); - - const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); - - store_tile(in_lds_window, c_warptile_in_tensor_casted); - block_sync_lds(); + using SFC_dram = space_filling_curve, + sequence<0, 1>, + sequence>; + static_for<0, NumGroupsToMerge, 1>{} + ( + [&](auto group) + { auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); const auto ds_tensor = generate_tuple( @@ -459,20 +403,133 @@ struct CShuffleEpilogue update_tile(out_dram_window, c_out_tensor); } - // TODO: This probably doesn't work correctly. - if constexpr(iAccess != num_access - 1) + if constexpr(group != NumGroupsToMerge - 1) { - constexpr auto step = SFC::get_forward_step(iAccess); + // Move the global memory window + constexpr auto step_dram = SFC_dram::get_forward_step(group); - move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + move_tile_window(out_dram_window, {step_dram.at(number<0>{}), step_dram.at(number<1>{})}); static_for<0, NumDTensor, 1>{}([&](auto idx) { move_tile_window(d_dram_windows[idx], - {step.at(number<0>{}), step.at(number<1>{})}); + {step_dram.at(number<0>{}), step_dram.at(number<1>{})}); }); - } - }); - } + + // Move the LDS window + constexpr auto iAccess = number{}; + constexpr auto next_iAccess = number<(group+1) * NumGroupsToMerge + (group+1)>{}; + constexpr auto step_lds = SFC_lds::get_step_between(iAccess, next_iAccess); + move_tile_window(out_lds_window, {step_lds.at(number<0>{}), step_lds.at(number<1>{})}); + } + } + ); + } + + template + CK_TILE_DEVICE auto unmerged_op(ODramWindow& out_dram_window, + const OAccTile& o_acc_tile, + const DsDramWindows& ds_dram_windows, + void* p_smem) + { + constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode()); + auto lds_tile = make_static_distributed_tensor(LdsTileDistr); + + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + auto o_lds_block = make_tensor_view( + static_cast(p_smem), lds_block_desc); + + auto in_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + LdsTileDistr); + + auto out_lds_window = make_tile_window( + o_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + + using SFC = space_filling_curve, + sequence<0, 1>, + sequence>; + + constexpr index_t num_access = SFC::get_num_of_access(); + + static_assert(std::is_same_v, + "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); + + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + constexpr auto dram_tile_distribution = + TileEncodingPattern::make_2d_static_tile_distribution(); + + auto d_dram_windows = generate_tuple( + [&](auto idx) { + return make_tile_window(ds_dram_windows[idx], dram_tile_distribution); + }, + number{}); + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + + static_for<0, num_access, 1>{}([&](auto iAccess) { + block_sync_lds(); + constexpr auto idx_y_start = SFC::get_index(iAccess); + + constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; + constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; + + lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence{}, + c_warp_y_lengths)); + + const auto c_warptile_in_tensor_casted = cast_tile(lds_tile); + + store_tile(in_lds_window, c_warptile_in_tensor_casted); + block_sync_lds(); + + auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); + + const auto ds_tensor = generate_tuple( + [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number{}); + + const auto c_ds_tiles = concat_tuple_of_reference( + tie(c_out_tensor, c_out_tensor), + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); + + tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); + + if constexpr(MemoryOperation == memory_operation_enum::set) + { + store_tile(out_dram_window, c_out_tensor); + } + else + { + update_tile(out_dram_window, c_out_tensor); + } + + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); + + move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); + + static_for<0, NumDTensor, 1>{}([&](auto idx) { + move_tile_window(d_dram_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); + } + }); } }; } // namespace ck_tile