From c4aa2fef46aecb8cd1210ff35251ac2634f7bb97 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 30 Jul 2025 07:55:09 +0000 Subject: [PATCH] merge M grouped flatmm --- example/ck_tile/18_flatmm/CMakeLists.txt | 2 + example/ck_tile/18_flatmm/grouped_flatmm.cpp | 382 +++++++ .../18_flatmm/run_grouped_flatmm_example.inc | 935 ++++++++++++++++++ include/ck_tile/host/kernel_launch.hpp | 6 + include/ck_tile/ops/flatmm.hpp | 1 + .../ops/flatmm/kernel/flatmm_kernel.hpp | 6 +- .../flatmm/kernel/grouped_flatmm_kernel.hpp | 465 +++++++++ .../ops/gemm/kernel/gemm_tile_partitioner.hpp | 2 +- script/cmake-ck-dev.sh | 2 +- 9 files changed, 1797 insertions(+), 4 deletions(-) create mode 100644 example/ck_tile/18_flatmm/grouped_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc mode change 100755 => 100644 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 30fd769c88..fa5f1b7a24 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,4 +1,5 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) +add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) @@ -11,3 +12,4 @@ list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo) #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") #list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --disable-schedmodel-in-sched-mi=1 -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental -mllvm --misched-bottomup=1") target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp new file mode 100644 index 0000000000..56146f5cbb --- /dev/null +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -0,0 +1,382 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "flatmm_basic.hpp" + +#include "ck_tile/host.hpp" + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "512,256,1024", "m dimension") + .insert("Ns", "512,512,512", "n dimension") + .insert("Ks", "1024,1024,512", "k dimension") + .insert("group_count", "3", "group count") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("mode", "general", "grouped gemm mode: [general | contiguous], general by default") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s) +{ + using CodegenFlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; + + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = + ck_tile::GroupedFlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + // if(!Kernel::IsSupportedArgument(kargs)) + // { + // throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // } + + // if(s.flush_cache_) + // { + // std::cout << "Flushing cache..." << std::endl; + // static constexpr ck_tile::index_t APackedSize = + // std::is_same_v ? 2 : 1; + // static constexpr ck_tile::index_t BPackedSize = + // std::is_same_v ? 2 : 1; + + // ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + // args.group_count * args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + // ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + // args.K, args.group_count * args.N, args.stride_B, is_row_major(BLayout{}))); + + // auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + // auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + // ck_tile::RotatingMemWrapper rotating_mem( + // kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + // rotating_mem.Print(); + + // auto run_flush_cache = [&]() { + // // flush icache + // ck_tile::flush_icache(); + // // rotating mem + // rotating_mem.Next(); + // // clear c mem + // if(args.k_batch > 1) + // hipGetErrorString(hipMemsetAsync( + // args.e_ptr, 0, args.group_count * args.M * args.N * sizeof(CDataType), s.stream_id_)); + // }; + // ave_time = ck_tile::launch_kernel_preprocess( + // s, + // run_flush_cache, + // ck_tile::make_kernel( + // Kernel{}, grids, blocks, 0, kargs)); + // } + // else + // { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + // } + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +#include "run_grouped_flatmm_example.inc" + +template