[CK_TILE]fix ck_tile's moe_sorting example in gfx11 (#2667)

* fix ck_tile's moe_sorting example in gfx11

* fix clang format

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
joyeamd
2025-08-13 03:33:56 +08:00
committed by GitHub
parent bbf41b27f2
commit 0856b3f4a2

View File

@@ -63,48 +63,15 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
private:
template <index_t LanesPerK, index_t WarpSize, typename = void>
struct LdsStoreDescSelector;
template <index_t LanesPerK, index_t WarpSize>
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK >= WarpSize)>>
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t WarpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= WarpSize)
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
{
// need multiple waves to load K
static_assert(LanesPerK % WarpSize == 0);
@@ -143,7 +110,13 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
return lds_block_desc_issues_warps_lanes;
}
}
else
};
template <index_t LanesPerK, index_t WarpSize>
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK < WarpSize)>>
{
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
{
// lanes within a wave load different M but same K
static_assert(WarpSize % LanesPerK == 0);
@@ -175,6 +148,49 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
return lds_block_desc_issues_warps_lanes;
}
};
public:
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
return LdsStoreDescSelector<LanesPerK, WarpSize>::
template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
}
// template <typename Problem>