diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index dbb86cd7e8..540aebb3ed 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -67,6 +67,9 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; #endif + + using Traits = ck_tile::TileGemmTraits; + using CodegenFlatmmShape = ck_tile::TileFlatmmShape, ck_tile::sequence, @@ -74,15 +77,35 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + using CodegenGemmTraits = ck_tile::TileGemmTraits; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; - const auto Run = [&](const auto memory_operation_) { + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; constexpr auto memory_operation = memory_operation_.value; + using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; - if(args.k_batch == 1) + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + if(has_hot_loop) { - return Run(ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "For compute pipeline tail number should always be Full, but have \"" << tail_num + << "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } } else { - return Run(ck_tile::integral_constant{}); + if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + std::ostringstream err; + err << "Num K loop must be larger than number of prefetech stages." + << "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages + << "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } } + + return ave_time; } #include "run_flatmm_example.inc" diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index c10570bfae..67c63553b5 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -3,13 +3,29 @@ #pragma once -// #define BLOCKWISE_LOADSTORE +// #define FINEGRADE_LOADSTORE #include "ck_tile/core.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" namespace ck_tile { +template +struct BaseFlatmmPipelineAGmemBGmemCRegV1 +{ + static constexpr index_t PrefetchStages = 2; + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } +}; + template struct FlatmmPipelineAGmemBGmemCRegV1 { @@ -24,6 +40,10 @@ struct FlatmmPipelineAGmemBGmemCRegV1 using BlockFlatmm = remove_cvref_t())>; + + static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; static constexpr index_t BlockSize = Problem::kBlockSize; @@ -54,6 +74,28 @@ struct FlatmmPipelineAGmemBGmemCRegV1 using BlockWarps = remove_cvref_t; using WarpTile = remove_cvref_t; + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; + static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; + static constexpr index_t BloadGap = MIterPerWarp / 2; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off @@ -76,88 +118,167 @@ struct FlatmmPipelineAGmemBGmemCRegV1 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() { - #if 0 - static_for<0, 7, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 - - - - + // -1 M7N2: 63 - - - - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 - - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 - - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 - - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 - - - 8 + // 0 M3N2: 15 - - - - + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 - - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 - - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 - - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 17 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 - - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 18 - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 - - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 20 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 - - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 22 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 - - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 24 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 - - - 16 + // 0 M3N2: 47 - - - - + // 0 M3N3: 48 26 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 - - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 28 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 - - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 30 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 - - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 - - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - - static_for<0, 7, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read - - __builtin_amdgcn_sched_barrier(0); - #endif #if 1 - static_for<0, 2, 1>{}([&](auto j) { - ignore = j; - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - - static_for<0, 3, 1>{}([&](auto i) { - ignore = i; + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(0); + #endif + #if 0 + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); #endif } CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler() { + #if 0 static_for<0, 2, 1>{}([&](auto j) { ignore = j; static_for<0, 3, 1>{}([&](auto i) { @@ -199,6 +320,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); __builtin_amdgcn_sched_barrier(0); + #endif } template @@ -218,29 +340,6 @@ struct FlatmmPipelineAGmemBGmemCRegV1 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); - constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); - - using WG = remove_cvref_t())>; - - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); - - constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; - - constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; - constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; - - constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; - constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; - constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; - constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum; - constexpr index_t BloadGap = MIterPerWarp / 2; - const index_t iMWarp = get_warp_id() / NWarp; using CWarpDstr = typename WG::CWarpDstr; @@ -263,7 +362,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 auto a_lds_block_pong = make_tensor_view(p_a_lds_pong, a_lds_block_desc); // A DRAM tile window for load - #ifdef BLOCKWISE_LOADSTORE + #ifndef FINEGRADE_LOADSTORE auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), make_tuple(number{}, number{}), @@ -399,7 +498,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // Prefetch A0 - #ifdef BLOCKWISE_LOADSTORE + #ifndef FINEGRADE_LOADSTORE auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); @@ -438,7 +537,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 // { // store_tile(a_copy_lds_window_ping, tile_elementwise_in(a_element_func, a_block_tile)); // } - #ifdef BLOCKWISE_LOADSTORE + #ifndef FINEGRADE_LOADSTORE auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); #else @@ -449,7 +548,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 __builtin_amdgcn_sched_barrier(0); // Prefetch A1 - #ifdef BLOCKWISE_LOADSTORE + #ifndef FINEGRADE_LOADSTORE a_block_tile = load_tile(a_copy_dram_window); // move A window to next k move_tile_window(a_copy_dram_window, {0, kKPerBlock}); @@ -485,9 +584,11 @@ struct FlatmmPipelineAGmemBGmemCRegV1 index_t iCounter = num_loop / 2 - 1; - while(iCounter > 0) + // if constexpr(HasMainLoop) + // { + while(iCounter > 0) { - #ifdef BLOCKWISE_LOADSTORE + #ifndef FINEGRADE_LOADSTORE // prefetch B(2i+1) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -531,7 +632,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - #ifndef BLOCKWISE_LOADSTORE + #ifdef FINEGRADE_LOADSTORE // prefetch B(2i+1) constexpr auto curMNIter = mIter * NIterPerWarp + nIter; if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) @@ -586,7 +687,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 //Next K - #ifdef BLOCKWISE_LOADSTORE + #ifndef FINEGRADE_LOADSTORE // prefetch B(2i+2) static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -629,7 +730,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), c_warp_tensor.get_thread_buffer()); - #ifndef BLOCKWISE_LOADSTORE + #ifdef FINEGRADE_LOADSTORE // prefetch B(2i+2) constexpr auto curMNIter = mIter * NIterPerWarp + nIter; if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) @@ -685,129 +786,191 @@ struct FlatmmPipelineAGmemBGmemCRegV1 iCounter--; } - // tail - { - // __builtin_amdgcn_sched_barrier(0); - #ifdef BLOCKWISE_LOADSTORE - // prefetch B(loopK) - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; - - move_tile_window(b_flat_dram_windows(nIter)(kIter), - {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); - }); - }); - - // Prefill A(loopK) - a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window_pong, a_block_tile_tmp); - #endif - - // GEMM loopK-1 - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; + // tail + if constexpr(TailNum == TailNumber::Even) + { + // __builtin_amdgcn_sched_barrier(0); + #ifndef FINEGRADE_LOADSTORE + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - - #ifndef BLOCKWISE_LOADSTORE - // prefetch B(loopK) - constexpr auto curMNIter = mIter * NIterPerWarp + nIter; - if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + #endif + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = mIter % 2; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + #ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } + #endif + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) { - constexpr auto BnIter = curMNIter / BloadGap; - constexpr auto BkIter = kIter; - b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; - move_tile_window(b_flat_dram_windows(number{})(BkIter), - {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + constexpr auto AmIter = (mIter + 2) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); } - // Prefill A(loopK) - if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) - { - constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; - store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) + { + block_sync_lds(); } - #endif - __builtin_amdgcn_sched_barrier(0x7F6); }); - // preload next A from lds - if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) - { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); - a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); - } - - //barrier - if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) - { - block_sync_lds(); - } }); - }); - //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); + //block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_ping); - TailHotLoopScheduler(); + TailHotLoopScheduler(); - static_for<0, 2, 1>{}([&](auto mIter) { - a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{})); - }); + static_for<0, 2, 1>{}([&](auto mIter) { + a_warp_tensor_pong(mIter) = load_tile(a_warp_windows_pong(mIter)(number<0>{})); + }); - // __builtin_amdgcn_sched_barrier(0); - - // GEMM loopK - static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { - static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { - constexpr auto AwarpIter = mIter % 2; - static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - - c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - - // warp GEMM - WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); - __builtin_amdgcn_sched_barrier(0x7F6); + // __builtin_amdgcn_sched_barrier(0); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = mIter % 2; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_pong(number{}), b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); + }); + if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) + { + constexpr auto AmIter = (mIter + 2) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); + } }); - if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) - { - constexpr auto AmIter = (mIter + 2) % MIterPerWarp; - constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); - a_warp_tensor_pong(number{}) = load_tile(a_warp_windows_pong(number{})(number{})); - } }); - }); - // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); + // block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_pong); - // TailHotLoopScheduler(); - // __builtin_amdgcn_sched_barrier(0); - } + // TailHotLoopScheduler(); + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = mIter % 2; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor_ping(number{}), b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + #ifdef FINEGRADE_LOADSTORE + // prefetch B(loopK) + constexpr auto curMNIter = mIter * NIterPerWarp + nIter; + if constexpr((curMNIter < NIterPerWarp * BloadGap) && ((curMNIter % BloadGap)==1)) + { + constexpr auto BnIter = curMNIter / BloadGap; + constexpr auto BkIter = kIter; + b_flat_dram_windows(number{})(BkIter) = b_flat_dram_window; + move_tile_window(b_flat_dram_windows(number{})(BkIter), + {BnIter * NFlatPerBlockPerIter, BkIter * KFlatPerBlockPerIter}); + b_warp_tensor_pong(number{})(BkIter) = load_tile(b_flat_dram_windows(number{})(BkIter)); + } + // Prefill A(loopK) + if constexpr((mIter >= (MIterPerWarp - 1 - ACopyLoadNumPerK)) && (mIter < (MIterPerWarp - 1)) && ((nIter % NIterPerWarp)==0)) + { + constexpr auto AIter = (mIter + ACopyLoadNumPerK + 1 + kIter * ACopyLoadNumPerK) % ACopyLoadNum; + store_tile(a_copy_lds_window_pong(number{}), tile_elementwise_in(a_element_func, a_block_tile(number{}))); + } + #endif + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter!=KIterPerWarp-1)||(mIter<(MIterPerWarp-2))) + { + constexpr auto AmIter = (mIter + 2) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + 2) / MIterPerWarp); + a_warp_tensor_ping(number{}) = load_tile(a_warp_windows_ping(number{})(number{})); + } + + //barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == (MIterPerWarp - 2))) + { + block_sync_lds(); + } + }); + }); + } + // } return c_block_tile; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 0b38e7789e..7d463ec96f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -160,6 +160,148 @@ using GemmPipelineProblem = GemmPipelineProblemBase; +template +struct FlatmmPipelineProblem +{ + using Traits = remove_cvref_t; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool TransposeC = Traits::TransposeC; + + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); + + static constexpr bool kPadM = Traits::kPadM; + static constexpr bool kPadN = Traits::kPadN; + static constexpr bool kPadK = Traits::kPadK; + + static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer; + + static constexpr auto Scheduler = GemmPipelineScheduler::Default; + static constexpr index_t VectorLoadSize = Traits::_VectorSize; + + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm_problem", + concat('x', VectorLoadSize, kBlockSize), + concat('x', kPadM, kPadN, kPadK), + Scheduler); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() + { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType) + ? pixels_per_thread + : PackedSize * VectorLoadSize / sizeof(ADataType); + } + else + { + return VectorLoadSize / sizeof(ADataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() + { + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType) + ? pixels_per_thread + : PackedSize * VectorLoadSize / sizeof(BDataType); + } + else + { + return PackedSize * VectorLoadSize / sizeof(BDataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC() + { + if constexpr(std::is_same_v) + { + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size()); + constexpr index_t M0 = get_warp_size() / N2; + constexpr index_t M1 = BlockGemmShape::kM / M0; + + return std::min(M1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + else + { + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = BlockGemmShape::kN / N0; + + return std::min(N1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + } + + static constexpr index_t VectorSizeA = []() { + if constexpr(std::is_same_v) + { + return kPadK ? 1 : GetAlignmentA(); + } + else + { + return kPadM ? 1 : GetAlignmentA(); + } + }(); + + static constexpr index_t VectorSizeB = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentB(); + } + else + { + return kPadK ? 1 : GetAlignmentB(); + } + }(); + static constexpr index_t VectorSizeC = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentC(); + } + else + { + return kPadM ? 1 : GetAlignmentC(); + } + }(); +}; + template