From a26ba690fd08aa6b6aef967a39f857292ab2b8bd Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Thu, 10 Jul 2025 13:00:47 -0400 Subject: [PATCH 01/92] fix(precommit_install): fix bug for bare metal machines (#2448) Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> --- script/install_precommit.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/script/install_precommit.sh b/script/install_precommit.sh index 83e526035c..6132f6a287 100755 --- a/script/install_precommit.sh +++ b/script/install_precommit.sh @@ -9,13 +9,13 @@ run_and_check() { return $status } -echo "I: Installing tools required for pre-commit checks..." -run_and_check apt install clang-format-12 - echo "I: Creating and activating virtual environment for pre-commit..." python3 -m venv "$(dirname "$0")/../.venv" source "$(dirname "$0")/../.venv/bin/activate" +echo "I: Installing tools required for pre-commit checks..." +run_and_check pip install dos2unix +run_and_check pip install clang-format==12.0.1 echo "I: Installing pre-commit in virtual environment..." run_and_check pip install pre-commit run_and_check pre-commit install From 45904b8fd7cde71dfc3741970325b3d552b06d27 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Fri, 11 Jul 2025 18:14:47 +0800 Subject: [PATCH 02/92] Add separate mask checking for scope [aligned_physical_seqlen_k_start, physical_seqlen_k_end) (#2487) * Add separate mask checking for scope [aligned_physical_seqlen_k_start, physical_seqlen_k_end) in pagedkv pipeline * i_nhead_ conversion type to prevent overflow --------- Co-authored-by: ltqin --- .../fmha/kernel/fmha_fwd_pagedkv_kernel.hpp | 6 ++- ...ock_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp | 54 ++++++++++++------- 2 files changed, 40 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index e56d518634..d8cd006c60 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -1122,7 +1122,8 @@ struct FmhaFwdPagedKVKernel const index_t num_blocks = integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); - const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_k; + const long_index_t fixed_offset = + static_cast(i_nhead_) * kargs.nhead_stride_k; return make_page_block_navigator( kargs.k_ptr, @@ -1152,7 +1153,8 @@ struct FmhaFwdPagedKVKernel const index_t num_blocks = integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size); - const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_v; + const long_index_t fixed_offset = + static_cast(i_nhead_) * kargs.nhead_stride_v; return make_page_block_navigator( kargs.v_ptr, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp index 6ad5844b69..9d267e1cee 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp @@ -441,28 +441,46 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS } } move_tile_window(bias_dram_window, {0, kN0}); - if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { const auto k_origin = k_page_block_navigator.to_global_window_origin( i_page_block_k, k_dram_block_window.get_window_origin()); - // mask accept only logical coordinates, do conversion here - bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), - k_origin.at(number<0>{}) - kv_l2p_offset, - number{}, - number{}); - if(need_perpixel_check) + + if constexpr(kIsPagedKV) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col - kv_l2p_offset, - block_indices.qo_head_idx, - block_indices.kv_head_idx); - }); + // check columns in [aligned_physical_seqlen_k_start, physical_seqlen_k_end) + if(kv_l2p_offset > 0) + { + set_tile_if( + s_acc, + -numeric::infinity(), + [&, physical_seqlen_k_start_ = physical_seqlen_k_start](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return col < physical_seqlen_k_start_; + }); + }; + } + + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + // mask accept only logical coordinates, do conversion here + bool need_perpixel_check = + mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}) - kv_l2p_offset, + number{}, + number{}); + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col - kv_l2p_offset); + }); + } } } From d239b91fd54f63cc6e46ba2f6fe7d02512ebe3f1 Mon Sep 17 00:00:00 2001 From: Khushbu Agarwal Date: Fri, 11 Jul 2025 08:27:55 -0700 Subject: [PATCH 03/92] Merge flatmm Operator with universal gemm (#2434) * Initial commit * Adding new tile partitioner to flatmm * intermediate changes * debugging kernels * Updating flatmm example to universal gemm example * updated flatmm kernel to run via gemmKernel * update universal gemm to incorporate flatmm * debug * Fix flatmm call * Fixing other kernels and tests for API changes * clang formatted * fixing gemm tests * added test for flatmm and simplify kernel arguments * adding flatmm test * fix test for flatmm * simplify gemm kernel with flatmm * remove flatmm related files * addressing review comments and code clean up * resolving empty file * resolving empty file * clang formatted * addressing review comments * enable persistent kernel for flatmm * reverted the removed files for flatmm * reverted the removed files for flatmm * changed flatmm to weightPReshuffle; removed the _1 added in teh faltmm example * some more renames * clang formatted --- example/ck_tile/03_gemm/CMakeLists.txt | 1 + example/ck_tile/03_gemm/gemm_utils.hpp | 71 +++ .../03_gemm/gemm_weight_preshuffle.cpp | 294 +++++++++++ example/ck_tile/03_gemm/run_gemm_example.inc | 76 ++- example/ck_tile/03_gemm/universal_gemm.cpp | 27 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 137 +++-- example/ck_tile/18_flatmm/flatmm_basic.hpp | 61 ++- .../ck_tile/18_flatmm/run_flatmm_example.inc | 64 ++- .../ops/flatmm/kernel/flatmm_kernel.hpp | 357 +++++++++---- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 83 ++- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 124 ++++- include/ck_tile/ops/gemm.hpp | 4 + .../block/block_wp_asmem_bsmem_creg_v1.hpp | 122 +++++ ...k_wp_asmem_bsmem_creg_v1_custom_policy.hpp | 38 ++ .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 86 +++- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 23 +- .../gemm_pipeline_ag_bg_cr_comp_v4.hpp | 11 + ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 4 +- .../gemm_pipeline_ag_bg_cr_comp_v5.hpp | 1 + ...peline_ag_bg_cr_comp_v5_default_policy.hpp | 4 +- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 1 + .../gemm/pipeline/gemm_pipeline_problem.hpp | 39 +- .../ops/gemm/pipeline/tile_gemm_shape.hpp | 4 + .../ops/gemm/pipeline/tile_gemm_traits.hpp | 12 +- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 472 ++++++++++++++++++ ...wp_pipeline_agmem_bgmem_creg_v1_policy.hpp | 450 +++++++++++++++++ test/ck_tile/CMakeLists.txt | 1 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 7 +- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 25 +- .../gemm_weight_preshuffle/CMakeLists.txt | 22 + .../test_gemm_pipeline_kernel_types.hpp | 32 ++ .../test_gemm_pipeline_ut_cases.inc | 21 + .../test_gemm_pipeline_util.hpp | 384 ++++++++++++++ .../test_gemm_pipeline_wp.cpp | 16 + 34 files changed, 2736 insertions(+), 338 deletions(-) create mode 100644 example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp mode change 100644 => 100755 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp mode change 100644 => 100755 include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1_policy.hpp create mode 100644 test/ck_tile/gemm_weight_preshuffle/CMakeLists.txt create mode 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp create mode 100755 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc create mode 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp create mode 100644 test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_wp.cpp diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 411db2e317..3d3a54020c 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,6 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 2157397f1d..9deccc7f16 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -14,6 +14,7 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE 5 template constexpr ck_tile::index_t get_k_warp_tile() @@ -32,6 +33,21 @@ constexpr ck_tile::index_t get_k_warp_tile() return 32; #endif } +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(__gfx950__) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} struct GemmConfigBase { @@ -51,6 +67,7 @@ struct GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; }; template @@ -213,6 +230,50 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct GemmConfigPreshufle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshufle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + template struct GemmTypeConfig; @@ -367,6 +428,16 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp new file mode 100644 index 0000000000..f57c24f458 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -0,0 +1,294 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } +} + +template