mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE]fix ck_tile's moe_sorting example in gfx11 (#2667)
* fix ck_tile's moe_sorting example in gfx11 * fix clang format --------- Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -63,48 +63,15 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
|
||||
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
private:
|
||||
template <index_t LanesPerK, index_t WarpSize, typename = void>
|
||||
struct LdsStoreDescSelector;
|
||||
|
||||
template <index_t LanesPerK, index_t WarpSize>
|
||||
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK >= WarpSize)>>
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using CDataType = float;
|
||||
constexpr auto c_block_dstr = MakeCBlockDist();
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
@@ -143,7 +110,13 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
else
|
||||
};
|
||||
|
||||
template <index_t LanesPerK, index_t WarpSize>
|
||||
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK < WarpSize)>>
|
||||
{
|
||||
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
@@ -175,6 +148,49 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using CDataType = float;
|
||||
constexpr auto c_block_dstr = MakeCBlockDist();
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
|
||||
return LdsStoreDescSelector<LanesPerK, WarpSize>::
|
||||
template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
|
||||
Reference in New Issue
Block a user