From 6c9ffce86773b1e0df33cb9599609e8b0a26520f Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 9 Sep 2025 05:57:14 +0800 Subject: [PATCH] [CK_TILE] add atomic IGLP scheduler for wp gemm (#2739) * add atomic IGLP scheduler * clang format --------- Co-authored-by: ThomasNing [ROCm/composable_kernel commit: e4a77289031fe3841cefc9f678d655ce9ff7983c] --- .../03_gemm/gemm_weight_preshuffle.cpp | 2 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 708 ++++++++---------- 2 files changed, 320 insertions(+), 390 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 40e826f7dd..2b8f8b32ae 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -91,7 +91,7 @@ int main(int argc, char* argv[]) try { - return !run_gemm_example(arg_parser); + return !run_gemm_example(arg_parser); } catch(const std::runtime_error& e) { diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index d7749d031e..7104e318d2 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -43,6 +43,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 return run_func(bool_constant{}, integral_constant{}); } + return run_func(bool_constant{}, integral_constant{}); } }; @@ -69,7 +70,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 using WG = remove_cvref_t())>; + static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; @@ -129,13 +134,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); - static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; - static constexpr auto TailNum = Problem::TailNum; + static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + static constexpr auto TailNum = Problem::TailNum; - static constexpr auto warp_m = WarpTile::at(idxM); - static constexpr auto warp_n = WarpTile::at(idxN); - static constexpr auto warp_k = WarpTile::at(idxK); +#ifdef __gfx942__ + static constexpr index_t mfma_per_wg = 2; +#else + static constexpr index_t mfma_per_wg = 1; +#endif + static constexpr index_t dsread_per_wg = + WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize; + static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0); + + static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp; + static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); + static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; + static constexpr index_t Aload_num_perK = dswrite_num_perK; + static constexpr index_t Aload_rep = dswrite_rep; + static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize; + static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; + static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; + + static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg; + static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; + static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -160,411 +185,314 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 return PipelinePolicy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + // dsread_perM: how many LDS reads want to issue in this M-iter + // dswrite_perM: how many LDS writes you want to do this M-iter + // load_perM: how many global loads VMEM want to do in this M-iter + CK_TILE_HOST_DEVICE static constexpr auto + SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) { - constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); - constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; - constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; - constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + // Init inst order + index_t max_data_inst = dsread_perM > load_perM + ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) + : (load_perM > dswrite_perM ? load_perM : dswrite_perM); + index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM; + index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK); + constexpr int kOrderCap = NIterPerWarp * 10; + index_t inst_order[kOrderCap] = {}; + index_t index = 0; +#pragma unroll + // round-robin + // Index: 0 1 2 3 4 5 ... + // Value: 1 2 3 1 2 3 ... + for(int j = 0; j < max_data_inst; j++) + { + if(dswrite_perM > j) + { + inst_order[index] = 1; + index++; + } + if(load_perM > j) + { + inst_order[index] = 2; + index++; + } + if(dsread_perM > j) + { + inst_order[index] = 3; + index++; + } + } + +// Schedule IGLP +#pragma unroll + for(int j = 0; j < mfma_perM_perK; j++) + { + index_t inst_idx = 0; + if(j == 0) + ; + else if(j == 1) + inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2; + else if(j == 2) + inst_idx = mfma_perM_perK - 1; + else + inst_idx = mfma_perM_perK - j; + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + +#pragma unroll + for(int r = 0; r < round_data_inst; r++) + { + if(r % 2 == 0) + { + if(inst_order[inst_idx + r * mfma_perM_perK] == 1) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + else + { + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + } + } + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { // Keypoint of pipeline optimize is workload balance in time // instruction schedule example(128X256X256, 1X4, 16X16X128): // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N0: 57 - 8 - - + // -1 M6N1: 58 1 - - - + // -1 M6N2: 59 - - 7 - // -1 M6N3: 60 2 - - - // -1 M7N0: 61 - - - - - // -1 M7N1: 62 - - - - - // -1 M7N2: 63 - - - - + // -1 M7N1: 62 3 - - - + // -1 M7N2: 63 - - 8 - // -1 M7N3: 64 4 - - - - // 0 M0N0K0: 1 - - - - - // 0 M0N1: 2 - - - 2 - // 0 M0N2: 3 - - - - + // 0 M0N0K0: 1 - - - 1 + // 0 M0N1: 2 5 - - - + // 0 M0N2: 3 - - - 2 // 0 M0N3: 4 6 - - - - // 0 M1N0: 5 - - - - - // 0 M1N1: 6 - - - 4 - // 0 M1N2: 7 - - - - + // 0 M1N0: 5 - - - 3 + // 0 M1N1: 6 7 - - - + // 0 M1N2: 7 - - - 4 // 0 M1N3: 8 8 - - - - // 0 M2N0: 9 - - - - - // 0 M2N1: 10 - - - 6 - // 0 M2N2: 11 - - - - + // 0 M2N0: 9 - - - 5 + // 0 M2N1: 10 9 - - - + // 0 M2N2: 11 - - - 6 // 0 M2N3: 12 10 - - - - // 0 M3N0: 13 - 1 - - - // 0 M3N1: 14 - - - 8 - // 0 M3N2: 15 - - - - + // 0 M3N0: 13 - 1 - 7 + // 0 M3N1: 14 11 - - - + // 0 M3N2: 15 - - - 8 // 0 M3N3: 16 12 - - - // 0 M4N0: 17 - 2 - - - // 0 M4N1: 18 - - - - + // 0 M4N1: 18 13 - - - // 0 M4N2: 19 - - 1 - // 0 M4N3: 20 14 - - - // 0 M5N0: 21 - 3 - - - // 0 M5N1: 22 - - - - + // 0 M5N1: 22 15 - - - // 0 M5N2: 23 - - 2 - // 0 M5N3: 24 16 - - - // 0 M6N0: 25 - 4 - - - // 0 M6N1: 26 - - - - + // 0 M6N1: 26 17 - - - // 0 M6N2: 27 - - 3 - - // 0 M6N3: 28 17 - - - + // 0 M6N3: 28 18 - - - // 0 M7N0: 29 - - - - - // 0 M7N1: 30 - - - - + // 0 M7N1: 30 19 - - - // 0 M7N2: 31 - - 4 - - // 0 M7N3: 32 18 - - - - // 0 M0N0K1: 33 - - - - - // 0 M0N1: 34 - - - 10 - // 0 M0N2: 35 - - - - - // 0 M0N3: 36 20 - - - - // 0 M1N0: 37 - - - - - // 0 M1N1: 38 - - - 12 - // 0 M1N2: 39 - - - - - // 0 M1N3: 40 22 - - - - // 0 M2N0: 41 - - - - - // 0 M2N1: 42 - - - 14 - // 0 M2N2: 43 - - - - - // 0 M2N3: 44 24 - - - - // 0 M3N0: 45 - 5 - - - // 0 M3N1: 46 - - - 16 - // 0 M3N2: 47 - - - - - // 0 M3N3: 48 26 - - - + // 0 M7N3: 32 20 - - - + // 0 M0N0K1: 33 - - - 9 + // 0 M0N1: 34 21 - - - + // 0 M0N2: 35 - - - 10 + // 0 M0N3: 36 22 - - - + // 0 M1N0: 37 - - - 11 + // 0 M1N1: 38 23 - - - + // 0 M1N2: 39 - - - 12 + // 0 M1N3: 40 24 - - - + // 0 M2N0: 41 - - - 13 + // 0 M2N1: 42 25 - - - + // 0 M2N2: 43 - - - 14 + // 0 M2N3: 44 26 - - - + // 0 M3N0: 45 - 5 - 15 + // 0 M3N1: 46 27 - - - + // 0 M3N2: 47 - - - 16 + // 0 M3N3: 48 28 - - - // 0 M4N0: 49 - 6 - - - // 0 M4N1: 50 - - - - + // 0 M4N1: 50 29 - - - // 0 M4N2: 51 - - 5 - - // 0 M4N3: 52 28 - - - + // 0 M4N3: 52 30 - - - // 0 M5N0: 53 - 7 - - - // 0 M5N1: 54 - - - - + // 0 M5N1: 54 31 - - - // 0 M5N2: 55 - - 6 - - // 0 M5N3: 56 30 - - - + // 0 M5N3: 56 32 - - - // 0 M6N0: 57 - 8 - - - // 0 M6N1: 58 - - - - + // 0 M6N1: 58 1 - - - // 0 M6N2: 59 - - 7 - // 0 M6N3: 60 2 - - - // 0 M7N0: 61 - - - - - // 0 M7N1: 62 - - - - + // 0 M7N1: 62 3 - - - // 0 M7N2: 63 - - 8 - // 0 M7N3: 64 4 - - - - if constexpr(warp_m == 16 && warp_n == 16) +#pragma unroll + for(int kIter = 0; kIter < KIterPerWarp; kIter++) { -// MFMA -> VMEM READ -> MFMA -> DS Read -> MFMA -// hiding the glbal memory VMEM latency -#if defined(__gfx950__) - if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) +#pragma unroll + for(int mIter = 0; mIter < MIterPerWarp; mIter++) { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; - __builtin_amdgcn_sched_barrier(0); + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep + : 0) + + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep + : 0); + } + else + { + load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 + ? Aload_rep + : 0; + } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); } - else - { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } -// MFMA → MFMA → MFMA → MFMA → DS Read -// For other device engine we need more aggressive MFMA with DS writes interleaved -#else - if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) - { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - // Uses loops to amortize scheduling overhead - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256) - { - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128) - { - // prioritize MFMA to avoid LDS write conflicts - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA - }); - static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA - }); - } - -#endif } - else + // Add Aload when Aload data > needed + if(Aload_num_perK == 0) + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler() + { +#pragma unroll + for(int kIter = 0; kIter < KIterPerWarp; kIter++) { - if constexpr((A_LDS_Read_Inst_Num / 2 > - A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) +#pragma unroll + for(int mIter = 0; mIter < MIterPerWarp; mIter++) { - static_for<0, - A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - - B_Buffer_Load_Inst_Num, - 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; + + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep + : 0); + } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); } - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - }); - static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA - }); - __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA } + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler() + { +#pragma unroll + for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { +#pragma unroll + for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + dsread_perM = dsread_per_wg; + + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + // __builtin_amdgcn_sched_barrier(0); } template = 2) ? 2 : 1; statically_indexed_array{})(number<0>{}))), m_preload> - a_warp_tensor_ping; - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor_pong; + a_warp_tensor; static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = + a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); + // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; while(iCounter > 0) { @@ -792,7 +717,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // warp GEMM WG{}(c_warp_tensor, - a_warp_tensor_ping(number{}), + a_warp_tensor(number{}), b_warp_tensor_ping(nIter)(kIter)); // write C warp tensor into C block tensor @@ -809,7 +734,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = + a_warp_tensor(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } @@ -826,7 +751,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = + a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); HotLoopScheduler(); @@ -867,7 +792,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // warp GEMM WG{}(c_warp_tensor, - a_warp_tensor_pong(number{}), + a_warp_tensor(number{}), b_warp_tensor_pong(nIter)(kIter)); // write C warp tensor into C block tensor @@ -884,7 +809,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = + a_warp_tensor(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); } @@ -901,7 +826,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = + a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); HotLoopScheduler(); @@ -943,7 +868,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // warp GEMM WG{}(c_warp_tensor, - a_warp_tensor_ping(number{}), + a_warp_tensor(number{}), b_warp_tensor_ping(nIter)(kIter)); // write C warp tensor into C block tensor @@ -960,7 +885,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = + a_warp_tensor(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } @@ -976,11 +901,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = + a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); - // __builtin_amdgcn_sched_barrier(0); + Last2ndHotLoopScheduler(); // GEMM loopK static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -996,7 +921,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // warp GEMM WG{}(c_warp_tensor, - a_warp_tensor_pong(number{}), + a_warp_tensor(number{}), b_warp_tensor_pong(nIter)(kIter)); // write C warp tensor into C block tensor @@ -1004,19 +929,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); }); if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = + a_warp_tensor(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); } + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } }); }); - // TailHotLoopScheduler(); + LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { @@ -1034,7 +963,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // warp GEMM WG{}(c_warp_tensor, - a_warp_tensor_ping(number{}), + a_warp_tensor(number{}), b_warp_tensor_ping(nIter)(kIter)); // write C warp tensor into C block tensor @@ -1051,7 +980,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = + a_warp_tensor(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } @@ -1062,6 +991,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } }); }); + LastHotLoopScheduler(); } return c_block_tile;