Add permuteN optimzization when NRepeat % 2 == 0 on flatmm

This commit is contained in:
Feng Shijie
2025-07-27 11:57:38 +00:00
parent bfb9f4002f
commit 5473f06461
5 changed files with 228 additions and 104 deletions

View File

@@ -41,6 +41,7 @@ struct FlatmmConfig32
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
};
template <typename DataType>
@@ -78,6 +79,9 @@ struct FlatmmConfig16
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename DataType>
@@ -87,6 +91,10 @@ struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename ADataType>