diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index c0ad4b5489..437bfee269 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -31,7 +31,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t N_Tile = 128; - constexpr ck_tile::index_t K_Tile = 64; + constexpr ck_tile::index_t K_Tile = 128; constexpr ck_tile::index_t M_Warp = 1; constexpr ck_tile::index_t N_Warp = 4; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 4b5f7c0062..831830157c 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -72,6 +72,47 @@ struct FlatmmPipelineAGmemBGmemCRegV1 return PipelinePolicy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + + constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; + constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; + constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + //constexpr index_t A_LDS_Read_Inst_Remain = A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num; + + static_for<0, A_LDS_Read_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + template CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, @@ -227,7 +268,8 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // LDS write i + 1 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window, a_block_tile_tmp); - + + HotLoopScheduler(); block_sync_lds(); // iCounter--; @@ -261,6 +303,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window, a_block_tile_tmp); + HotLoopScheduler(); block_sync_lds(); iCounter--;