diff --git a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp index 6de2e6b384..89d07a149a 100644 --- a/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/mx_gemm_kernel.hpp @@ -231,6 +231,13 @@ struct MxGemmKernel bs_scale_ptr[i] = reinterpret_cast(kargs.bs_scale_ptr[i]); }); + // cluster launch pads grid to cluster boundaries; skip out-of-bound blocks + if constexpr(BaseKernel::ClusterLaunch) + { + if(block_idx_m >= kargs.M || block_idx_n >= kargs.N) + return; + } + const auto& as_block_window = BaseKernel::MakeABlockWindows( as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& bs_block_window = BaseKernel::MakeBBlockWindows( diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp index e38606c143..6c16621bb6 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_kernel_types.hpp @@ -24,6 +24,9 @@ using I64 = ck_tile::number<64>; using I128 = ck_tile::number<128>; using I256 = ck_tile::number<256>; +using ClusterEnable = std::true_type; +using ClusterDisable = std::false_type; + // clang-format off // MX GEMM kernel types using TDM pipeline with scale support // Tuple format: @@ -43,6 +46,8 @@ using KernelTypesMxGemmCompTDMWmma = ::testing::Types< std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>, std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>, std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>, - std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1> + std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, ClusterEnable>, + std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, ClusterEnable> >; // clang-format on diff --git a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp index 981d4c1d33..157da722b9 100644 --- a/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp @@ -191,8 +191,9 @@ class TestCkTileMxGemmPipeline : public ::testing::Test using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; - static constexpr bool Persistent = false; - static constexpr bool ClusterLaunch = false; + static constexpr bool Persistent = false; + static constexpr bool ClusterLaunch = + ck_tile::tuple_element_or_default_t::value; static constexpr ck_tile::index_t ScaleBlockSize = 32; @@ -205,6 +206,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; + // if cluster launch is enabled, set cluster dim to 2x2x1 + constexpr ck_tile::index_t kClusterSizeM = + std::conditional_t, ck_tile::number<1>>{}; + constexpr ck_tile::index_t kClusterSizeN = + std::conditional_t, ck_tile::number<1>>{}; + constexpr ck_tile::index_t kClusterSizeK = + std::conditional_t, ck_tile::number<1>>{}; + constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; @@ -222,14 +231,27 @@ class TestCkTileMxGemmPipeline : public ::testing::Test static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; - constexpr int kBlockPerCu = 1; + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; - using GemmShape = + using GemmShape = std::conditional_t< + ClusterLaunch, + ck_tile::ClusterTileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>, ck_tile::TileGemmShape, ck_tile::sequence, - ck_tile::sequence>; + ck_tile::sequence>>; - using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + using TilePartitioner = + std::conditional_t, + ck_tile::GemmSpatiallyLocalTilePartitioner>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits(Kernel{}, grids, blocks, 0, kargs)); + if constexpr(ClusterLaunch) + { + dim3 clusters = Kernel::ClusterSize(); + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, clusters, grids, blocks, 0, kargs)); + } + else + { + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } } public: