WIP: Separate epilogue for merged conv groups.

This commit is contained in:
Ville Pietilä
2025-09-19 13:52:33 +00:00
parent 437599c517
commit 7dfbac5d0b

View File

@@ -195,23 +195,12 @@ struct CShuffleEpilogue
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr auto MNPerIterationShuffle = [] {
if constexpr (NumGroupsToMerge == 1)
{
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}
else
{
// When NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
// Hence, we configure the shuffle such that it iterates one merge group block at a time.
constexpr index_t MPerGroupBlock = kMPerBlock / NumGroupsToMerge;
constexpr index_t NPerGroupBlock = kNPerBlock / NumGroupsToMerge;
return std::make_tuple(MPerGroupBlock, NPerGroupBlock);
}
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
@@ -276,6 +265,26 @@ struct CShuffleEpilogue
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
if constexpr (NumGroupsToMerge > 1)
{
// When NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
// Hence, we configure the shuffle such that it iterates one merge group block at a time.
return merged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
else
{
// When NumGroupsToMerge == 1, we want to write out all the blocks.
return unmerged_op(out_dram_window, o_acc_tile, ds_dram_windows, p_smem);
}
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto merged_op(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
@@ -284,31 +293,11 @@ struct CShuffleEpilogue
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
if constexpr (NumGroupsToMerge > 1)
{
// The full tile size is (kMPerBlock, kNPerBlock).
// We access the tile in MPerIterationShuffle x NPerIterationShuffle chunks
// using the serpentine order.
// The serpentine access order should be better than row-major order when
// the number of merged groups is even (we typically have powers of 2 for the number of merged groups).
// When the NumGroupsToMerge > 1, we want to write out only the diagonal blocks.
// Hence, we configure the SFC such that it iterates one merge group block at a time.
constexpr index_t MPerGroupBlock = kMPerBlock / NumGroupsToMerge;
constexpr index_t NPerGroupBlock = kNPerBlock / NumGroupsToMerge;
// TODO: Remove these debug asserts that are specific to a given case.
static_assert(kMPerBlock == 8, "kMPerBlock must be 8");
static_assert(kNPerBlock == 128, "kNPerBlock must be 128");
static_assert(MPerIterationShuffle == 1, "MPerIterationShuffle must be 1");
static_assert(NPerIterationShuffle == 16, "NPerIterationShuffle must be 16");
static_assert(NumMXdlPerWavePerShuffle == 1, "NumMXdlPerWavePerShuffle must be 1");
static_assert(NumNXdlPerWavePerShuffle == 1, "NumNXdlPerWavePerShuffle must be 1");
static_assert(MPerIterationShuffle == MPerGroupBlock,
"MPerIterationShuffle should be equal to MPerGroupBlock");
static_assert(NPerIterationShuffle == NPerGroupBlock,
"NPerIterationShuffle should be equal to NPerGroupBlock");
}
// TODO: Remove these debug asserts that are specific to a given case.
static_assert(kMPerBlock == 8, "kMPerBlock must be 8");
static_assert(kNPerBlock == 128, "kNPerBlock must be 128");
static_assert(NumMXdlPerWavePerShuffle == 1, "NumMXdlPerWavePerShuffle must be 1");
static_assert(NumNXdlPerWavePerShuffle == 1, "NumNXdlPerWavePerShuffle must be 1");
auto in_lds_window = make_tile_window(
o_lds_block,
@@ -316,19 +305,10 @@ struct CShuffleEpilogue
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
using SFC_dram = space_filling_curve<sequence<kMPerBlock, kNPerBlock / NumGroupsToMerge>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
@@ -353,91 +333,55 @@ struct CShuffleEpilogue
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// Ensure that we have the expected number of accesses.
if constexpr (NumGroupsToMerge > 1)
{
static_assert(num_access == NumGroupsToMerge * NumGroupsToMerge,
"Number of accesses must be equal to NumGroupsToMerge squared.");
// Store full data to LDS.
block_sync_lds();
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
static_for<0, NumGroupsToMerge, 1>{}
(
[&](auto group)
{
block_sync_lds();
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
move_tile_window(in_lds_window, {step.at(number<0>{}), step.at(number<1>{})});
}
});
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// Copy diagonal block from LDS to global memory.
constexpr index_t MPerGroup = kMPerBlock / NumGroupsToMerge;
constexpr index_t NPerGroup = kNPerBlock / NumGroupsToMerge;
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerGroup>{}, number<NPerGroup>{}),
{0, 0});
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
using SFC_lds = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerGroup, NPerGroup>>;
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(group != NumGroupsToMerge - 1)
{
constexpr auto step = SFC_dram::get_forward_step(group);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
}
);
}
else
{
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
using SFC_dram = space_filling_curve<sequence<kMPerBlock, kNPerBlock / NumGroupsToMerge>,
sequence<0, 1>,
sequence<MPerGroup, NPerGroup>>;
static_for<0, NumGroupsToMerge, 1>{}
(
[&](auto group)
{
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
@@ -459,20 +403,133 @@ struct CShuffleEpilogue
update_tile(out_dram_window, c_out_tensor);
}
// TODO: This probably doesn't work correctly.
if constexpr(iAccess != num_access - 1)
if constexpr(group != NumGroupsToMerge - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
// Move the global memory window
constexpr auto step_dram = SFC_dram::get_forward_step(group);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(out_dram_window, {step_dram.at(number<0>{}), step_dram.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
{step_dram.at(number<0>{}), step_dram.at(number<1>{})});
});
}
});
}
// Move the LDS window
constexpr auto iAccess = number<group * NumGroupsToMerge + group>{};
constexpr auto next_iAccess = number<(group+1) * NumGroupsToMerge + (group+1)>{};
constexpr auto step_lds = SFC_lds::get_step_between(iAccess, next_iAccess);
move_tile_window(out_lds_window, {step_lds.at(number<0>{}), step_lds.at(number<1>{})});
}
}
);
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto unmerged_op(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::sparse_row,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}
};
} // namespace ck_tile