From d454d0e201b528fa7f26fc382b8e6c48809ebe8d Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Mon, 22 Sep 2025 08:05:55 +0000 Subject: [PATCH] cherry pick related code --- example/ck_tile/18_flatmm/CMakeLists.txt | 26 +- example/ck_tile/18_flatmm/README.md | 2 +- example/ck_tile/18_flatmm/flatmm_basic.cpp | 307 +++- example/ck_tile/18_flatmm/flatmm_basic.hpp | 66 +- .../18_flatmm/mixed_prec/a16w4_flatmm.hpp | 50 + .../18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp | 513 +++++++ .../18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp | 87 ++ .../mixed_prec/mixed_prec_flatmm.cpp | 484 ++++++ .../mixed_prec/mixed_prec_flatmm.hpp | 15 + .../run_a16w4_moe_flatmm_example.inc | 356 +++++ .../mixed_prec/run_mixed_prec_flatmm.inc | 180 +++ example/ck_tile/18_flatmm/moe_flatmm.cpp | 473 ++++++ example/ck_tile/18_flatmm/moe_flatmm.hpp | 202 +++ .../ck_tile/18_flatmm/run_flatmm_example.inc | 332 ++--- .../18_flatmm/run_moe_flatmm_example.inc | 323 ++++ .../core/arch/amd_buffer_addressing.hpp | 55 +- .../arch/amd_buffer_addressing_builtins.hpp | 49 +- .../core/arch/generic_memory_space_atomic.hpp | 97 +- include/ck_tile/core/container/sequence.hpp | 24 +- include/ck_tile/core/numeric/vector_type.hpp | 26 +- include/ck_tile/core/tensor/buffer_view.hpp | 192 ++- .../core/tensor/tile_scatter_gather.hpp | 272 ++++ include/ck_tile/core/tensor/tile_window.hpp | 27 + include/ck_tile/host/kernel_launch.hpp | 135 +- .../ck_tile/host/reference/reference_gemm.hpp | 367 +++-- .../host/reference/reference_moe_gemm.hpp | 315 ++++ .../unary_element_wise_operation.hpp | 9 +- .../ops/epilogue/cshuffle_epilogue.hpp | 41 +- include/ck_tile/ops/flatmm.hpp | 3 + .../block_flatmm_asmem_bsmem_creg_v1.hpp | 1 + .../ops/flatmm/kernel/flatmm_kernel.hpp | 498 +++++-- .../kernel/mixed_prec_flatmm_kernel.hpp | 458 ++++++ .../flatmm_pipeline_agmem_bgmem_creg_v0.hpp | 883 +++++++++++ .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1065 +++++++++---- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 139 +- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1256 ++++++++++++++++ ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 240 +++ .../gemm/pipeline/gemm_pipeline_problem.hpp | 147 +- include/ck_tile/ops/moe_flatmm.hpp | 10 + .../moe_flatmm/kernel/moe_flatmm_kernel.hpp | 1322 +++++++++++++++++ .../moe_flatmm_pipeline_agmem_bgmem_creg.hpp | 1012 +++++++++++++ 41 files changed, 10947 insertions(+), 1112 deletions(-) create mode 100644 example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/mixed_prec/run_a16w4_moe_flatmm_example.inc create mode 100644 example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc create mode 100644 example/ck_tile/18_flatmm/moe_flatmm.cpp create mode 100644 example/ck_tile/18_flatmm/moe_flatmm.hpp create mode 100644 example/ck_tile/18_flatmm/run_moe_flatmm_example.inc create mode 100644 include/ck_tile/host/reference/reference_moe_gemm.hpp mode change 100644 => 100755 include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp create mode 100644 include/ck_tile/ops/moe_flatmm.hpp create mode 100644 include/ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp create mode 100644 include/ck_tile/ops/moe_flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index 6d6b71ea18..50c0a78026 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -1,6 +1,28 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) +add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp) + +add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp) +add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp) + set(EXAMPLE_FLATMM_COMPILE_OPTIONS) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) + +set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS) + +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef -Wno-unused-variable -Wno-unused-parameter) +list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -Wno-nrvo -Wno-unused-variable -Wno-unused-parameter -Wno-unused-local-typedef -Wno-float-equal) + +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) + list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + + +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo) + target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) + +list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS --save-temps) +target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md index c58700fc7b..49420a7325 100644 --- a/example/ck_tile/18_flatmm/README.md +++ b/example/ck_tile/18_flatmm/README.md @@ -7,7 +7,7 @@ This folder contains example for FLATMM using ck_tile tile-programming implement # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -../script/cmake-ck-dev.sh ../ +sh ../script/cmake-ck-dev.sh ../ # The basic pipeline method on the flatmm calculation make tile_example_flatmm_basic -j ``` diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 280da8d333..c19116fa9e 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -11,7 +11,102 @@ #include "ck_tile/host.hpp" #include "flatmm_basic.hpp" -#include "run_flatmm_example.inc" +#include + +template +constexpr const char* DataTypeToString() +{ + if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else if constexpr(std::is_same_v) + { + return "bf16"; + } + else + { + return "unknown"; + } +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); + + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, + FlatmmConfig::N_Warp_Tile, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 1, 3}); +} + +template +auto shuffle_b_v1(const ck_tile::HostTensor& t) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + constexpr int MaxVecSize = 16 / sizeof(T); + constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane); + constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp; + + ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Tile, + FlatmmConfig::N_Warp, + FlatmmConfig::N_Warp_Tile, + NRepeat, + k_ / ItemsPerAccess, + ItemsPerAccess}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} template -float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s) +float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) { using CodegenFlatmmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -80,14 +178,14 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; - using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem; + using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem; using CodegenFlatmmPipeline = ck_tile::FlatmmPipelineAGmemBGmemCRegV1; @@ -101,6 +199,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c DsLayout, ELayout, CDEElementWise, + CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, FlatmmConfig::M_Warp, @@ -110,7 +209,10 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, memory_operation, - FlatmmConfig::NumWaveGroups>>; + FlatmmConfig::NumWaveGroups, + false, + 1, + FlatmmConfig::TiledMMAPermuteN>>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. @@ -118,8 +220,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -167,16 +269,18 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + ave_time = ck_tile::launch_kernel_preprocess( s, run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } return ave_time; }; @@ -201,6 +305,111 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c return ave_time; } +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + ScaleM scale_m, + ScaleN scale_n, + int n_warmup, + int n_repeat) +{ + ck_tile::ScaleFlatmmHostArgs args = {a_dev_buf.GetDeviceBuffer(), + b_shuffle_dev_buf.GetDeviceBuffer(), + {}, + c_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C, + scale_m, + scale_n}; + + float ave_time = flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8") + .insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8") + .insert("persistent", "0", "0: no persistent, 1: persistent kernel") + .insert("warp_tile", + "0", + "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)"); + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +#include "run_flatmm_example.inc" + template