mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
WIP: Separate epilogue for merged conv groups.
This commit is contained in:
@@ -195,23 +195,12 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
if constexpr (NumGroupsToMerge == 1)
|
||||
{
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}
|
||||
else
|
||||
{
|
||||
// When NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
|
||||
// Hence, we configure the shuffle such that it iterates one merge group block at a time.
|
||||
constexpr index_t MPerGroupBlock = kMPerBlock / NumGroupsToMerge;
|
||||
constexpr index_t NPerGroupBlock = kNPerBlock / NumGroupsToMerge;
|
||||
return std::make_tuple(MPerGroupBlock, NPerGroupBlock);
|
||||
}
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
@@ -276,6 +265,26 @@ struct CShuffleEpilogue
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
|
||||
{
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
// When NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
|
||||
// Hence, we configure the shuffle such that it iterates one merge group block at a time.
|
||||
return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
else
|
||||
{
|
||||
// When NumGroupsToMerge == 1, we want to write out all the blocks.
|
||||
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
@@ -284,31 +293,11 @@ struct CShuffleEpilogue
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
// The full tile size is (kMPerBlock, kNPerBlock).
|
||||
// We access the tile in MPerIterationShuffle x NPerIterationShuffle chunks
|
||||
// using the serpentine order.
|
||||
// The serpentine access order should be better than row-major order when
|
||||
// the number of merged groups is even (we typically have powers of 2 for the number of merged groups).
|
||||
// When the NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
|
||||
// Hence, we configure the SFC such that it iterates one merge group block at a time.
|
||||
constexpr index_t MPerGroupBlock = kMPerBlock / NumGroupsToMerge;
|
||||
constexpr index_t NPerGroupBlock = kNPerBlock / NumGroupsToMerge;
|
||||
|
||||
// TODO: Remove these debug asserts that are specific to a given case.
|
||||
static_assert(kMPerBlock == 8, "kMPerBlock must be 8");
|
||||
static_assert(kNPerBlock == 128, "kNPerBlock must be 128");
|
||||
static_assert(MPerIterationShuffle == 1, "MPerIterationShuffle must be 1");
|
||||
static_assert(NPerIterationShuffle == 16, "NPerIterationShuffle must be 16");
|
||||
static_assert(NumMXdlPerWavePerShuffle == 1, "NumMXdlPerWavePerShuffle must be 1");
|
||||
static_assert(NumNXdlPerWavePerShuffle == 1, "NumNXdlPerWavePerShuffle must be 1");
|
||||
|
||||
static_assert(MPerIterationShuffle == MPerGroupBlock,
|
||||
"MPerIterationShuffle should be equal to MPerGroupBlock");
|
||||
static_assert(NPerIterationShuffle == NPerGroupBlock,
|
||||
"NPerIterationShuffle should be equal to NPerGroupBlock");
|
||||
}
|
||||
// TODO: Remove these debug asserts that are specific to a given case.
|
||||
static_assert(kMPerBlock == 8, "kMPerBlock must be 8");
|
||||
static_assert(kNPerBlock == 128, "kNPerBlock must be 128");
|
||||
static_assert(NumMXdlPerWavePerShuffle == 1, "NumMXdlPerWavePerShuffle must be 1");
|
||||
static_assert(NumNXdlPerWavePerShuffle == 1, "NumNXdlPerWavePerShuffle must be 1");
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
@@ -316,19 +305,10 @@ struct CShuffleEpilogue
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
using SFC_dram = space_filling_curve<sequence<kMPerBlock, kNPerBlock / NumGroupsToMerge>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
@@ -353,91 +333,55 @@ struct CShuffleEpilogue
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// Ensure that we have the expected number of accesses.
|
||||
if constexpr (NumGroupsToMerge > 1)
|
||||
{
|
||||
static_assert(num_access == NumGroupsToMerge * NumGroupsToMerge,
|
||||
"Number of accesses must be equal to NumGroupsToMerge squared.");
|
||||
// Store full data to LDS.
|
||||
block_sync_lds();
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
static_for<0, NumGroupsToMerge, 1>{}
|
||||
(
|
||||
[&](auto group)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
move_tile_window(in_lds_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
});
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
// Copy diagonal block from LDS to global memory.
|
||||
constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge;
|
||||
constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge;
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerGroup>{}, number<NPerGroup>{}),
|
||||
{0, 0});
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
using SFC_lds = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerGroup, NPerGroup>>;
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
if constexpr(group != NumGroupsToMerge - 1)
|
||||
{
|
||||
constexpr auto step = SFC_dram::get_forward_step(group);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
using SFC_dram = space_filling_curve<sequence<kMPerBlock, kNPerBlock / NumGroupsToMerge>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerGroup, NPerGroup>>;
|
||||
|
||||
static_for<0, NumGroupsToMerge, 1>{}
|
||||
(
|
||||
[&](auto group)
|
||||
{
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
@@ -459,20 +403,133 @@ struct CShuffleEpilogue
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
// TODO: This probably doesn't work correctly.
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
if constexpr(group != NumGroupsToMerge - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
// Move the global memory window
|
||||
constexpr auto step_dram = SFC_dram::get_forward_step(group);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(out_dram_window, {step_dram.at(number<0>{}), step_dram.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
{step_dram.at(number<0>{}), step_dram.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Move the LDS window
|
||||
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
|
||||
constexpr auto next_iAccess = number<(group+1) * NumGroupsToMerge + (group+1)>{};
|
||||
constexpr auto step_lds = SFC_lds::get_step_between(iAccess, next_iAccess);
|
||||
move_tile_window(out_lds_window, {step_lds.at(number<0>{}), step_lds.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto unmerged_op(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::sparse_row,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user