Merge some updates for ck_tile headers (#3342)

* fix some issues from internal branch

* update cshuffle_epilogue

* update cshuffle_epilogue

* update cshuffle

* update warp_gemm
This commit is contained in:
joyeamd
2026-01-06 15:39:00 +08:00
committed by GitHub
parent 2b563ad048
commit b78563b3d3
14 changed files with 205 additions and 119 deletions

View File

@@ -333,14 +333,30 @@ struct CShuffleEpilogue
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
// this branch is for original a16w4
if constexpr(is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 1>>{};
}
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
@@ -351,7 +367,8 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
}
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>