From 8bf96c18c60681889a8e3c1de5bfbc46fc3a5984 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Fri, 1 Aug 2025 03:04:54 -0400 Subject: [PATCH] Integration of a new pipeline for weight preshuffle into gemm examples (#2516) * something khushbu can help with * v1 v2 works with flatmm develop * v0 v1 v2 numerical error gone * Fixing numerical error, and interchange preshuffle configs to match with flatmm * Refactor GEMM pipeline configurations and integrate preshuffle support - Updated preshuffle pipeline definitions to include multiple versions (V1, V2, V3). - Changed the pipeline constant from CK_TILE_PIPELINE_PRESHUFFLE to CK_TILE_PIPELINE_PRESHUFFLE_V3 in relevant configurations. - Removed obsolete code and comments * clang format * fix vectorloadsize bug * add the Preshuffle3 * update kwarp calculation in gemm utils * update vector size A and B correctly in V2 pipeline; Added few more changes to align with dteng's branch * fix: add CK_GFX950_SUPPORT macro for gfx950 detection * default disable rotating buffer * docs(CHANGELOG): update changelog for rocm 7.0 * Revert "docs(CHANGELOG): update changelog for rocm 7.0" This reverts commit 2bc16fff84a416b33b8a87692044fc4645fd2086. * Remove unused Preshuffle V3 pipeline and related code; update gemm function to use Preshuffle V2; clean up comments and formatting in various files. * revert example/ck_tile/flatmm to its original state * remove comment added by second author * switch to xor ALDSDescriptor * modify the MakeALdsDescriptor() * temporary profiling script * getting rid of line marker compiler error * UniversalWeightPreshufflePipelineAgBgCrPolicy now derives from UniversalGemmBasePolicy * add a minor fix for the config * typo fix * Fix formatting in lambda function for WeightPreshufflePipelineAGmemBGmemCRegV2 * revert change in include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp * revert change in include/ck_tile/core/arch/amd_buffer_addressing.hpp * reenable the GemmSpatiallyLocalTilePartitioner * make GemmConfigPreshuffle_1 for v1 pipeline, GemmConfigPreshuffle_2 for v2 pipeline * remove hardcoded true for preshuffle bool template argument * rename script * remove gemm_profilie.sh script * merge conflict resolve * clang formatted * typo fix * Remove duplicate include of block_gemm_areg_bsmem_creg_v2r1.hpp in gemm.hpp * Remove commented-out code in UniversalWeightPreshufflePipelineAgBgCrPolicy * Fix missing newline at end of file in run_gemm_example.inc * Remove unused barrier call in BlockWeightPreshuffleASmemBSmemCRegV1 * addressing review comments * removing debug code * addressing review comments * Revert "addressing review comments" This reverts commit 29c45192badc2371d78cfba9df4ed65148885b88. * updating tile_engine code * addressing review comments --------- Co-authored-by: amd-khushbu Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 1441a0a7eee2930c037d1c7cadde157e8eb3c476] --- example/ck_tile/03_gemm/CMakeLists.txt | 6 + example/ck_tile/03_gemm/gemm_utils.hpp | 33 +- .../03_gemm/gemm_weight_preshuffle.cpp | 4 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 1 + .../ops/flatmm/kernel/flatmm_kernel.hpp | 0 .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1 + include/ck_tile/ops/gemm.hpp | 3 +- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 0 .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 2 +- ...pipeline_agmem_bgmem_creg_base_policy.hpp} | 177 +-- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 14 +- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 1070 +++++++++++++++++ script/gemm_profile.sh | 107 ++ 13 files changed, 1231 insertions(+), 187 deletions(-) mode change 100755 => 100644 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp mode change 100755 => 100644 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp rename include/ck_tile/ops/gemm/pipeline/{wp_pipeline_agmem_bgmem_creg_v1_policy.hpp => wp_pipeline_agmem_bgmem_creg_base_policy.hpp} (64%) create mode 100644 include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp create mode 100755 script/gemm_profile.sh diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 3d3a54020c..e6f67e4c76 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -2,9 +2,15 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) +set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index e9b779c00c..cab110597b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -14,12 +14,13 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) @@ -36,7 +37,7 @@ constexpr ck_tile::index_t get_k_warp_tile() template constexpr ck_tile::index_t get_k_warp_tile_flatmm() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) if constexpr(M_Warp_Tile == 32) return sizeof(PrecType) == 2 ? 16 : 64; else @@ -231,7 +232,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase }; template -struct GemmConfigPreshufle_1 : public GemmConfigBase +struct GemmConfigPreshuffle_1 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -247,13 +248,13 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1; static constexpr bool Preshuffle = true; static constexpr bool DoubleSmemBuffer = false; }; template -struct GemmConfigPreshufle_2 : public GemmConfigBase +struct GemmConfigPreshuffle_2 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -263,15 +264,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = false; + static constexpr bool DoubleSmemBuffer = true; }; template @@ -429,7 +430,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; @@ -438,6 +439,16 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index 74e79574d1..0a06787e2b 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -279,13 +279,11 @@ int main(int argc, char* argv[]) { try { - return !run_gemm_example(argc, argv); + return !run_gemm_example(argc, argv); } catch(const std::runtime_error& e) { std::cerr << "Caught runtime error: " << e.what() << '\n'; - // Return a non-zero code to indicate failure return EXIT_FAILURE; } - return EXIT_SUCCESS; } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 0f2beca2c7..475a0c7bf3 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -219,6 +219,7 @@ int run_flatmm_example(int argc, char* argv[]) std::string b_layout = arg_parser.get_str("b_layout"); if(a_layout == "R" && b_layout == "C") { + if(data_type == "fp16") { run_flatmm_example_with_layouts>( diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp old mode 100755 new mode 100644 diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 54f2a777bf..1a28366e24 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -32,6 +32,7 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 return run_func(bool_constant{}, integral_constant{}); } }; + template struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV1 { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c201293389..c9bedd7c53 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -48,8 +48,9 @@ #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp" -#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp old mode 100755 new mode 100644 diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 0a6bacdc42..b621468e92 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -112,7 +112,7 @@ struct GemmTile1DPartitioner * @param N GEMM's N dimension. * @return dim3 Structure holding grid's X,Y and Z dimensions. */ - CK_TILE_HOST static auto + CK_TILE_HOST_DEVICE static auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t { const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp similarity index 64% rename from include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp rename to include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 25aad329d9..83555e5295 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -9,77 +9,19 @@ namespace ck_tile { struct UniversalWeightPreshufflePipelineAgBgCrPolicy + : public UniversalGemmBasePolicy { - static constexpr auto I0 = number<0>{}; - static constexpr auto I1 = number<1>{}; - static constexpr auto I2 = number<2>{}; + using BasePolicy = UniversalGemmBasePolicy; // 3d + padding template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { using namespace ck_tile; - - constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); - constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); - if constexpr(MPerXdl == 16 && NPerXdl == 16) - { - /*reduce transform layers,compare with old ck*/ - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_pass_through_transform(number{}), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - else - { - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t kKPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_pass_through_transform(kMPerBlock), - make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } -/*xor*/ -#if 0 constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = GetSmemPackA(); - using ADataType = remove_cvref_t; + using ADataType = remove_cvref_t; constexpr auto DataTypeSize = sizeof(ADataType); constexpr auto MLdsLayer = @@ -87,8 +29,8 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, - number{}, - number{}), + number{}, + number{}), make_tuple(number{}, number{}, number<1>{}), number{}, number<1>{}); @@ -96,119 +38,29 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc_0, make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), + number{})), + make_pass_through_transform(number{})), make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{})); constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( a_lds_block_desc_permuted, make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto a_lds_block_desc = transform_tensor_descriptor( a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return a_lds_block_desc; -#endif - } - - /** - * @brief Get the maximum global memory vector load size. - * - * @tparam Problem The UniversalGemmPipelineProblem object. - * @tparam DataType The tensor data type we're considering. - * @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B). - * @tparam XPerTile The contiguous Tile dimension size. - * @return Maximum DRAM vector load size. - */ - template - CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; - constexpr index_t PackedSize = - ck_tile::numeric_traits>::PackedSize; - - // Assume DataType is even! - if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && - PackedSize == 2) - { - return (PackedSize * 32 / sizeof(DataType)); - } - else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) - { - return (PackedSize * 16 / sizeof(DataType)); - } - else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) - { - return (PackedSize * 8 / sizeof(DataType)); - } - else if constexpr(sizeof(DataType) >= PackedSize * 4 && - XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) - { - return (PackedSize * 4 / sizeof(DataType)); - } - else if constexpr(sizeof(DataType) >= PackedSize * 2 && - XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && - elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) - { - return (PackedSize * 2 / sizeof(DataType)); - } - else - { - return PackedSize; - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() - { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - if constexpr(std::is_same_v) - { - return GetGlobalVectorLoadSize(); - } - else - { - return GetGlobalVectorLoadSize(); - } - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() - { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - - if constexpr(std::is_same_v) - { - return GetGlobalVectorLoadSize(); - } - else - { - return GetGlobalVectorLoadSize(); - } } template @@ -426,7 +278,6 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle() { - // using AccDataType = float; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher(); - auto b_flat_dram_window = // tile_window_with_static_distribution - make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - b_flat_distribution); + auto b_flat_dram_window = + make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); // Acc register tile auto c_block_tile = block_flatmm.MakeCBlockTile(); @@ -468,5 +467,4 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 p_smem); } }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp new file mode 100644 index 0000000000..9c0f257e8e --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -0,0 +1,1070 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" + +namespace ck_tile { + +template +struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2 +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + { + if(tail_number == TailNumber::Odd) + { + run_func(bool_constant{}, integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + run_func(bool_constant{}, integral_constant{}); + } + } +}; + +template +struct WeightPreshufflePipelineAGmemBGmemCRegV2 + : public BaseWeightPreshufflePipelineAGmemBGmemCRegV2 +{ + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = + remove_cvref_t())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() + { + return PipelinePolicy::template GetVectorSizeA(); + } + static constexpr index_t GetVectorSizeB() + { + return PipelinePolicy::template GetVectorSizeB(); + } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr index_t kLdsAlignmentInBytes = 16; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr auto TailNum = Problem::TailNum; + + static constexpr auto warp_m = WarpTile::at(idxM); + static constexpr auto warp_n = WarpTile::at(idxN); + static constexpr auto warp_k = WarpTile::at(idxK); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV2", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', WG::kM, WG::kN, WG::kK), + concat('x', GetVectorSizeA(), GetVectorSizeB()), + concat('x', kPadM, kPadN, kPadK)); + + // clang-format on + } + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t Preshuffle = Problem::Preshuffle; + using Base::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return PipelinePolicy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + + constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; + constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; + constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 - - - - + // -1 M7N2: 63 - - - - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 - - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 - - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 - - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 - - - 8 + // 0 M3N2: 15 - - - - + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 - - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 - - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 - - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 17 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 - - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 18 - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 - - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 20 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 - - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 22 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 - - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 24 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 - - - 16 + // 0 M3N2: 47 - - - - + // 0 M3N3: 48 26 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 - - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 28 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 - - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 30 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 - - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 - - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - + + if constexpr(warp_m == 16 && warp_n == 16) + { +// MFMA -> VMEM READ -> MFMA -> DS Read -> MFMA +// hiding the glbal memory VMEM latency +#if defined(__gfx950__) + if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } +// MFMA → MFMA → MFMA → MFMA → DS Read +// For other device engine we need more aggressive MFMA with DS writes interleaved +#else + if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + // Uses loops to amortize scheduling overhead + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256) + { + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128) + { + // prioritize MFMA to avoid LDS write conflicts + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + }); + } + +#endif + } + else + { + if constexpr((A_LDS_Read_Inst_Num / 2 > + A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) + { + static_for<0, + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - + B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + } + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + const index_t iMWarp = get_warp_id() / NWarp; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_pong; + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + block_sync_lds(); + + // preload A00,A10 from lds + constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_ping; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_pong; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) + { + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + HotLoopScheduler(); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + HotLoopScheduler(); + + iCounter--; + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // __builtin_amdgcn_sched_barrier(0); + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // TailHotLoopScheduler(); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // __builtin_amdgcn_sched_barrier(0); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); + }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + }); + }); + // TailHotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType & a) { return a; }, + b_flat_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + } +}; + +} // namespace ck_tile diff --git a/script/gemm_profile.sh b/script/gemm_profile.sh new file mode 100755 index 0000000000..b71c43f74f --- /dev/null +++ b/script/gemm_profile.sh @@ -0,0 +1,107 @@ +#!/bin/bash + +BIN=./bin/tile_example_gemm_weight_preshuffle +PREC=fp8 +VERBOSITY=2 + +# List of all (m, n, k) triplets +ARGS_LIST=( + "1 2048 5120" + "1 5120 1024" + "2 2048 5120" + "2 5120 1024" + "3 2048 5120" + "3 5120 1024" + "4 2048 5120" + "4 5120 1024" + "5 2048 5120" + "5 5120 1024" + "6 2048 5120" + "6 5120 1024" + "7 2048 5120" + "7 5120 1024" + "8 2048 5120" + "8 5120 1024" + "9 2048 5120" + "9 5120 1024" + "10 2048 5120" + "10 5120 1024" + "11 2048 5120" + "11 5120 1024" + "12 2048 5120" + "12 5120 1024" + "13 2048 5120" + "13 5120 1024" + "14 2048 5120" + "14 5120 1024" + "15 2048 5120" + "15 5120 1024" + "16 2048 5120" + "16 5120 1024" + "2048 5120 1024" + "2048 5120 8192" + "2048 7168 8192" + "2048 8192 3584" + "16384 7168 8192" + "16384 8192 3584" +) + +# Output file +OUTPUT_FILE="gemm_profile_results.csv" + +# Output header +echo "m,n,k,Pipeline,Time_ms,TFlops,GBps,Verification" > "$OUTPUT_FILE" + +# Loop over each argument set +for args in "${ARGS_LIST[@]}"; do + read -r m n k <<< "$args" + + echo "Testing: m=$m, n=$n, k=$k" + OUTPUT=$($BIN -m=$m -n=$n -k=$k -prec=$PREC -v=$VERBOSITY 2>/dev/null) + + # Extract pipeline information + # Format: "Launching kernel with args: gemm_fp8_pipeline_AGmemBGmemCRegV2_128x256x256x256_16x16x128_16x16_0x0x0" + PIPELINE=$(echo "$OUTPUT" | grep "Launching kernel with args:" | sed -n 's/.*Launching kernel with args: \(.*\)/\1/p') + + # Extract TFlops and GB/s from the output + # Format: "Run Gemm kernel with M=3840 N=4096 K=2048 ... : 0.042338 ms, 1521.67 TFlops, 1126.89 GB/s," + PERF_LINE=$(echo "$OUTPUT" | grep "TFlops") + + # Extract verification result + # Format: "The GPU verification result is: correct" + VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is: \(.*\)/\1/p') + + if [ -n "$PERF_LINE" ]; then + # Extract execution time in ms + TIME_MS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ ms' | grep -o '[0-9]\+\.[0-9]\+') + # Extract TFlops value - more robust regex + TFLOPS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ TFlops' | grep -o '[0-9]\+\.[0-9]\+') + # Extract GB/s value - more robust regex + GBPS=$(echo "$PERF_LINE" | grep -o '[0-9]\+\.[0-9]\+ GB/s' | grep -o '[0-9]\+\.[0-9]\+') + + # Use extracted pipeline or default if not found + if [ -z "$PIPELINE" ]; then + PIPELINE="gemm_basic" + fi + + # Print to terminal + echo " Pipeline: $PIPELINE" + echo " Time: ${TIME_MS} ms" + echo " TFlops: ${TFLOPS}" + echo " GB/s: ${GBPS}" + + + # Save to CSV file + echo "$m,$n,$k,$PIPELINE,$TIME_MS,$TFLOPS,$GBPS,$VERIFICATION" >> "$OUTPUT_FILE" + else + echo " ERROR: Could not parse performance data" + echo "" + echo "$m,$n,$k,$PIPELINE,,,,$VERIFICATION" >> "$OUTPUT_FILE" + fi +done + +echo "==========================================" +echo "Profile completed!" +echo "Results saved to: $OUTPUT_FILE" +echo "Total tests run: ${#ARGS_LIST[@]}" +echo "==========================================" \ No newline at end of file