From 276863ca874bedeee72fa8f46094de085c258aa6 Mon Sep 17 00:00:00 2001 From: jefyang1 <146495389+jefyang1@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:33:11 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#8259 (commit df03f10) Add cluster launch in test ck_tile mx gemm tdm wmma ## Motivation Add cluster launch test in test_ck_tile_mx_gemm_pipeline_tdm_wmma on gfx1250, so that we can check the performance on gfx1250 hardware. ## Technical Details Added Out-of-bounds guard in RunGemm of MxGemmKernel to skip blocks padded by cluster alignment. Add ClusterEnable/ClusterDisable aliases and extend the tuple in test_mx_gemm_pipeline_kernel_types.hpp by adding two kernel types with ClusterEnable for F8 CompTDMV1 and CompTDMV2 respectively. The existing F4 non-ClusterLaunch kernel types have issue to be fixed, so this PR does not include F4 cases. Read ClusterLaunch from the tuple in test_mx_gemm_pipeline_util.hpp. Update invoke_mx_gemm to branch on ClusterLaunch, including Add cluster size constants, Switch GemmShape type, TilePartitioner type, and the kernel launch call. ## Test Plan Tested the changes on gfx1250 FFM. ## Test Result The added kernel types (instances) passed the tests on gfx1250 FFM. ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../ops/gemm/kernel/mx_gemm_kernel.hpp | 7 +++ .../test_mx_gemm_pipeline_kernel_types.hpp | 7 ++- .../gemm_mx/test_mx_gemm_pipeline_util.hpp | 47 +++++++++++++++---- 3 files changed, 52 insertions(+), 9 deletions(-) diff --git a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp index 6de2e6b384..89d07a149a 100644 --- a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp @@ -231,6 +231,13 @@ struct MxGemmKernel bs_scale_ptr[i] = reinterpret_cast(kargs.bs_scale_ptr[i]); }); + // cluster launch pads grid to cluster boundaries; skip out-of-bound blocks + if constexpr(BaseKernel::ClusterLaunch) + { + if(block_idx_m >= kargs.M || block_idx_n >= kargs.N) + return; + } + const auto& as_block_window = BaseKernel::MakeABlockWindows( as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& bs_block_window = BaseKernel::MakeBBlockWindows( diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp index e38606c143..6c16621bb6 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp @@ -24,6 +24,9 @@ using I64 = ck_tile::number<64>; using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; +using ClusterEnable = std::true_type; +using ClusterDisable = std::false_type; + // clang-format off // MX GEMM kernel types using TDM pipeline with scale support // Tuple format: @@ -43,6 +46,8 @@ using KernelTypesMxGemmCompTDMWmma = ::testing::Types< std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>, std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>, std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>, - std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1> + std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, ClusterEnable>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, ClusterEnable> >; // clang-format on diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp index 981d4c1d33..157da722b9 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp @@ -191,8 +191,9 @@ class TestCkTileMxGemmPipeline : public ::testing::Test using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; - static constexpr bool Persistent = false; - static constexpr bool ClusterLaunch = false; + static constexpr bool Persistent = false; + static constexpr bool ClusterLaunch = + ck_tile::tuple_element_or_default_t::value; static constexpr ck_tile::index_t ScaleBlockSize = 32; @@ -205,6 +206,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; + // if cluster launch is enabled, set cluster dim to 2x2x1 + constexpr ck_tile::index_t kClusterSizeM = + std::conditional_t, ck_tile::number<1>>{}; + constexpr ck_tile::index_t kClusterSizeN = + std::conditional_t, ck_tile::number<1>>{}; + constexpr ck_tile::index_t kClusterSizeK = + std::conditional_t, ck_tile::number<1>>{}; + constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; @@ -222,14 +231,27 @@ class TestCkTileMxGemmPipeline : public ::testing::Test static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; - using GemmShape = + using GemmShape = std::conditional_t< + ClusterLaunch, + ck_tile::ClusterTileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>, ck_tile::TileGemmShape, ck_tile::sequence, - ck_tile::sequence>; + ck_tile::sequence>>; - using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + using TilePartitioner = + std::conditional_t, + ck_tile::GemmSpatiallyLocalTilePartitioner>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits(Kernel{}, grids, blocks, 0, kargs)); + if constexpr(ClusterLaunch) + { + dim3 clusters = Kernel::ClusterSize(); + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, clusters, grids, blocks, 0, kargs)); + } + else + { + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } } public: