From 8b98fe03539a8c65dd2c1651d5b6d3efda99e17a Mon Sep 17 00:00:00 2001 From: Yi DING Date: Mon, 8 Dec 2025 19:20:44 +0800 Subject: [PATCH] [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 (#3287) * [CK_TILE] Optimize Flatmm MXFP4 by Eliminating Runtime Division by 2 * typo [ROCm/composable_kernel commit: 878b4e7f46d7e47618f4d860d71b438cb6d992fd] --- .../ops/flatmm/kernel/mx_flatmm_kernel.hpp | 134 +++++++----------- ...mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 81 ++++++----- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 47 +++++- 3 files changed, 141 insertions(+), 121 deletions(-) diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index d9fb144176..1133da33ad 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -18,21 +18,21 @@ struct MXFlatmmKernel : FlatmmKernel; - using TilePartitioner = remove_cvref_t; - using FlatmmPipeline = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using MXFlatmmPipeline = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; using DsLayout = remove_cvref_t; using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; - static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel; + static constexpr index_t KernelBlockSize = MXFlatmmPipeline::BlockSize; + static constexpr bool UsePersistentKernel = MXFlatmmPipeline::UsePersistentKernel; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; @@ -43,9 +43,9 @@ struct MXFlatmmKernel : FlatmmKernel::PackedSize; static constexpr int BPackedSize = numeric_traits::PackedSize; - static constexpr int MXdlPack = FlatmmPipeline::MXdlPack; - static constexpr int NXdlPack = FlatmmPipeline::NXdlPack; - static constexpr int KXdlPack = FlatmmPipeline::KXdlPack; + static constexpr int MXdlPack = MXFlatmmPipeline::MXdlPack; + static constexpr int NXdlPack = MXFlatmmPipeline::NXdlPack; + static constexpr int KXdlPack = MXFlatmmPipeline::KXdlPack; static constexpr index_t NumDTensor = DsDataType::size(); @@ -63,7 +63,7 @@ struct MXFlatmmKernel : FlatmmKernel, FlatmmPipeline::GetName()); + return concat('_', "mx_flatmm_gemm", gemm_prec_str, MXFlatmmPipeline::GetName()); // clang-format on } @@ -123,33 +123,23 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); }(); - constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock; + constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; const auto& b_flat_tensor_view = [&]() { - static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0, + static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( make_tuple(kFlatN, kFlatKBlocks, number{})); @@ -262,20 +252,12 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); }(); const auto& b_flat_tensor_view = views.at(I1); @@ -289,14 +271,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(d_tensor_view[i], make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }, number{}); @@ -309,14 +291,14 @@ struct MXFlatmmKernel : FlatmmKernel{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view(e_tensor_view, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); @@ -334,26 +316,18 @@ struct MXFlatmmKernel : FlatmmKernel) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } + static_assert(std::is_same_v, + "A tensor for mx must be RowMajor"); + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); }(); const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), + make_tuple(number{}, + number{}), {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); const auto ds_block_window = generate_tuple( @@ -444,14 +418,14 @@ struct MXFlatmmKernel : FlatmmKernel(kargs.a_ptr) + - splitk_batch_offset.a_k_split_offset / APackedSize; - const BDataType* b_flat_ptr = static_cast(kargs.b_ptr) + - splitk_batch_offset.b_k_split_offset / BPackedSize; + const auto a_ptr = static_cast(kargs.a_ptr) + + splitk_batch_offset.a_k_split_offset / APackedSize; + const auto b_flat_ptr = static_cast(kargs.b_ptr) + + splitk_batch_offset.b_k_split_offset / BPackedSize; EDataType* e_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -501,7 +475,7 @@ struct MXFlatmmKernel : FlatmmKernel::value)) { - constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); RunFlatmm(a_ptr, b_flat_ptr, kargs.ds_ptr, diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index ff799cb0fc..87ae7f57d8 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -34,13 +34,11 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem; using CLayout = remove_cvref_t; + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + using BlockFlatmm = remove_cvref_t())>; @@ -81,8 +82,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1::PackedSize; - static constexpr index_t BPackedSize = numeric_traits::PackedSize; + // static constexpr index_t WG_AKPacks = WG::kK / APackedSize; + // static constexpr index_t WG_BKPacks = WG::kK / BPackedSize; static constexpr index_t MXdlPack = Problem::MXdlPack; static constexpr index_t NXdlPack = Problem::NXdlPack; static constexpr index_t KXdlPack = Problem::KXdlPack; static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK; - static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; - static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; + static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType); static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) ? DsReadPreload @@ -562,11 +563,8 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_window = PipelinePolicy::template MakeMX_BFlatBytesDramWindow( + b_flat_dram_block_window_tmp); auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; @@ -621,7 +619,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, true_type{}, false_type{}); + async_load_tile(lds, dram, number<-1>{}, true_type{}, true_type{}); }; // HEAD @@ -633,11 +631,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); // move B window to next flat K b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); // prefetch Scale A @@ -698,12 +697,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -739,8 +738,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -792,12 +793,12 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( - tuple, number>{}); + tuple, number>{}); }); }); @@ -833,8 +834,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -897,7 +900,7 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( b_flat_dram_window, - b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter); }); }); @@ -932,8 +935,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) @@ -986,8 +991,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_pong(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_pong(number{})(number{})), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) @@ -1029,8 +1036,10 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1( c_warp_tensors(number{})(number{}), - a_warp_tensor(number{}), - b_warp_tensor_ping(number{})(number{}), + bit_cast( + a_warp_tensor(number{})), + bit_cast( + b_warp_tensor_ping(number{})(number{})), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 969cddf3e7..4d76ab7da2 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -255,9 +255,11 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatBytesDramTileDistribution() { - using TileShape = typename Problem::BlockGemmShape; + using TileShape = typename Problem::BlockGemmShape; + using BDataType = remove_cvref_t; + constexpr index_t BPack = numeric_traits::PackedSize; static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); @@ -282,21 +284,56 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy tile_distribution_encoding< // sequence, tuple, // 4 2 - sequence>, // 1 64 32 + sequence>, // 1 64 32 tuple, sequence<2>>, tuple, sequence<1>>, sequence<2>, sequence<2>>, tile_distribution_encoding< // sequence, - tuple, // 4 2 - sequence>, // 2 1 64 16 + tuple, // 4 2 + sequence>, // 2 1 64 16 tuple, sequence<2>>, tuple, sequence<2>>, sequence<2, 2>, sequence<0, 3>>>{}); } + template + CK_TILE_HOST_DEVICE static constexpr auto + MakeMX_BFlatBytesDramWindow(const WindowTmp& window_tmp) + { + + using BDataType = remove_cvref_t; + constexpr auto BPackedSize = numeric_traits::PackedSize; + constexpr auto kKPerBlock = Problem::BlockGemmShape::kK; + constexpr auto M_Warp_Tile = Problem::BlockGemmShape::WarpTile::at(I1); + constexpr auto flatNPerWarp = Problem::BlockGemmShape::flatNPerWarp; + constexpr auto flatKPerWarp = Problem::BlockGemmShape::flatKPerWarp; + + static_assert(std::decay_t::get_num_of_dimension() == 2); + auto&& tensor_view_tmp = window_tmp.get_bottom_tensor_view(); + const auto [flat_n, flat_k] = tensor_view_tmp.get_tensor_descriptor().get_lengths(); + constexpr auto flat_k_per_block = kKPerBlock * M_Warp_Tile; + auto&& byte_tensor_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple( + flat_n, flat_k / flat_k_per_block, number{})), + make_tuple(make_pass_through_transform(flat_n), + make_merge_transform_v3_division_mod(make_tuple( + flat_k / flat_k_per_block, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + auto&& byte_ptr = reinterpret_cast(&(tensor_view_tmp.get_buffer_view()(0))); + auto&& byte_tensor_view = + make_tensor_view(byte_ptr, byte_tensor_desc); + auto&& origin_tmp = window_tmp.get_window_origin(); + return make_tile_window( + byte_tensor_view, + make_tuple(number{}, number{}), + {origin_tmp[0], origin_tmp[1] / BPackedSize}, + MakeMX_BFlatBytesDramTileDistribution()); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() {