add CShuffleM/NXdlPerWavePerShuffle in cshuffle_epilogue (#2185)

* add cshuffle's mxdlperwavepershuffle support, not finished

* add epilogue functions

* add cshuffle's mxdlperwavepershuffle support, not finished

* add epilogue functions

* update cshuffle logic

* update cshuffle_logics

* add some change within review

* update some codes following the code review

* update epilogue logic

* remove from problem

* update codes following review.

* fix some issues
This commit is contained in:
joyeamd
2025-05-29 20:31:14 +08:00
committed by GitHub
parent 28cd0dffc9
commit fd6a859b44

View File

@@ -65,19 +65,6 @@ struct CShuffleEpilogue
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
/**
* @brief Get the vector store size for C tensor.
@@ -91,10 +78,62 @@ struct CShuffleEpilogue
*/
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
constexpr index_t MaxVectorStoreSize = 16;
return MaxVectorStoreSize / sizeof(ODataType);
constexpr index_t max_vector_store_size = 16;
return max_vector_store_size / sizeof(ODataType);
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t vecPerThread = kMPerXdl * kNPerXdl / get_warp_size();
if constexpr(vecPerThread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / vecPerThread;
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (kMPerXdl * kMWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by kMPerXdl*kMWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (kMPerXdl * kMWave)), 1);
}
else
{
static_assert((kNPerBlock % (kNPerXdl * kNWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by kNPerXdl*kNWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (kNPerXdl * kNWave)));
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr index_t kMPerIteration = kMPerXdl * kMWave * NumMXdlPerWavePerShuffle;
static constexpr index_t kNPerIteration = kNPerXdl * kNWave * NumNXdlPerWavePerShuffle;
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
@@ -102,15 +141,17 @@ struct CShuffleEpilogue
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
make_tuple(number<kMWave * kMPerXdl * NumMXdlPerWavePerShuffle>{},
number<kNWave * kNPerXdl * NumNXdlPerWavePerShuffle>{}),
make_tuple(number<kNWave * kNPerXdl * NumNXdlPerWavePerShuffle>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
make_tuple(number<kMWave * kMPerXdl * NumMXdlPerWavePerShuffle>{},
number<kNWave * kNPerXdl * NumNXdlPerWavePerShuffle>{}),
make_tuple(number<1>{}, number<kMWave * kMPerXdl * NumMXdlPerWavePerShuffle>{}));
}
else
{
@@ -118,34 +159,57 @@ struct CShuffleEpilogue
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, kMWave>,
sequence<NumNXdlPerWavePerShuffle, kNWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
return kMWave * kNWave * kMPerXdl * kNPerXdl * NumMXdlPerWavePerShuffle *
NumNXdlPerWavePerShuffle * sizeof(ODataType);
}
template <typename ODramWindow, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
const index_t iMWarp = get_warp_id() / kNWave;
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
make_tuple(number<kMWave * kMPerXdl * NumMXdlPerWavePerShuffle>{},
number<kNWave * kNPerXdl * NumNXdlPerWavePerShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
make_tuple(number<kMWave * kMPerXdl * NumMXdlPerWavePerShuffle>{},
number<kNWave * kNPerXdl * NumNXdlPerWavePerShuffle>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
sequence<kMPerXdl * kMWave * NumMXdlPerWavePerShuffle,
kNPerXdl * kNWave * NumNXdlPerWavePerShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
using TileEncodingPattern =
@@ -160,21 +224,25 @@ struct CShuffleEpilogue
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
CWarpTensor c_warp_in_tensor;
block_sync_lds();
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
constexpr auto mIter = number<idx_y_start.at(number<0>{}) /
(kMPerXdl * kMWave * NumMXdlPerWavePerShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) /
(kNPerXdl * kNWave * NumNXdlPerWavePerShuffle)>{};
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =