diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 1a2348810e..b8e923a52e 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -13,7 +13,7 @@ namespace ck_tile { template struct BaseFlatmmPipelineAGmemBGmemCRegV1 { - static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefetchStages = 2; CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -25,19 +25,23 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } template - CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) { - if (TailNumber::Even == tail_num) + if(TailNumber::Even == tail_num) { - return run_func(bool_constant{}, integral_constant{}); + return run_func(bool_constant{}, + integral_constant{}); } - else if (TailNumber::Odd == tail_num) + else if(TailNumber::Odd == tail_num) { - return run_func(bool_constant{}, integral_constant{}); + return run_func(bool_constant{}, + integral_constant{}); } // assert(false); return run_func(bool_constant{}, integral_constant{}); - // return run_func(bool_constant{}, integral_constant{}); + // return run_func(bool_constant{}, integral_constant{}); } }; @@ -56,8 +60,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV using BlockFlatmm = remove_cvref_t())>; - - static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto config = + BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -109,11 +114,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = 16 / sizeof(ADataType); - static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; - static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; - static constexpr index_t BloadGap = MIterPerWarp / 2; + static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; + static constexpr index_t BloadGap = MIterPerWarp / 2; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; @@ -145,21 +150,21 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1 if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 16 && - std::is_same_v) || - (warp_m == 16 && warp_n == 16 && warp_k == 16 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 8 && - std::is_same_v)) + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 16 && + std::is_same_v) || + (warp_m == 16 && warp_n == 16 && warp_k == 16 && + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 8 && + std::is_same_v)) { return {2, 1}; } // K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2 else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 64 && - std::is_same_v)) + std::is_same_v) || + (warp_m == 32 && warp_n == 32 && warp_k == 64 && + std::is_same_v)) { return {1, 2}; } @@ -227,73 +232,73 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // instruction schedule example(128X256X256, 1X4, 16X16X128): // Iter MNK MFMA ds_read ds_write A_load b_load // -1 M6N3: 60 2 - - - - // -1 M7N0: 61 - - - - - // -1 M7N1: 62 - - - - - // -1 M7N2: 63 - - - - - // -1 M7N3: 64 4 - - - - // 0 M0N0K0: 1 - - - - - // 0 M0N1: 2 - - - 2 - // 0 M0N2: 3 - - - - - // 0 M0N3: 4 6 - - - - // 0 M1N0: 5 - - - - - // 0 M1N1: 6 - - - 4 - // 0 M1N2: 7 - - - - - // 0 M1N3: 8 8 - - - - // 0 M2N0: 9 - - - - - // 0 M2N1: 10 - - - 6 - // 0 M2N2: 11 - - - - - // 0 M2N3: 12 10 - - - - // 0 M3N0: 13 - 1 - - - // 0 M3N1: 14 - - - 8 - // 0 M3N2: 15 - - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 - - - - + // -1 M7N2: 63 - - - - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 - - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 - - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 - - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 - - - 8 + // 0 M3N2: 15 - - - - // 0 M3N3: 16 12 - - - - // 0 M4N0: 17 - 2 - - - // 0 M4N1: 18 - - - - - // 0 M4N2: 19 - - 1 - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 - - - - + // 0 M4N2: 19 - - 1 - // 0 M4N3: 20 14 - - - - // 0 M5N0: 21 - 3 - - - // 0 M5N1: 22 - - - - - // 0 M5N2: 23 - - 2 - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 - - - - + // 0 M5N2: 23 - - 2 - // 0 M5N3: 24 16 - - - - // 0 M6N0: 25 - 4 - - - // 0 M6N1: 26 - - - - - // 0 M6N2: 27 - - 3 - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 - - - - + // 0 M6N2: 27 - - 3 - // 0 M6N3: 28 17 - - - - // 0 M7N0: 29 - - - - - // 0 M7N1: 30 - - - - - // 0 M7N2: 31 - - 4 - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 - - - - + // 0 M7N2: 31 - - 4 - // 0 M7N3: 32 18 - - - - // 0 M0N0K1: 33 - - - - - // 0 M0N1: 34 - - - 10 - // 0 M0N2: 35 - - - - - // 0 M0N3: 36 20 - - - - // 0 M1N0: 37 - - - - - // 0 M1N1: 38 - - - 12 - // 0 M1N2: 39 - - - - - // 0 M1N3: 40 22 - - - - // 0 M2N0: 41 - - - - - // 0 M2N1: 42 - - - 14 - // 0 M2N2: 43 - - - - - // 0 M2N3: 44 24 - - - - // 0 M3N0: 45 - 5 - - - // 0 M3N1: 46 - - - 16 - // 0 M3N2: 47 - - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 - - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 20 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 - - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 22 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 - - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 24 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 - - - 16 + // 0 M3N2: 47 - - - - // 0 M3N3: 48 26 - - - - // 0 M4N0: 49 - 6 - - - // 0 M4N1: 50 - - - - - // 0 M4N2: 51 - - 5 - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 - - - - + // 0 M4N2: 51 - - 5 - // 0 M4N3: 52 28 - - - - // 0 M5N0: 53 - 7 - - - // 0 M5N1: 54 - - - - - // 0 M5N2: 55 - - 6 - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 - - - - + // 0 M5N2: 55 - - 6 - // 0 M5N3: 56 30 - - - - // 0 M6N0: 57 - 8 - - - // 0 M6N1: 58 - - - - - // 0 M6N2: 59 - - 7 - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 - - - - + // 0 M6N2: 59 - - 7 - // 0 M6N3: 60 2 - - - - // 0 M7N0: 61 - - - - - // 0 M7N1: 62 - - - - - // 0 M7N2: 63 - - 8 - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 - - - - + // 0 M7N2: 63 - - 8 - // 0 M7N3: 64 4 - - - if constexpr(warp_m == 16 && warp_n == 16) { @@ -473,7 +478,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - + static_for<0, 1, 1>{}([&](auto i) { ignore = i; __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -531,7 +536,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA @@ -545,13 +550,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_barrier(0); } - } #endif + } } CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() { - #if 0 +#if 0 static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -593,7 +598,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); __builtin_amdgcn_sched_barrier(0); - #endif +#endif } template @@ -630,7 +635,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; __builtin_amdgcn_sched_barrier(0); - + // A tile in LDS ADataType* p_a_lds_ping = static_cast(p_smem_ping); ADataType* p_a_lds_pong = static_cast(p_smem_pong); @@ -638,11 +643,13 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); - auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); - // A DRAM tile window for load - #ifndef FINEGRADE_LOADSTORE +// A DRAM tile window for load +#ifndef FINEGRADE_LOADSTORE auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -657,10 +664,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); - #else + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); +#else auto a_copy_dram_window_tmp = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -673,49 +680,49 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV move_tile_window(a_copy_dram_window(AIter), {AIter * AcopyPerLoadM, 0}); }); - auto a_copy_lds_window_ping_tmp = make_tile_window( - a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramDistribution() - ); + auto a_copy_lds_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramDistribution()); - statically_indexed_array a_copy_lds_window_ping; + statically_indexed_array + a_copy_lds_window_ping; static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_copy_lds_window_ping(AIter) = a_copy_lds_window_ping_tmp; move_tile_window(a_copy_lds_window_ping(AIter), {AIter * AcopyPerLoadM, 0}); }); - auto a_copy_lds_window_pong_tmp = make_tile_window( - a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramDistribution() - ); + auto a_copy_lds_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramDistribution()); - statically_indexed_array a_copy_lds_window_pong; + statically_indexed_array + a_copy_lds_window_pong; static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_copy_lds_window_pong(AIter) = a_copy_lds_window_pong_tmp; move_tile_window(a_copy_lds_window_pong(AIter), {AIter * AcopyPerLoadM, 0}); }); - #endif +#endif // A LDS tile for block GEMM // auto a_lds_gemm_window = make_tile_window( // a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}); // ping-pong window for A LDS - auto a_warp_window_ping_tmp = make_tile_window( - a_lds_block_ping, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - auto a_warp_window_pong_tmp = make_tile_window( - a_lds_block_pong, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -726,7 +733,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV statically_indexed_array, MIterPerWarp> a_warp_windows_pong; - + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; @@ -776,19 +783,19 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV NIterPerWarp> b_warp_tensor_pong; - - // Prefetch A0 - #ifndef FINEGRADE_LOADSTORE +// Prefetch A0 +#ifndef FINEGRADE_LOADSTORE auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #else - statically_indexed_array{}))), ACopyLoadNum> a_block_tile; +#else + statically_indexed_array{}))), ACopyLoadNum> + a_block_tile; static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); }); - #endif +#endif // prefetch B static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -796,7 +803,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); @@ -815,29 +822,31 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // } // else // { - // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile)); + // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, + // a_block_tile)); // } - #ifndef FINEGRADE_LOADSTORE +#ifndef FINEGRADE_LOADSTORE auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); - #else +#else static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - store_tile(a_copy_lds_window_ping(AIter), tile_elementwise_in(a_element_func, a_block_tile(AIter))); + store_tile(a_copy_lds_window_ping(AIter), + tile_elementwise_in(a_element_func, a_block_tile(AIter))); }); - #endif +#endif __builtin_amdgcn_sched_barrier(0); - // Prefetch A1 - #ifndef FINEGRADE_LOADSTORE +// Prefetch A1 +#ifndef FINEGRADE_LOADSTORE a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #else +#else static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); }); - #endif +#endif // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -845,420 +854,493 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV block_sync_lds(); // preload A00,A10 from lds - constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2: 1; - statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_ping; - statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor_pong; - + constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_ping; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_pong; + static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); // if(threadIdx.x==0){ // for(int i=0;i(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size()); + // printf("dteng--A buffer load: idx.x=%u, ablocktile=%f, buffer size=%d\n", + // threadIdx.x, + // type_convert(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size()); // } // } // for(int i=0;i{}).get_thread_buffer_size();i++) { - // printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", threadIdx.x, type_convert(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size()); + // printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", + // threadIdx.x, + // type_convert(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size()); // } - index_t iCounter = (num_loop - 1) / 2; // if constexpr(HasMainLoop) // { - while(iCounter > 0) - { - #ifndef FINEGRADE_LOADSTORE - // prefetch B(2i+1) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + while(iCounter > 0) + { +#ifndef FINEGRADE_LOADSTORE + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); +#endif + + // GEMM 2i + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); - // Prefill A(2i+1) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - // Prefetch A(2i+2) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #endif - - // GEMM 2i - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(2i+1) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(2i+1) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - // Prefetch A(2i+2) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2))) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - a_block_tile(number{}) = load_tile(a_copy_dram_window(number{})); - move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) +#ifdef FINEGRADE_LOADSTORE + // prefetch B(2i+1) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + // Prefill A(2i+1) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) { - block_sync_lds(); + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_pong(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); } + // Prefetch A(2i+2) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && + (mIter < (MIterPerWarp - 1 + 1)) && + ((nIter % NIterPerWarp) == (NIterPerWarp - 2))) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + a_block_tile(number{}) = + load_tile(a_copy_dram_window(number{})); + move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); - //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + HotLoopScheduler(); + + // Next K + +#ifndef FINEGRADE_LOADSTORE + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); - HotLoopScheduler(); - - //Next K + }); - #ifndef FINEGRADE_LOADSTORE - // prefetch B(2i+2) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); +#endif + + // GEMM 2i+1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(2i+2) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); - // Prefetch A(2i+3) - a_block_tile = load_tile(a_copy_dram_window); - // move A window to next k - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - #endif - - // GEMM 2i+1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(2i+2) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(2i+1) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_ping(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - // Prefetch A(2i+2) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && (mIter < (MIterPerWarp - 1 + 1)) && ((nIter % NIterPerWarp)==(NIterPerWarp-2))) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - a_block_tile(number{}) = load_tile(a_copy_dram_window(number{})); - move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) +#ifdef FINEGRADE_LOADSTORE + // prefetch B(2i+2) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_ping(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + // Prefill A(2i+1) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) { - block_sync_lds(); + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_ping(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); } + // Prefetch A(2i+2) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && + (mIter < (MIterPerWarp - 1 + 1)) && + ((nIter % NIterPerWarp) == (NIterPerWarp - 2))) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + a_block_tile(number{}) = + load_tile(a_copy_dram_window(number{})); + move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); - }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } - // move B window to next flat K - move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); - HotLoopScheduler(); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); - iCounter--; - } + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); - // tail - if constexpr(TailNum == TailNumber::Even) - { - // __builtin_amdgcn_sched_barrier(0); - #ifndef FINEGRADE_LOADSTORE - // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + HotLoopScheduler(); + + iCounter--; + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { +// __builtin_amdgcn_sched_barrier(0); +#ifndef FINEGRADE_LOADSTORE + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); +#endif + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + +#ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_pong(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); - // Prefill A(loopK) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - #endif + TailHotLoopScheduler(); - // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); - } + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } + // __builtin_amdgcn_sched_barrier(0); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } }); - //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + }); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); - TailHotLoopScheduler(); + // TailHotLoopScheduler(); + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; - static_for<0, m_preload, 1>{}([&](auto loadIter) { - constexpr auto mIter = loadIter % MIterPerWarp; - constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); - }); + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // __builtin_amdgcn_sched_barrier(0); - - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); - }); - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + +#ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && + ((curMNIter % BloadGap) == 1)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window( + b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = + load_tile(b_flat_dram_windows(number{})(BkIter)); } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && + (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) + { + constexpr auto AIter = + (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % + ACopyLoadNum; + store_tile( + a_copy_lds_window_pong(number{}), + tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } +#endif + __builtin_amdgcn_sched_barrier(0x7F6); }); - }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } - // TailHotLoopScheduler(); - // __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(TailNum == TailNumber::Odd) - { - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifdef FINEGRADE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); - }); - // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) - { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); - } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) - { - block_sync_lds(); - } - }); + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); - } + }); + } // } return c_block_tile; @@ -1273,7 +1355,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem_ping,