[CK TILE] Fix KIterPerInnerLoop for block gemm. (#1934)

* Fix KIterPerInnerLoop

* Fix Kpack and KPerInnerLoop for block universal gemm.

* Fix overlooked spelling bugs.
This commit is contained in:
Adam Osewski
2025-03-05 23:17:44 +01:00
committed by GitHub
parent 9b51c08bf7
commit 4814db3905
3 changed files with 23 additions and 19 deletions

View File

@@ -136,7 +136,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
@@ -196,7 +196,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /

View File

@@ -252,7 +252,7 @@ struct UniversalGemmBasePolicy
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();