From 7336398fb66bf864cad652e7cb6e1cbea12c490c Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 18 Nov 2025 13:46:30 +0800 Subject: [PATCH] [CK_TILE] MX Flatmm Split kernel instances (#3207) * [CK_TILE] MX Flatmm Split kernel instances * Fix flatmm example compile [ROCm/composable_kernel commit: b6720531de9cbbe5f6022f173ead11c61860f57f] --- example/ck_tile/18_flatmm/CMakeLists.txt | 7 +- example/ck_tile/18_flatmm/moe_flatmm.cpp | 2 +- .../ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp | 306 ++++-------------- .../18_flatmm/mxgemm/mx_flatmm_instance.cmake | 27 ++ .../mxgemm/mx_flatmm_instance.cpp.in | 53 +++ .../18_flatmm/mxgemm/mx_flatmm_instance.hpp | 172 ++++++++++ .../ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp | 20 ++ .../18_flatmm/mxgemm/run_mx_flatmm.inc | 2 +- .../18_flatmm/run_grouped_flatmm_example.inc | 17 - .../18_flatmm/run_moe_flatmm_example.inc | 2 +- .../flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 30 +- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 9 +- 12 files changed, 371 insertions(+), 276 deletions(-) create mode 100644 example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake create mode 100644 example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in create mode 100644 example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt index d2ad442248..c5cecceb9c 100644 --- a/example/ck_tile/18_flatmm/CMakeLists.txt +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -14,7 +14,12 @@ if(has_supported_gpu) add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp) add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp) add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp) - add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only + + include(mxgemm/mx_flatmm_instance.cmake) + mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES) + message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}") + add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES}) + target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm) set(EXAMPLE_FLATMM_COMPILE_OPTIONS) set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS) diff --git a/example/ck_tile/18_flatmm/moe_flatmm.cpp b/example/ck_tile/18_flatmm/moe_flatmm.cpp index 4db6a1171f..064522a360 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.cpp @@ -29,7 +29,7 @@ static constexpr inline auto is_row_major(Layout layout_) } template -auto shuffle_b(const ck_tile::HostTensor& t) +auto flatmm_shuffle_b(const ck_tile::HostTensor& t) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 0474e4b1d6..33a2ba3135 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -20,211 +20,6 @@ static constexpr inline auto is_row_major(Layout layout_) ck_tile::tensor_layout::gemm::RowMajor>>{}; } -template -float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, - const ck_tile::stream_config& s) -{ - using CodegenFlatmmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = - ck_tile::GemmSpatiallyLocalTilePartitioner; - - using Traits = ck_tile::TileGemmTraits; - - using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits; - - using ComputeDataType = ADataType; - static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), - "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); - - using GemmPipelineProblem = ck_tile::GemmPipelineProblem; - - using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; - - const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; - const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; - - constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern - - using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem; - - using CodegenMXFlatmmPipeline = - ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1; - - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = - ck_tile::MXFlatmmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(kargs); - constexpr dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n" - << "Shape: " << CodegenFlatmmShape::GetName() << "\n" - << "problem: " << CodegenPipelineProblem::GetName() << "\n" - << "pipeline: " << CodegenMXFlatmmPipeline::GetName() << "\n" - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits::PackedSize; - constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits::PackedSize; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - - rotating_mem_ptr = std::make_unique>( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; - - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; -} - template ( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + using FlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1; + + const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile; + const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split); + const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time = BaseFlatmmPipeline::template TailHandler( + [&](auto has_hot_loop_, auto tail_num_) { + constexpr auto has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_num_v = tail_num_.value; + auto invoke_splitk_path = [&](auto split_k_) { + return mx_flatmm_calc( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + }; + return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{}) + : invoke_splitk_path(std::true_type{}); + }, + has_hot_loop, + tail_num); constexpr int APackedSize = ck_tile::numeric_traits::PackedSize; constexpr int BPackedSize = ck_tile::numeric_traits::PackedSize; @@ -297,8 +137,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf, float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Run MXFP4_Flatmm kernel " // - << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A - << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A + << " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; @@ -441,21 +281,13 @@ int run_mx_flatmm_example(int argc, char* argv[]) if(mx_prec == "fp4xfp4") { if(persistent_opt == 0) - { - run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } + return run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); else - { - run_mx_flatmm_with_layouts(argc, argv, Row{}, Col{}, Row{}); - } + throw std::runtime_error("Only non-persistent kernels are supported currently!"); } else if(mx_prec == "fp6xfp6") { @@ -487,7 +319,7 @@ int main(int argc, char* argv[]) int warp_tile = arg_parser.get_int("warp_tile"); if(warp_tile == 0) { - return !run_mx_flatmm_example(argc, argv); + return run_mx_flatmm_example(argc, argv); } else if(warp_tile == 1) { diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake new file mode 100644 index 0000000000..950b0c72a6 --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake @@ -0,0 +1,27 @@ +function(mx_flatmm_instance_generate FILE_LIST) + set(FLATMM_CONFIG MXfp4_FlatmmConfig16) + set(A_DATA_TYPE FP4) + set(B_DATA_TYPE FP4) + set(C_DATA_TYPE FP16) + set(A_LAYOUT ROW) + set(B_LAYOUT COL) + set(C_LAYOUT ROW) + + # foreach(PERSISTENT false true) + # TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions. + foreach(PERSISTENT false) + foreach(SPLIT_K false true) + foreach(HAS_HOT_LOOP false true) + foreach(TAIL_NUMBER ODD EVEN) + set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp) + configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in + ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE} + @ONLY) + list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}) + endforeach() + endforeach() + endforeach() + endforeach() + set(${FILE_LIST} ${${FILE_LIST}} PARENT_SCOPE) +endfunction() diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in new file mode 100644 index 0000000000..0be9fc7bb7 --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "mx_flatmm_instance.hpp" + +// clang-format off +#define FLATMM_CONFIG @FLATMM_CONFIG@ +#define A_DATA_TYPE @A_DATA_TYPE@ +#define B_DATA_TYPE @B_DATA_TYPE@ +#define C_DATA_TYPE @C_DATA_TYPE@ +#define A_LAYOUT @A_LAYOUT@ +#define B_LAYOUT @B_LAYOUT@ +#define C_LAYOUT @C_LAYOUT@ +#define PERSISTENT @PERSISTENT@ +#define SPLIT_K @SPLIT_K@ +#define HAS_HOT_LOOP @HAS_HOT_LOOP@ +#define TAIL_NUMBER @TAIL_NUMBER@ +// clang-format on + +using FP4 = ck_tile::pk_fp4_t; +using FP16 = ck_tile::fp16_t; +using BF16 = ck_tile::bf16_t; + +using ROW = ck_tile::tensor_layout::gemm::RowMajor; +using COL = ck_tile::tensor_layout::gemm::ColumnMajor; + +inline constexpr auto ODD = ck_tile::TailNumber::Odd; +inline constexpr auto EVEN = ck_tile::TailNumber::Even; + +inline constexpr int ScaleGranularityM = 1; +inline constexpr int ScaleGranularityN = 1; +inline constexpr int ScaleGranularityK = 32; +using ScaleM = ck_tile::FlatmmScalePointer; +using ScaleN = ck_tile::FlatmmScalePointer; + +template float mx_flatmm_calc, + /*AccDataType*/ float, + C_DATA_TYPE, + A_LAYOUT, + B_LAYOUT, + /*DsLayout*/ ck_tile::tuple<>, + C_LAYOUT, + ScaleM, + ScaleN, + PERSISTENT, + /*CDEElementWise*/ ck_tile::element_wise::PassThrough, + SPLIT_K, + HAS_HOT_LOOP, + TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp new file mode 100644 index 0000000000..c7614e9bd4 --- /dev/null +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "mx_flatmm.hpp" + +template +using is_row_major_t = ck_tile::bool_constant< + std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>; + +template +float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s) +{ + using FlatmmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using MXGemmTraits = ck_tile::TileGemmUniversalTraits; + + using ComputeDataType = ADataType; + static_assert(sizeof(ComputeDataType) >= sizeof(BDataType), + "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); + + constexpr auto scheduler = FlatmmConfig::Scheduler; + constexpr auto memory_operation = + Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set; + + constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern + + using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem; + + using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + using GemmEpilogue = + ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::MXFlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << FlatmmShape::GetName() << "\n" + << "Shape: " << FlatmmShape::GetName() << "\n" + << "problem: " << MXPipelineProblem::GetName() << "\n" + << "pipeline: " << MXFlatmmPipeline::GetName() << "\n" + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; + + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString( + hipMemsetAsync(args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits::PackedSize; + constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits::PackedSize; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major_t{})); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major_t{})); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + rotating_mem_ptr = std::make_unique>( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; + } + else + { + preprocess = clear_gemm_output; + } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp b/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp index 4ef627969c..02f58a6269 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp @@ -38,3 +38,23 @@ struct MXfp4_FlatmmConfig16 static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; static constexpr bool TiledMMAPermuteN = false; }; + +template +float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc index bc24427780..0171fc1403 100644 --- a/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc +++ b/example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc @@ -163,5 +163,5 @@ int run_mx_flatmm_with_layouts(int argc, std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; } - return pass; + return pass ? 0 : -1; } diff --git a/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc index fbab5b6d0e..3bb039aae8 100644 --- a/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc @@ -3,23 +3,6 @@ #pragma once -// mfma_type, 0:32x32, 1:16x16 -template -auto shuffle_b(const ck_tile::HostTensor& t) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view({n_ / FlatmmConfig::N_Warp_Tile, - FlatmmConfig::N_Warp_Tile, - k_ / FlatmmConfig::K_Warp_Tile, - divisor, - FlatmmConfig::K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); -} - template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, diff --git a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc index 9e0cbda0c0..d898ed2f29 100644 --- a/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc +++ b/example/ck_tile/18_flatmm/run_moe_flatmm_example.inc @@ -114,7 +114,7 @@ int run_moe_gemm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{0.0f, 1.0f}(a_m_k_tensor); ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n_tensor); - auto b_shuffle_host = shuffle_b(b_k_n_tensor); + auto b_shuffle_host = flatmm_shuffle_b(b_k_n_tensor); std::cout << "moe_flatmm:" // << "\n num_experts: " << experts << "\n num_tokens: " << num_tokens diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ceb6ef6734..5bb5436edf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -23,22 +23,28 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1 { return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; } - template + + template + CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop) + { + if constexpr(!DispatchHotloop) + return run_func(bool_constant{}, integral_constant{}); + else if(has_hot_loop) + return run_func(bool_constant{}, integral_constant{}); + else + return run_func(bool_constant{}, integral_constant{}); + } + + template CK_TILE_HOST_DEVICE static auto - TailHandler(const RunFunction& run_func, bool, TailNumber tail_num) + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_num) { if(TailNumber::Even == tail_num) - { - return run_func(bool_constant{}, - integral_constant{}); - } + return TailHandler(run_func, has_hot_loop); else if(TailNumber::Odd == tail_num) - { - return run_func(bool_constant{}, - integral_constant{}); - } - // return run_func(bool_constant{}, integral_constant{}); + return TailHandler(run_func, has_hot_loop); + else + assert(("Wrong TailNumber!", false)); } }; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index d3da488a88..896f6613a7 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -216,17 +216,14 @@ struct UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * - MakeALdsBlockDescriptor().get_element_space_size(); - return smem_size_a; + return sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); } template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - constexpr index_t smem_size_a = GetSmemSizeA(); - - return smem_size_a; + return GetSmemSizeA(); } template