From 0329d71fb94a19eae78fcb9a6625c31592df9ddf Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Wed, 22 Oct 2025 22:36:11 +0800 Subject: [PATCH] [CK_TILE] Update flatmm related kernels (#3022) --------- Co-authored-by: Ding, Yi Co-authored-by: felix [ROCm/composable_kernel commit: 211d64e18a1bf2ecb1d13c5eb87983bdcabb3b5e] --- example/ck_tile/18_flatmm/CMakeLists.txt | 36 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 334 ++++- example/ck_tile/18_flatmm/flatmm_basic.hpp | 66 +- example/ck_tile/18_flatmm/grouped_flatmm.cpp | 364 +++++ .../18_flatmm/mixed_prec/a16w4_flatmm.hpp | 50 + .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 511 +++++++ .../18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp | 87 ++ .../mixed_prec/mixed_prec_flatmm.cpp | 482 ++++++ .../mixed_prec/mixed_prec_flatmm.hpp | 15 + .../run_a16w4_moe_flatmm_example.inc | 353 +++++ .../mixed_prec/run_mixed_prec_flatmm.inc | 180 +++ example/ck_tile/18_flatmm/moe_flatmm.cpp | 470 ++++++ example/ck_tile/18_flatmm/moe_flatmm.hpp | 202 +++ .../ck_tile/18_flatmm/run_flatmm_example.inc | 332 ++--- .../18_flatmm/run_grouped_flatmm_example.inc | 605 ++++++++ .../18_flatmm/run_moe_flatmm_example.inc | 323 ++++ .../core/arch/amd_buffer_addressing.hpp | 48 +- .../arch/amd_buffer_addressing_builtins.hpp | 49 +- include/ck_tile/core/numeric/vector_type.hpp | 21 +- include/ck_tile/core/tensor/buffer_view.hpp | 16 +- .../core/tensor/tile_scatter_gather.hpp | 202 +++ include/ck_tile/core/tensor/tile_window.hpp | 27 + include/ck_tile/host.hpp | 1 + .../ck_tile/host/reference/reference_gemm.hpp | 177 +++ .../host/reference/reference_moe_gemm.hpp | 316 ++++ .../ops/epilogue/cshuffle_epilogue.hpp | 52 +- include/ck_tile/ops/flatmm.hpp | 6 + .../block_flatmm_asmem_bsmem_creg_v1.hpp | 1 + .../ops/flatmm/kernel/flatmm_kernel.hpp | 482 ++++-- .../flatmm/kernel/grouped_flatmm_kernel.hpp | 478 ++++++ .../kernel/mixed_prec_flatmm_kernel.hpp | 458 ++++++ .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 1325 +++++++++++++++++ .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1065 +++++++++---- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 125 +- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1259 ++++++++++++++++ ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 239 +++ .../moe_flatmm_pipeline_agmem_bgmem_creg.hpp | 1012 +++++++++++++ .../gemm/pipeline/gemm_pipeline_problem.hpp | 143 ++ include/ck_tile/ops/moe_flatmm.hpp | 10 + 39 files changed, 11183 insertions(+), 739 deletions(-) create mode 100644 example/ck_tile/18_flatmm/grouped_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc create mode 100644 example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc create mode 100644 example/ck_tile/18_flatmm/moe_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/moe_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc create mode 100644 example/ck_tile/18_flatmm/run_moe_flatmm_example.inc create mode 100644 include/ck_tile/host/reference/reference_moe_gemm.hpp create mode 100644 include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp create mode 100644 include/ck_tile/ops/moe_flatmm.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 6d6b71ea18..1641549c98 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,6 +1,32 @@ -add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) +set(SUPPORTED_GPUS gfx908 gfx90a gfx942 gfx950) + +set(has_supported_gpu FALSE) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST SUPPORTED_GPUS) + set(has_supported_gpu TRUE) + break() + endif() +endforeach() + +if(has_supported_gpu) + add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) + add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp) + add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp) + add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp) + add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp) + + set(EXAMPLE_FLATMM_COMPILE_OPTIONS) + set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS) + + if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + endif() + + target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + +endif() -set(EXAMPLE_FLATMM_COMPILE_OPTIONS) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) -target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 3273fac674..9155b27dba 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -11,7 +11,102 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" -#include "run_flatmm_example.inc" +#include + +template +constexpr const char* DataTypeToString() +{ + if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else if constexpr(std::is_same_v) + { + return "bf16"; + } + else + { + return "unknown"; + } +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); + + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); +} + +template +auto shuffle_b_v1(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); + constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp; + + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Tile, + FlatmmConfig::N_Warp, + FlatmmConfig::N_Warp_Tile, + NRepeat, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} template -float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s) +float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) { using CodegenFlatmmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -80,14 +178,14 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem; + using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; using CodegenFlatmmPipeline = ck_tile::FlatmmPipelineAGmemBGmemCRegV1; @@ -110,7 +208,10 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, memory_operation, - FlatmmConfig::NumWaveGroups>>; + FlatmmConfig::NumWaveGroups, + false, + 1, + FlatmmConfig::TiledMMAPermuteN>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. @@ -118,8 +219,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -167,40 +268,145 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - return ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - return ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } + return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - return Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; } +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + ScaleM scale_m, + ScaleN scale_n, + int n_warmup, + int n_repeat) +{ + ck_tile::ScaleFlatmmHostArgs args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_m, + scale_n}; + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") + .insert("persistent", "0", "0: no persistent, 1: persistent kernel") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +#include "run_flatmm_example.inc" + template