[CK TILE] Fix KIterPerInnerLoop for block gemm. (#1934)

* Fix KIterPerInnerLoop

* Fix Kpack and KPerInnerLoop for block universal gemm.

* Fix overlooked spelling bugs.
This commit is contained in:
Adam Osewski
2025-03-05 23:17:44 +01:00
committed by GitHub
parent 9b51c08bf7
commit 4814db3905
3 changed files with 23 additions and 19 deletions

View File

@@ -68,15 +68,19 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
// TODO: Should we have two policies? Interwave & Intrawave ??
// Controls how many MAC clusters (MFMA blocks) we have per wave
// Ie if
// InterWaveSchedulingMacClusters = 1;
// KPerBlock == 32
// WarpGemm::kK = 8
// Then we would group all 4 WarpGemms into single MAC cluster.
// But if we would set InterWaveSchedulingMacClusters = 2, then we would
// split those 4 warp gemms into two groups.
static constexpr index_t InterWaveSchedulingMacClusters = 1;
// should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
};
public:
@@ -129,11 +133,12 @@ struct BlockUniversalGemmAsBsCr
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterWave>,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto a_block_outer_dstr_encoding =
@@ -153,11 +158,12 @@ struct BlockUniversalGemmAsBsCr
{
constexpr index_t KPerThread = Traits::KPerThread;
constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack);
constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK;
constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
sequence<KIterInterWave>,
sequence<KIterInterwave>,
sequence<KIterPerWarp>>;
constexpr auto b_block_outer_dstr_encoding =
@@ -371,11 +377,9 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread;
static constexpr auto ALdsTileDistr =
decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){};

View File

@@ -136,7 +136,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
@@ -196,7 +196,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /

View File

@@ -252,7 +252,7 @@ struct UniversalGemmBasePolicy
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();