[rocm-libraries] ROCm/rocm-libraries#8259 (commit df03f10)

Add cluster launch in test ck_tile mx gemm tdm wmma

## Motivation

Add cluster launch test in test_ck_tile_mx_gemm_pipeline_tdm_wmma on
gfx1250, so that we can check the performance on gfx1250 hardware.

## Technical Details

Added Out-of-bounds guard in RunGemm of MxGemmKernel to skip blocks
padded by cluster alignment.

Add ClusterEnable/ClusterDisable aliases and extend the tuple in
test_mx_gemm_pipeline_kernel_types.hpp by adding two kernel types with
ClusterEnable for F8 CompTDMV1 and CompTDMV2 respectively. The existing
F4 non-ClusterLaunch kernel types have issue to be fixed, so this PR
does not include F4 cases.

Read ClusterLaunch from the tuple in test_mx_gemm_pipeline_util.hpp.

Update invoke_mx_gemm to branch on ClusterLaunch, including Add cluster
size constants, Switch GemmShape type, TilePartitioner type, and the
kernel launch call.

## Test Plan

Tested the changes on gfx1250 FFM.

## Test Result

The added kernel types (instances) passed the tests on gfx1250 FFM.

## Submission Checklist

- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
jefyang1
2026-06-11 17:33:11 +00:00
committed by assistant-librarian[bot]
parent 359f664b25
commit 276863ca87
3 changed files with 52 additions and 9 deletions

View File

@@ -231,6 +231,13 @@ struct MxGemmKernel
bs_scale_ptr[i] = reinterpret_cast<const int32_t*>(kargs.bs_scale_ptr[i]);
});
// cluster launch pads grid to cluster boundaries; skip out-of-bound blocks
if constexpr(BaseKernel::ClusterLaunch)
{
if(block_idx_m >= kargs.M || block_idx_n >= kargs.N)
return;
}
const auto& as_block_window = BaseKernel::MakeABlockWindows(
as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& bs_block_window = BaseKernel::MakeBBlockWindows(

View File

@@ -24,6 +24,9 @@ using I64 = ck_tile::number<64>;
using I128 = ck_tile::number<128>;
using I256 = ck_tile::number<256>;
using ClusterEnable = std::true_type;
using ClusterDisable = std::false_type;
// clang-format off
// MX GEMM kernel types using TDM pipeline with scale support
// Tuple format:
@@ -43,6 +46,8 @@ using KernelTypesMxGemmCompTDMWmma = ::testing::Types<
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>,
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>,
std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>,
std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>
std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>,
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, ClusterEnable>,
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, ClusterEnable>
>;
// clang-format on

View File

@@ -191,8 +191,9 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr bool Persistent = false;
static constexpr bool ClusterLaunch = false;
static constexpr bool Persistent = false;
static constexpr bool ClusterLaunch =
ck_tile::tuple_element_or_default_t<Tuple, 16, std::false_type>::value;
static constexpr ck_tile::index_t ScaleBlockSize = 32;
@@ -205,6 +206,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
// if cluster launch is enabled, set cluster dim to 2x2x1
constexpr ck_tile::index_t kClusterSizeM =
std::conditional_t<ClusterLaunch, ck_tile::number<2>, ck_tile::number<1>>{};
constexpr ck_tile::index_t kClusterSizeN =
std::conditional_t<ClusterLaunch, ck_tile::number<2>, ck_tile::number<1>>{};
constexpr ck_tile::index_t kClusterSizeK =
std::conditional_t<ClusterLaunch, ck_tile::number<1>, ck_tile::number<1>>{};
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
@@ -222,14 +231,27 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
static constexpr bool StructuredSparsity = false;
static constexpr bool NumWaveGroup = 1;
constexpr int kBlockPerCu = 1;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
using GemmShape = std::conditional_t<
ClusterLaunch,
ck_tile::ClusterTileGemmShape<
ck_tile::sequence<kClusterSizeM, kClusterSizeN, kClusterSizeK>,
ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>,
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape, 8, 4>;
using TilePartitioner =
std::conditional_t<ClusterLaunch,
ck_tile::GemmClusterTilePartitioner<GemmShape>,
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
@@ -305,8 +327,17 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
<< blocks.y << ", " << blocks.z << "}" << std::endl;
}
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
if constexpr(ClusterLaunch)
{
dim3 clusters = Kernel::ClusterSize();
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, clusters, grids, blocks, 0, kargs));
}
else
{
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
}
public: