From 29bd9f40af1bfa548cd628024d1bd48242f974df Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Fri, 7 Feb 2025 09:00:44 +0100 Subject: [PATCH] CK Tile - small fix to hotloop scheduler & KPack value. (#1867) * Use SmemPack in HotLoop scheduler * Additional debug print information * Change KPack value. Hardcode for now, as without AK1/BK1 there's no good way to determine its value. * Fix HotLoopScheduler MFMA instr parameters. [ROCm/composable_kernel commit: 0c15de6a8353c97c289ebc5094752d6e002f0ea4] --- .../impl/device_gemm_xdl_cshuffle_v3.hpp | 5 +- include/ck/utility/blkgemmpipe_scheduler.hpp | 12 ++- .../block/block_universal_gemm_as_bs_cr.hpp | 5 +- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 77 ++++++++++++++++--- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 1 - 5 files changed, 86 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp index 600f12139d..1c14496650 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 0) { arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); } if(!GridwiseGemm::CheckValidity(arg)) @@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +#include + #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" @@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 return Policy::template GetSmemSize(); } + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + template struct PipelineImpl : public PipelineImplBase { @@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { - constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); - constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); - constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; constexpr index_t WaveSize = 64; constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); - constexpr index_t A_LDS_Read_Width = KPerXDL; - constexpr index_t B_LDS_Read_Width = KPerXDL; + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB(); constexpr index_t A_Buffer_Load_Inst_Num = MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); constexpr index_t B_Buffer_Load_Inst_Num = NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); - constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); - constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); constexpr index_t A_LDS_Read_Inst_Num = - WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); constexpr index_t B_LDS_Read_Inst_Num = - WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index feed32a439..2a9683b36e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;