From cc4122c8133056f5009c2c177dcf4e0764804745 Mon Sep 17 00:00:00 2001 From: Aleksander Dudek Date: Mon, 23 Jun 2025 05:52:13 -0500 Subject: [PATCH] Review changes --- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 58e7075b0b..0dccfc9b82 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -73,6 +73,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using BlockGemmShape = remove_cvref_t; static_assert(!std::is_same_v, "Not implemented"); + static_assert(!std::is_same_v, "Not implemented"); static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; @@ -103,9 +104,6 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } - // TODO check KRepeat - static constexpr index_t KRepeat = KPerBlock / GetSmemPackA(); - static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; @@ -190,19 +188,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 constexpr auto ds_read_b_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); - constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); - constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); - constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; - constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; - constexpr auto num_dsread_stage1_a_mfma = - (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; constexpr auto num_dsread_stage1_b_mfma = - (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; constexpr auto num_dsread_stage3_a_mfma = - (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; constexpr auto num_dsread_stage3_b_mfma = - (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num - num_ds_read_inst_a / ds_read_a_mfma_rate - @@ -215,7 +208,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // stage 1 static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { ignore = i; - if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read @@ -224,14 +217,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { __builtin_amdgcn_sched_group_barrier( 0x100, - num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + num_ds_read_inst_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, 0); // DS read } __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { ignore = i; - if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= ds_read_b_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read @@ -240,7 +233,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { __builtin_amdgcn_sched_group_barrier( 0x100, - num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + num_ds_read_inst_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, 0); // DS read } __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -273,7 +266,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // stage 3 static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { ignore = i; - if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= ds_read_a_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read @@ -282,14 +275,14 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { __builtin_amdgcn_sched_group_barrier( 0x100, - num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + num_ds_read_inst_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, 0); // DS read } __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA }); static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { ignore = i; - if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= ds_read_b_mfma_rate) { __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read @@ -298,7 +291,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 { __builtin_amdgcn_sched_group_barrier( 0x100, - num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + num_ds_read_inst_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, 0); // DS read } __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -397,16 +390,16 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // A DRAM tile window for load // A LDS tile window for store // A LDS tile for block GEMM - acopy_dram_type& a_copy_dram_window = aWindows.at(I0); - a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1); - a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2); + auto& a_copy_dram_window = aWindows.at(I0); + auto& a_copy_lds_window = aWindows.at(I1); + auto& a_lds_gemm_window = aWindows.at(I2); // B DRAM tile window for load // B LDS tile window for store // B LDS tile for block GEMM - bcopy_dram_type& b_copy_dram_window = bWindows.at(I0); - b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1); - b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2); + auto& b_copy_dram_window = bWindows.at(I0); + auto& b_copy_lds_window = bWindows.at(I1); + auto& b_lds_gemm_window = bWindows.at(I2); // Block GEMM auto block_gemm = BlockGemm();