diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 6024e00419..38ed108f6d 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -68,15 +68,19 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN; static constexpr index_t KPerBlockPerIter = WarpGemm::kK; - // TODO: Should we have two policies? Interwave & Intrawave ?? + // Controls how many MAC clusters (MFMA blocks) we have per wave + // Ie if + // InterWaveSchedulingMacClusters = 1; + // KPerBlock == 32 + // WarpGemm::kK = 8 + // Then we would group all 4 WarpGemms into single MAC cluster. + // But if we would set InterWaveSchedulingMacClusters = 2, then we would + // split those 4 warp gemms into two groups. static constexpr index_t InterWaveSchedulingMacClusters = 1; // should be at least equal to: WarpGemm::Impl::kABKPerLane - // and the question is how to assess upper limit or exact value? - // TODO: Should we introduce AK1/BK1 parameters ? - static constexpr index_t KPack = 8; - static constexpr index_t KPerThread = KIterPerWarp * KPack; - static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPack = WarpGemm::kKPerThread; + static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread; }; public: @@ -129,11 +133,12 @@ struct BlockUniversalGemmAsBsCr { constexpr index_t KPerThread = Traits::KPerThread; constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); - constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; using KIterSeq = std::conditional_t, + sequence, sequence>; constexpr auto a_block_outer_dstr_encoding = @@ -153,11 +158,12 @@ struct BlockUniversalGemmAsBsCr { constexpr index_t KPerThread = Traits::KPerThread; constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = ck_tile::max(KPerThread / NumMacClusters, Traits::KPack); - constexpr index_t KIterInterWave = KPerInnerLoop / WarpGemm::kK; + constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; using KIterSeq = std::conditional_t, + sequence, sequence>; constexpr auto b_block_outer_dstr_encoding = @@ -371,11 +377,9 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KPerThread = GemmTraits::KPerThread; static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; static constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack); - // TODO: do we really need this?? Are there any cases when this would be >=1 ?? - // Would we need InterWaveSchedulingMacClusters > 1 ??? + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; - static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 1e3694d24c..71d8ef1b3d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -136,7 +136,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t A_LDS_Read_Inst_Num = WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); @@ -196,7 +196,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr index_t A_LDS_Read_Inst_Num = WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index fd1e76a02b..f5b3523f60 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -252,7 +252,7 @@ struct UniversalGemmBasePolicy using ALayout = remove_cvref_t; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeA();