mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Merge some updates for ck_tile headers (#3342)
* fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm
This commit is contained in:
@@ -333,14 +333,30 @@ struct CShuffleEpilogue
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
// this branch is for original a16w4
|
||||
if constexpr(is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
|
||||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 1>>{};
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
@@ -351,7 +367,8 @@ struct CShuffleEpilogue
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
|
||||
|
||||
Reference in New Issue
Block a user