From 56f84349caa85fcc45978a4dbbb604d379c1a54a Mon Sep 17 00:00:00 2001 From: root Date: Wed, 25 Jun 2025 20:41:25 -0500 Subject: [PATCH] 1.fix loop_num=odd bug 2.optimize mi300 performance of big MNK(tilesize 128x128x128) 3.optimize decode perf on mi300 --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 48 +-- example/ck_tile/18_flatmm/flatmm_basic.hpp | 6 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 285 +++++++++++++++--- 3 files changed, 271 insertions(+), 68 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index e1a3b6054c..92a2c37ba8 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -145,8 +145,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con } }; - if(has_hot_loop) - { + // if(has_hot_loop) + // { if(tail_num == ck_tile::TailNumber::Odd) { RunSplitk(ck_tile::bool_constant{}, @@ -165,28 +165,28 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } - } - else - { - if(tail_num == ck_tile::TailNumber::Odd) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else if(tail_num == ck_tile::TailNumber::Even) - { - RunSplitk(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - else - { - std::ostringstream err; - err << "Num K loop must be larger than number of prefetech stages." - << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages - << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - } + // } + // else + // { + // if(tail_num == ck_tile::TailNumber::Odd) + // { + // RunSplitk(ck_tile::bool_constant{}, + // ck_tile::integral_constant{}); + // } + // else if(tail_num == ck_tile::TailNumber::Even) + // { + // RunSplitk(ck_tile::bool_constant{}, + // ck_tile::integral_constant{}); + // } + // else + // { + // std::ostringstream err; + // err << "Num K loop must be larger than number of prefetech stages." + // << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + // << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + // } + // } return ave_time; } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp index 3d75975a04..aa04f3cdf4 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.hpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -161,9 +161,9 @@ struct GemmConfig static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; #elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16 - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; static constexpr ck_tile::index_t M_Warp = 1; static constexpr ck_tile::index_t N_Warp = 4; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 67c63553b5..613065432c 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -96,6 +96,39 @@ struct FlatmmPipelineAGmemBGmemCRegV1 static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; + /* + defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1 + defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1 + defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1 + defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1 + + defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1 + defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1 + defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1 + defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1 + + defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1 + defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 + */ + + #if (defined(USING_MFMA_16x16x32_F8) || \ + defined(USING_MFMA_32x32x16_F8) || \ + defined(USING_MFMA_16x16x16_F16) || \ + defined(USING_MFMA_32x32x8_F16)) // K1 per Mfma = 0.5 + static constexpr auto mfma_per_wg = 2; + static constexpr auto dsread_per_wg = 1; + #elif (defined(USING_MFMA_16x16x32_F16) || \ + defined(USING_MFMA_32x32x16_F16) || \ + defined(USING_MFMA_16x16x128_F4) || \ + defined(USING_MFMA_32x32x64_F4)) // K1 per Mfma = 1 + static constexpr auto mfma_per_wg = 1; + static constexpr auto dsread_per_wg = 1; + #elif (defined(USING_MFMA_16x16x128_F8) || \ + defined(USING_MFMA_32x32x64_F8)) // K1 per Mfma = 2 + static constexpr auto mfma_per_wg = 1; + static constexpr auto dsread_per_wg = 2; + #endif + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -191,7 +224,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // 0 M7N2: 63 - - 8 - // 0 M7N3: 64 4 - - - - #if 1 + #if 0 // MI350 FP8 16X16 128*256*256 static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -237,7 +270,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0); #endif - #if 0 + #if 0 // MI350 FP8 16X16 static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -273,6 +306,166 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0); #endif + #if 0 // MI300 FP8 16X16 128*128*128 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 // MI300 FP8 16X16 128*256*128 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 //MI300 FP8 16X16 16*64*256 + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + #endif } @@ -340,6 +533,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; const index_t iMWarp = get_warp_id() / NWarp; using CWarpDstr = typename WG::CWarpDstr; @@ -565,11 +759,14 @@ struct FlatmmPipelineAGmemBGmemCRegV1 block_sync_lds(); // preload A00,A10 from lds - statically_indexed_array{})(number<0>{}))), 2> a_warp_tensor_ping; - statically_indexed_array{})(number<0>{}))), 2> a_warp_tensor_pong; - - static_for<0, 2, 1>{}([&](auto mIter) { - a_warp_tensor_ping(mIter) = load_tile(a_warp_windows_ping(mIter)(number<0>{})); + constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2: 1; + statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_ping; + statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_pong; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); @@ -583,7 +780,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // } - index_t iCounter = num_loop / 2 - 1; + index_t iCounter = (num_loop - 1) / 2; // if constexpr(HasMainLoop) // { while(iCounter > 0) @@ -614,7 +811,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // GEMM 2i static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -661,15 +858,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter != KIterPerWarp - 1) || (mIter < (MIterPerWarp - 2))) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } @@ -680,8 +877,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - static_for<0, 2, 1>{}([&](auto mIter) { - a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{})); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); HotLoopScheduler(); @@ -713,7 +912,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // GEMM 2i+1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -759,15 +958,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); } //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } @@ -778,8 +977,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - static_for<0, 2, 1>{}([&](auto mIter) { - a_warp_tensor_ping(mIter) = load_tile(a_warp_windows_ping(mIter)(number<0>{})); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); HotLoopScheduler(); @@ -811,7 +1012,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // GEMM loopK-1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -819,7 +1020,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); @@ -851,15 +1052,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } @@ -869,8 +1070,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 TailHotLoopScheduler(); - static_for<0, 2, 1>{}([&](auto mIter) { - a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{})); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); // __builtin_amdgcn_sched_barrier(0); @@ -878,7 +1081,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // GEMM loopK static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -886,7 +1089,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); @@ -897,10 +1100,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 c_warp_tensor.get_thread_buffer()); __builtin_amdgcn_sched_barrier(0x7F6); }); - if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); } }); @@ -915,7 +1118,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // GEMM loopK static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; @@ -923,7 +1126,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); @@ -955,15 +1158,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); }