From e34599e8a9126ea00955cef7b9b2e605d017bb08 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Fri, 11 Jul 2025 08:27:55 -0700 Subject: [PATCH] Merge flatmm Operator with universal gemm (#2434) * Initial commit * Adding new tile partitioner to flatmm * intermediate changes * debugging kernels * Updating flatmm example to universal gemm example * updated flatmm kernel to run via gemmKernel * update universal gemm to incorporate flatmm * debug * Fix flatmm call * Fixing other kernels and tests for API changes * clang formatted * fixing gemm tests * added test for flatmm and simplify kernel arguments * adding flatmm test * fix test for flatmm * simplify gemm kernel with flatmm * remove flatmm related files * addressing review comments and code clean up * resolving empty file * resolving empty file * clang formatted * addressing review comments * enable persistent kernel for flatmm * reverted the removed files for flatmm * reverted the removed files for flatmm * changed flatmm to weightPReshuffle; removed the _1 added in teh faltmm example * some more renames * clang formatted [ROCm/composable_kernel commit: d239b91fd54f63cc6e46ba2f6fe7d02512ebe3f1] --- example/ck_tile/03_gemm/CMakeLists.txt | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 71 +++ .../03_gemm/gemm_weight_preshuffle.cpp | 294 +++++++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 76 ++- example/ck_tile/03_gemm/universal_gemm.cpp | 27 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 137 +++-- example/ck_tile/18_flatmm/flatmm_basic.hpp | 61 ++- .../ck_tile/18_flatmm/run_flatmm_example.inc | 64 ++- .../ops/flatmm/kernel/flatmm_kernel.hpp | 357 +++++++++---- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 83 ++- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 124 ++++- include/ck_tile/ops/gemm.hpp | 4 + .../block/block_wp_asmem_bsmem_creg_v1.hpp | 122 +++++ ...k_wp_asmem_bsmem_creg_v1_custom_policy.hpp | 38 ++ .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 86 +++- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 23 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 11 + ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 4 +- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 1 + ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 4 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 1 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 39 +- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 4 + .../ops/gemm/pipeline/tile_gemm_traits.hpp | 12 +- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 472 ++++++++++++++++++ ...wp_pipeline_agmem_bgmem_creg_v1_policy.hpp | 450 +++++++++++++++++ test/ck_tile/CMakeLists.txt | 1 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 7 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 25 +- .../gemm_weight_preshuffle/CMakeLists.txt | 22 + .../test_gemm_pipeline_kernel_types.hpp | 32 ++ .../test_gemm_pipeline_ut_cases.inc | 21 + .../test_gemm_pipeline_util.hpp | 384 ++++++++++++++ .../test_gemm_pipeline_wp.cpp | 16 + 34 files changed, 2736 insertions(+), 338 deletions(-) create mode 100644 example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp mode change 100644 => 100755 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp mode change 100644 => 100755 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp create mode 100644 test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt create mode 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp create mode 100755 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc create mode 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp create mode 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_wp.cpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 411db2e317..3d3a54020c 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,6 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 2157397f1d..9deccc7f16 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -14,6 +14,7 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 template constexpr ck_tile::index_t get_k_warp_tile() @@ -32,6 +33,21 @@ constexpr ck_tile::index_t get_k_warp_tile() return 32; #endif } +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} struct GemmConfigBase { @@ -51,6 +67,7 @@ struct GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; }; template @@ -213,6 +230,50 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + template struct GemmTypeConfig; @@ -367,6 +428,16 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp new file mode 100644 index 0000000000..f57c24f458 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } +} + +template