From a5877011175aa939dc53f4493f5c0d247bafa7f1 Mon Sep 17 00:00:00 2001 From: AMD-dteng Date: Tue, 29 Jul 2025 22:48:00 +0800 Subject: [PATCH] update pipeline v1: add atomic IGLP schedule --- example/ck_tile/18_flatmm/flatmm_basic.cpp | 4 +- example/ck_tile/18_flatmm/flatmm_basic.hpp | 4 +- .../ck_tile/18_flatmm/run_flatmm_example.inc | 5 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1227 ++++++----------- 4 files changed, 429 insertions(+), 811 deletions(-) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 2cac0cd9b6..85f2ce482d 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -155,7 +155,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV0; + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; @@ -182,7 +182,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, tail_number_v>; using CodegenFlatmmPipeline = - ck_tile::FlatmmPipelineAGmemBGmemCRegV0; + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem{-.5f, .5f}(a_host); - memset(a_host.data(), 0, 4); + // ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + // ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_host); ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); ck_tile::FillUniformDistribution{-1.f, 1.f}(per_token_scale); ck_tile::FillUniformDistribution{-1.f, 1.f}(per_channel_scale); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index aa39e2c779..169cd29f77 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -3,7 +3,6 @@ #pragma once -// #define FINEGRADE_LOADSTORE #include "ck_tile/core.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" @@ -25,30 +24,23 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } template - CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) { - if(TailNumber::Even == tail_num) + if (TailNumber::Even == tail_num) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func(bool_constant{}, integral_constant{}); } - else if(TailNumber::Odd == tail_num) + else if (TailNumber::Odd == tail_num) { - return run_func(bool_constant{}, - integral_constant{}); + return run_func(bool_constant{}, integral_constant{}); } - // assert(false); return run_func(bool_constant{}, integral_constant{}); - // return run_func(bool_constant{}, integral_constant{}); } }; template -struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV1 +struct FlatmmPipelineAGmemBGmemCRegV1 { - using Base = BaseFlatmmPipelineAGmemBGmemCRegV1; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -60,13 +52,16 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV using BlockFlatmm = remove_cvref_t())>; - - static constexpr auto config = - BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; + static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 + static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; @@ -75,14 +70,9 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; - static constexpr index_t GetVectorSizeA() - { - return PipelinePolicy::template GetVectorSizeA(); - } - static constexpr index_t GetVectorSizeB() - { - return PipelinePolicy::template GetVectorSizeB(); - } + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; @@ -114,18 +104,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = 16 / sizeof(ADataType); - static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; - static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; - static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; - static constexpr index_t BloadGap = MIterPerWarp / 2; + static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload: MIterPerWarp * KIterPerWarp; static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr auto TailNum = Problem::TailNum; - static constexpr auto warp_m = WarpTile::at(idxM); - static constexpr auto warp_n = WarpTile::at(idxN); - static constexpr auto warp_k = WarpTile::at(idxK); /* defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1 defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1 @@ -140,56 +124,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1 defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1 */ - struct MfmaConfig - { - int mfma_per_wg; - int dsread_per_wg; - }; - static constexpr MfmaConfig GetMfmaConfig() - { - - // K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1 - if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 16 && - std::is_same_v) || - (warp_m == 16 && warp_n == 16 && warp_k == 16 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 8 && - std::is_same_v)) - { - return {2, 1}; - } - // K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2 - else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 64 && - std::is_same_v)) - { - return {1, 2}; - } - // K1 per Mfma = 1 cases: mfma_per_wg = 1, dsread_per_wg = 1 - else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 && - std::is_same_v) || - (warp_m == 32 && warp_n == 32 && warp_k == 16 && - std::is_same_v) || - (warp_m == 16 && warp_n == 16 && warp_k == 128 /*&& - std::is_same_v */) || - (warp_m == 32 && warp_n == 32 && warp_k == 64 /*&& - std::is_same_v */)) - { - return {1, 1}; - } - // Default configuration - else - { - return {1, 1}; - } - } - - static constexpr auto mfma_config = GetMfmaConfig(); - static constexpr auto mfma_per_wg = mfma_config.mfma_per_wg; - static constexpr auto dsread_per_wg = mfma_config.dsread_per_wg; // #if (defined(USING_MFMA_16x16x32_F8) || \ // defined(USING_MFMA_32x32x16_F8) || \ @@ -208,17 +142,41 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // static constexpr auto mfma_per_wg = 1; // static constexpr auto dsread_per_wg = 2; // #endif + #ifdef __gfx942__ + static constexpr index_t mfma_per_wg = 2; + #else + static constexpr index_t mfma_per_wg = 1; + #endif + static constexpr index_t dsread_per_wg = WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize; + static_assert((WG::kM * WG::kK * sizeof(ADataType) / WaveSize) % Problem::VectorLoadSize == 0); + + static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp; + static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp); + static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; + static constexpr index_t Aload_num_perK = dswrite_num_perK; + static constexpr index_t Aload_rep = dswrite_rep; + static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize; + static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; + static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; + + static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg; + static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; + static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV1", concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), - concat('x', GetVectorSizeA(), GetVectorSizeB()), + concat('x', WG::kM, WG::kN, WG::kK), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), concat('x', kPadM, kPadN, kPadK)); // clang-format on } + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -226,380 +184,308 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV return PipelinePolicy::template GetSmemSize(); } + CK_TILE_HOST_DEVICE static constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) + { + // Init inst order + index_t max_data_inst = + dsread_perM > load_perM + ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) + : (load_perM > dswrite_perM ? load_perM : dswrite_perM); + index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM; + index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK; + + index_t inst_order[NIterPerWarp * 10]; + #pragma unroll + for(int idx = 0; idx < NIterPerWarp * 10; idx++) + { + inst_order[idx] = 0; + } + + index_t index = 0; + #pragma unroll + for(int j = 0; j < max_data_inst; j++) + { + if(dswrite_perM > j) + { + inst_order[index] = 1; + index++; + } + if(load_perM > j) + { + inst_order[index] = 2; + index++; + } + if(dsread_perM > j) + { + inst_order[index] = 3; + index++; + } + } + + // Schedule IGLP + #pragma unroll + for(int j = 0; j < mfma_perM_perK; j++) + { + index_t inst_idx = 0; + if(j == 0) + ; + else if(j == 1) + inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2; + else if(j == 2) + inst_idx = mfma_perM_perK - 1; + else + inst_idx = mfma_perM_perK - j; + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + #pragma unroll + for(int r = 0; r < round_data_inst; r++) + { + if(r % 2 == 0) + { + if(inst_order[inst_idx + r * mfma_perM_perK] == 1) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + else + { + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + } + } + } CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { // Keypoint of pipeline optimize is workload balance in time // instruction schedule example(128X256X256, 1X4, 16X16X128): // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N0: 57 - 8 - - + // -1 M6N1: 58 1 - - - + // -1 M6N2: 59 - - 7 - // -1 M6N3: 60 2 - - - - // -1 M7N0: 61 - - - - - // -1 M7N1: 62 - - - - - // -1 M7N2: 63 - - - - - // -1 M7N3: 64 4 - - - - // 0 M0N0K0: 1 - - - - - // 0 M0N1: 2 - - - 2 - // 0 M0N2: 3 - - - - - // 0 M0N3: 4 6 - - - - // 0 M1N0: 5 - - - - - // 0 M1N1: 6 - - - 4 - // 0 M1N2: 7 - - - - - // 0 M1N3: 8 8 - - - - // 0 M2N0: 9 - - - - - // 0 M2N1: 10 - - - 6 - // 0 M2N2: 11 - - - - - // 0 M2N3: 12 10 - - - - // 0 M3N0: 13 - 1 - - - // 0 M3N1: 14 - - - 8 - // 0 M3N2: 15 - - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 3 - - - + // -1 M7N2: 63 - - 8 - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - 1 + // 0 M0N1: 2 5 - - - + // 0 M0N2: 3 - - - 2 + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - 3 + // 0 M1N1: 6 7 - - - + // 0 M1N2: 7 - - - 4 + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - 5 + // 0 M2N1: 10 9 - - - + // 0 M2N2: 11 - - - 6 + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - 7 + // 0 M3N1: 14 11 - - - + // 0 M3N2: 15 - - - 8 // 0 M3N3: 16 12 - - - - // 0 M4N0: 17 - 2 - - - // 0 M4N1: 18 - - - - - // 0 M4N2: 19 - - 1 - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 13 - - - + // 0 M4N2: 19 - - 1 - // 0 M4N3: 20 14 - - - - // 0 M5N0: 21 - 3 - - - // 0 M5N1: 22 - - - - - // 0 M5N2: 23 - - 2 - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 15 - - - + // 0 M5N2: 23 - - 2 - // 0 M5N3: 24 16 - - - - // 0 M6N0: 25 - 4 - - - // 0 M6N1: 26 - - - - - // 0 M6N2: 27 - - 3 - - // 0 M6N3: 28 17 - - - - // 0 M7N0: 29 - - - - - // 0 M7N1: 30 - - - - - // 0 M7N2: 31 - - 4 - - // 0 M7N3: 32 18 - - - - // 0 M0N0K1: 33 - - - - - // 0 M0N1: 34 - - - 10 - // 0 M0N2: 35 - - - - - // 0 M0N3: 36 20 - - - - // 0 M1N0: 37 - - - - - // 0 M1N1: 38 - - - 12 - // 0 M1N2: 39 - - - - - // 0 M1N3: 40 22 - - - - // 0 M2N0: 41 - - - - - // 0 M2N1: 42 - - - 14 - // 0 M2N2: 43 - - - - - // 0 M2N3: 44 24 - - - - // 0 M3N0: 45 - 5 - - - // 0 M3N1: 46 - - - 16 - // 0 M3N2: 47 - - - - - // 0 M3N3: 48 26 - - - - // 0 M4N0: 49 - 6 - - - // 0 M4N1: 50 - - - - - // 0 M4N2: 51 - - 5 - - // 0 M4N3: 52 28 - - - - // 0 M5N0: 53 - 7 - - - // 0 M5N1: 54 - - - - - // 0 M5N2: 55 - - 6 - - // 0 M5N3: 56 30 - - - - // 0 M6N0: 57 - 8 - - - // 0 M6N1: 58 - - - - - // 0 M6N2: 59 - - 7 - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 17 - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 18 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 19 - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 20 - - - + // 0 M0N0K1: 33 - - - 9 + // 0 M0N1: 34 21 - - - + // 0 M0N2: 35 - - - 10 + // 0 M0N3: 36 22 - - - + // 0 M1N0: 37 - - - 11 + // 0 M1N1: 38 23 - - - + // 0 M1N2: 39 - - - 12 + // 0 M1N3: 40 24 - - - + // 0 M2N0: 41 - - - 13 + // 0 M2N1: 42 25 - - - + // 0 M2N2: 43 - - - 14 + // 0 M2N3: 44 26 - - - + // 0 M3N0: 45 - 5 - 15 + // 0 M3N1: 46 27 - - - + // 0 M3N2: 47 - - - 16 + // 0 M3N3: 48 28 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 29 - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 30 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 31 - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 32 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 1 - - - + // 0 M6N2: 59 - - 7 - // 0 M6N3: 60 2 - - - - // 0 M7N0: 61 - - - - - // 0 M7N1: 62 - - - - - // 0 M7N2: 63 - - 8 - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 3 - - - + // 0 M7N2: 63 - - 8 - // 0 M7N3: 64 4 - - - - if constexpr(warp_m == 16 && warp_n == 16) + + #pragma unroll + for(int kIter = 0; kIter < KIterPerWarp; kIter++) { -#if defined(__gfx950__) // MI350 FP8 16X16 128*256*256 - if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + #pragma unroll + for(int mIter = 0; mIter < MIterPerWarp; mIter++) { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; - __builtin_amdgcn_sched_barrier(0); + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep + : 0) + + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep : 0); + } + else + { + load_perM = + (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep : 0; + } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); } - else - { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } -#else - if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128) - { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 2, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 128) - { - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 4, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - }); - __builtin_amdgcn_sched_barrier(0); - } - else if(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256) - { - static_for<0, 1, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_barrier(0); - } -#endif } + // Add Aload when Aload data > needed + if(Aload_num_perK == 0) + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler() + { + #pragma unroll + for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + #pragma unroll + for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; + + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep : 0); + } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + __builtin_amdgcn_sched_barrier(0); } + CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler() + { + #pragma unroll + for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + #pragma unroll + for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; - CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() - { -#if 0 - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + // Calculate ds_read number per M + if ((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + dsread_perM = dsread_per_wg; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_barrier(0); -#endif + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + // __builtin_amdgcn_sched_barrier(0); } template @@ -611,21 +497,15 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV void* p_smem_pong) const { static_assert( - std::is_same_v> && - std::is_same_v>, - "A/B Dram block window should have the same data type as appropriate " - "([A|B]DataType) defined in Problem definition!"); + std::is_same_v>, + "wrong!"); - constexpr bool is_a_col_major = std::is_same_v; - static_assert(is_a_col_major - ? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && - kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]) - : (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]), - "A block window has incorrect lengths for defined ALayout!"); + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; - const index_t iMWarp = get_warp_id() / NWarp; using CWarpDstr = typename WG::CWarpDstr; @@ -636,7 +516,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; __builtin_amdgcn_sched_barrier(0); - + // A tile in LDS ADataType* p_a_lds_ping = static_cast(p_smem_ping); ADataType* p_a_lds_pong = static_cast(p_smem_pong); @@ -644,13 +524,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV constexpr auto a_lds_block_desc = PipelinePolicy::template MakeALdsBlockDescriptor(); - auto a_lds_block_ping = - make_tensor_view(p_a_lds_ping, a_lds_block_desc); - auto a_lds_block_pong = - make_tensor_view(p_a_lds_pong, a_lds_block_desc); + auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); -// A DRAM tile window for load -#ifndef FINEGRADE_LOADSTORE + // A DRAM tile window for load auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -665,65 +542,22 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV auto a_copy_lds_window_pong = make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramTileDistribution()); -#else - auto a_copy_dram_window_tmp = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeADramDistribution()); - - statically_indexed_array a_copy_dram_window; - static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - a_copy_dram_window(AIter) = a_copy_dram_window_tmp; - move_tile_window(a_copy_dram_window(AIter), {AIter * AcopyPerLoadM, 0}); - }); - - auto a_copy_lds_window_ping_tmp = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramDistribution()); - - statically_indexed_array - a_copy_lds_window_ping; - static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - a_copy_lds_window_ping(AIter) = a_copy_lds_window_ping_tmp; - move_tile_window(a_copy_lds_window_ping(AIter), {AIter * AcopyPerLoadM, 0}); - }); - - auto a_copy_lds_window_pong_tmp = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {0, 0}, - PipelinePolicy::template MakeADramDistribution()); - - statically_indexed_array - a_copy_lds_window_pong; - static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - a_copy_lds_window_pong(AIter) = a_copy_lds_window_pong_tmp; - move_tile_window(a_copy_lds_window_pong(AIter), {AIter * AcopyPerLoadM, 0}); - }); -#endif - - // A LDS tile for block GEMM - // auto a_lds_gemm_window = make_tile_window( - // a_lds_block_ping, make_tuple(number{}, number{}), {0, 0}); + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); // ping-pong window for A LDS - auto a_warp_window_ping_tmp = - make_tile_window(a_lds_block_ping, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + auto a_warp_window_ping_tmp = make_tile_window( + a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); - auto a_warp_window_pong_tmp = - make_tile_window(a_lds_block_pong, - make_tuple(number{}, number{}), - {iMWarp * WG::kM, 0}, - make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + auto a_warp_window_pong_tmp = make_tile_window( + a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); statically_indexed_array< statically_indexed_array, @@ -734,7 +568,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV statically_indexed_array, MIterPerWarp> a_warp_windows_pong; - + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; @@ -784,19 +618,12 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV NIterPerWarp> b_warp_tensor_pong; -// Prefetch A0 -#ifndef FINEGRADE_LOADSTORE + + // HEAD + // Prefetch A0 auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); -#else - statically_indexed_array{}))), ACopyLoadNum> - a_block_tile; - static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); - move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); - }); -#endif // prefetch B static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -804,7 +631,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); @@ -823,80 +650,43 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // } // else // { - // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, - // a_block_tile)); + // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile)); // } -#ifndef FINEGRADE_LOADSTORE auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); -#else - static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - store_tile(a_copy_lds_window_ping(AIter), - tile_elementwise_in(a_element_func, a_block_tile(AIter))); - }); -#endif __builtin_amdgcn_sched_barrier(0); -// Prefetch A1 -#ifndef FINEGRADE_LOADSTORE + // Prefetch A1 a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); -#else - static_for<0, ACopyLoadNum, 1>{}([&](auto AIter) { - a_block_tile(AIter) = load_tile(a_copy_dram_window(AIter)); - move_tile_window(a_copy_dram_window(AIter), {0, kKPerBlock}); - }); -#endif // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); block_sync_lds(); - // preload A00,A10 from lds - constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1; - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor_ping; - statically_indexed_array{})(number<0>{}))), - m_preload> - a_warp_tensor_pong; - + // preload A00,A10... from lds + statically_indexed_array{})(number<0>{}))), m_preload> a_warp_tensor; + static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); __builtin_amdgcn_sched_barrier(0); - // if(threadIdx.x==0){ - // for(int i=0;i(a_block_tile.thread_buf_(i)),a_block_tile.get_thread_buffer_size()); - // } - // } - // for(int i=0;i{}).get_thread_buffer_size();i++) { - // printf("dteng--A lds load 00: idx.x=%u, awarptensor=%f, buffer size=%d\n", - // threadIdx.x, - // type_convert(a_warp_tensor_ping(number<0>{}).thread_buf_(i)),a_warp_tensor_ping(number<0>{}).get_thread_buffer_size()); - // } - + // MAIN LOOP index_t iCounter = (num_loop - 1) / 2; - // if constexpr(HasMainLoop) - // { while(iCounter > 0) { -#ifndef FINEGRADE_LOADSTORE // prefetch B(2i+1) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); @@ -910,8 +700,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); -#endif - + // GEMM 2i static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -919,81 +708,35 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor_ping(number{}), - b_warp_tensor_ping(nIter)(kIter)); - + WG{}(c_warp_tensor, a_warp_tensor(number{}), b_warp_tensor_ping(nIter)(kIter)); + // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - -#ifdef FINEGRADE_LOADSTORE - // prefetch B(2i+1) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && - ((curMNIter % BloadGap) == 1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window( - b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = - load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(2i+1) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && - (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) - { - constexpr auto AIter = - (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % - ACopyLoadNum; - store_tile( - a_copy_lds_window_pong(number{}), - tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - // Prefetch A(2i+2) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && - (mIter < (MIterPerWarp - 1 + 1)) && - ((nIter % NIterPerWarp) == (NIterPerWarp - 2))) - { - constexpr auto AIter = - (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % - ACopyLoadNum; - a_block_tile(number{}) = - load_tile(a_copy_dram_window(number{})); - move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); - } -#endif - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } - // barrier + //barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } }); }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -1001,26 +744,24 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); HotLoopScheduler(); - + // Next K -#ifndef FINEGRADE_LOADSTORE // prefetch B(2i+2) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); }); - + // Prefill A(2i+2) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); @@ -1029,7 +770,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); -#endif // GEMM 2i+1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -1041,77 +781,31 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor_pong(number{}), - b_warp_tensor_pong(nIter)(kIter)); - + WG{}(c_warp_tensor, a_warp_tensor(number{}), b_warp_tensor_pong(nIter)(kIter)); + // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - -#ifdef FINEGRADE_LOADSTORE - // prefetch B(2i+2) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && - ((curMNIter % BloadGap) == 1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window( - b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(number{})(BkIter) = - load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(2i+1) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && - (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) - { - constexpr auto AIter = - (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % - ACopyLoadNum; - store_tile( - a_copy_lds_window_ping(number{}), - tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } - // Prefetch A(2i+2) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK + 1)) && - (mIter < (MIterPerWarp - 1 + 1)) && - ((nIter % NIterPerWarp) == (NIterPerWarp - 2))) - { - constexpr auto AIter = - (mIter + ACopyLoadNumPerK + kIter * ACopyLoadNumPerK) % - ACopyLoadNum; - a_block_tile(number{}) = - load_tile(a_copy_dram_window(number{})); - move_tile_window(a_copy_dram_window(number{}), {0, kKPerBlock}); - } -#endif - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); } - // barrier + //barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } }); }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); // move B window to next flat K move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -1119,26 +813,23 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_ping(loadIter) = - load_tile(a_warp_windows_ping(number{})(number{})); + a_warp_tensor(loadIter) = load_tile(a_warp_windows_ping(number{})(number{})); }); HotLoopScheduler(); iCounter--; } - // tail + // TAIL if constexpr(TailNum == TailNumber::Even) { -// __builtin_amdgcn_sched_barrier(0); -#ifndef FINEGRADE_LOADSTORE // prefetch B(loopK) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); }); @@ -1147,7 +838,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV // Prefill A(loopK) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_pong, a_block_tile_tmp); -#endif // GEMM loopK-1 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { @@ -1156,81 +846,44 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor_ping(number{}), - b_warp_tensor_ping(nIter)(kIter)); - + WG{}(c_warp_tensor, a_warp_tensor(number{}), b_warp_tensor_ping(nIter)(kIter)); + // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - -#ifdef FINEGRADE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && - ((curMNIter % BloadGap) == 1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window( - b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = - load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && - (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) - { - constexpr auto AIter = - (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % - ACopyLoadNum; - store_tile( - a_copy_lds_window_pong(number{}), - tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } -#endif - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } - // barrier + //barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } }); }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); - - TailHotLoopScheduler(); static_for<0, m_preload, 1>{}([&](auto loadIter) { constexpr auto mIter = loadIter % MIterPerWarp; constexpr auto kIter = loadIter / MIterPerWarp; - a_warp_tensor_pong(loadIter) = - load_tile(a_warp_windows_pong(number{})(number{})); + a_warp_tensor(loadIter) = load_tile(a_warp_windows_pong(number{})(number{})); }); - - // __builtin_amdgcn_sched_barrier(0); - + + Last2ndHotLoopScheduler(); + // GEMM loopK static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -1238,37 +891,34 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor_pong(number{}), - b_warp_tensor_pong(nIter)(kIter)); - + WG{}(c_warp_tensor, a_warp_tensor(number{}), b_warp_tensor_pong(nIter)(kIter)); + // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); }); - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_pong(number{}) = - load_tile(a_warp_windows_pong(number{})(number{})); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + } + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); } }); }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); - - // TailHotLoopScheduler(); - // __builtin_amdgcn_sched_barrier(0); + LastHotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { @@ -1279,70 +929,37 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // read C warp tensor from C block tensor CWarpTensor c_warp_tensor; - + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + // warp GEMM - WG{}(c_warp_tensor, - a_warp_tensor_ping(number{}), - b_warp_tensor_ping(nIter)(kIter)); - + WG{}(c_warp_tensor, a_warp_tensor(number{}), b_warp_tensor_ping(nIter)(kIter)); + // write C warp tensor into C block tensor c_block_tile.set_y_sliced_thread_data( merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - -#ifdef FINEGRADE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && - ((curMNIter % BloadGap) == 1)) - { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window( - b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = - load_tile(b_flat_dram_windows(number{})(BkIter)); - } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && - (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp) == 0)) - { - constexpr auto AIter = - (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % - ACopyLoadNum; - store_tile( - a_copy_lds_window_pong(number{}), - tile_elementwise_in(a_element_func, a_block_tile(number{}))); - } -#endif - __builtin_amdgcn_sched_barrier(0x7F6); }); // preload next A from lds - if constexpr((kIter * MIterPerWarp + mIter) < - (KIterPerWarp * MIterPerWarp - m_preload)) + if constexpr((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) { - constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); - a_warp_tensor_ping(number{}) = - load_tile(a_warp_windows_ping(number{})(number{})); + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } - // barrier + //barrier if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) { block_sync_lds(); } }); }); + LastHotLoopScheduler(); } - // } return c_block_tile; } @@ -1356,7 +973,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](const ADataType& a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem_ping,