[CK_TILE] Fix flatmm on gfx11 and gfx12 (#2790)

1. Correct shuffle_b and MakeBFlatDramTileDistribution according to WMMA warp layout
2. Add FlatmmConfig16_Wmma for gfx11 and gfx12
This commit is contained in:
linqunAMD
2025-09-10 08:28:00 +08:00
committed by GitHub
parent 82890192dd
commit df4ee556d6
14 changed files with 224 additions and 67 deletions

View File

@@ -185,11 +185,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&

View File

@@ -237,8 +237,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
@@ -342,7 +346,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -350,8 +354,13 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
@@ -362,16 +371,15 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});