From 07ec843015e12e7686235ec5eb4b560c6d741096 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Wed, 5 Mar 2025 23:17:44 +0100 Subject: [PATCH] [CK TILE] Fix KIterPerInnerLoop for block gemm. (#1934) * Fix KIterPerInnerLoop * Fix Kpack and KPerInnerLoop for block universal gemm. * Fix overlooked spelling bugs. [ROCm/composable_kernel commit: 4814db39054691c7e72ddc893ba68fe3ae2c5df8] --- .../block/block_universal_gemm_as_bs_cr.hpp | 36 ++++++++++--------- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 4 +-- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 2 +- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 6024e00419..38ed108f6d 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -68,15 +68,19 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN; static constexpr index_t KPerBlockPerIter = WarpGemm::kK; - // TODO: Should we have two policies? Interwave & Intrawave ?? + // Controls how many MAC clusters (MFMA blocks) we have per wave + // Ie if + // InterWaveSchedulingMacClusters = 1; + // KPerBlock == 32 + // WarpGemm::kK = 8 + // Then we would group all 4 WarpGemms into single MAC cluster. + // But if we would set InterWaveSchedulingMacClusters = 2, then we would + // split those 4 warp gemms into two groups. static constexpr index_t InterWaveSchedulingMacClusters = 1; // should be at least equal to: WarpGemm::Impl::kABKPerLane - // and the question is how to assess upper limit or exact value? - // TODO: Should we introduce AK1/BK1 parameters ? - static constexpr index_t KPack = 8; - static constexpr index_t KPerThread = KIterPerWarp * KPack; - static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; }; public: @@ -129,11 +133,12 @@ struct BlockUniversalGemmAsBsCr { constexpr index_t KPerThread = Traits::KPerThread; constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); - constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; using KIterSeq = std::conditional_t, + sequence, sequence>; constexpr auto a_block_outer_dstr_encoding = @@ -153,11 +158,12 @@ struct BlockUniversalGemmAsBsCr { constexpr index_t KPerThread = Traits::KPerThread; constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); - constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; using KIterSeq = std::conditional_t, + sequence, sequence>; constexpr auto b_block_outer_dstr_encoding = @@ -371,11 +377,9 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KPerThread = GemmTraits::KPerThread; static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; static constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack); - // TODO: do we really need this?? Are there any cases when this would be >=1 ?? - // Would we need InterWaveSchedulingMacClusters > 1 ??? + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; - static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 1e3694d24c..71d8ef1b3d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -136,7 +136,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t A_LDS_Read_Inst_Num = WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); @@ -196,7 +196,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t A_LDS_Read_Inst_Num = WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index fd1e76a02b..f5b3523f60 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -252,7 +252,7 @@ struct UniversalGemmBasePolicy using ALayout = remove_cvref_t; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeA();