mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-24 00:45:34 +00:00
[rocm-libraries] ROCm/rocm-libraries#8259 (commit df03f10)
Add cluster launch in test ck_tile mx gemm tdm wmma ## Motivation Add cluster launch test in test_ck_tile_mx_gemm_pipeline_tdm_wmma on gfx1250, so that we can check the performance on gfx1250 hardware. ## Technical Details Added Out-of-bounds guard in RunGemm of MxGemmKernel to skip blocks padded by cluster alignment. Add ClusterEnable/ClusterDisable aliases and extend the tuple in test_mx_gemm_pipeline_kernel_types.hpp by adding two kernel types with ClusterEnable for F8 CompTDMV1 and CompTDMV2 respectively. The existing F4 non-ClusterLaunch kernel types have issue to be fixed, so this PR does not include F4 cases. Read ClusterLaunch from the tuple in test_mx_gemm_pipeline_util.hpp. Update invoke_mx_gemm to branch on ClusterLaunch, including Add cluster size constants, Switch GemmShape type, TilePartitioner type, and the kernel launch call. ## Test Plan Tested the changes on gfx1250 FFM. ## Test Result The added kernel types (instances) passed the tests on gfx1250 FFM. ## Submission Checklist - [x ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
359f664b25
commit
276863ca87
@@ -231,6 +231,13 @@ struct MxGemmKernel
|
||||
bs_scale_ptr[i] = reinterpret_cast<const int32_t*>(kargs.bs_scale_ptr[i]);
|
||||
});
|
||||
|
||||
// cluster launch pads grid to cluster boundaries; skip out-of-bound blocks
|
||||
if constexpr(BaseKernel::ClusterLaunch)
|
||||
{
|
||||
if(block_idx_m >= kargs.M || block_idx_n >= kargs.N)
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& as_block_window = BaseKernel::MakeABlockWindows(
|
||||
as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
|
||||
const auto& bs_block_window = BaseKernel::MakeBBlockWindows(
|
||||
|
||||
@@ -24,6 +24,9 @@ using I64 = ck_tile::number<64>;
|
||||
using I128 = ck_tile::number<128>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
using ClusterEnable = std::true_type;
|
||||
using ClusterDisable = std::false_type;
|
||||
|
||||
// clang-format off
|
||||
// MX GEMM kernel types using TDM pipeline with scale support
|
||||
// Tuple format:
|
||||
@@ -43,6 +46,8 @@ using KernelTypesMxGemmCompTDMWmma = ::testing::Types<
|
||||
std::tuple< Row, Col, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>,
|
||||
std::tuple< Row, Row, Row, BF8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2>,
|
||||
std::tuple< Row, Row, Row, F4, F4, E5M3, E5M3, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>,
|
||||
std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>
|
||||
std::tuple< Col, Row, Row, F4, F8, E5M3, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV1, ClusterEnable>,
|
||||
std::tuple< Row, Col, Row, F8, F8, E8M0, E8M0, F32, F16, I64, I64, I128, I32, I32, Intrawave, CompTDMV2, ClusterEnable>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -191,8 +191,9 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
|
||||
static constexpr bool Persistent = false;
|
||||
static constexpr bool ClusterLaunch = false;
|
||||
static constexpr bool Persistent = false;
|
||||
static constexpr bool ClusterLaunch =
|
||||
ck_tile::tuple_element_or_default_t<Tuple, 16, std::false_type>::value;
|
||||
|
||||
static constexpr ck_tile::index_t ScaleBlockSize = 32;
|
||||
|
||||
@@ -205,6 +206,14 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
// if cluster launch is enabled, set cluster dim to 2x2x1
|
||||
constexpr ck_tile::index_t kClusterSizeM =
|
||||
std::conditional_t<ClusterLaunch, ck_tile::number<2>, ck_tile::number<1>>{};
|
||||
constexpr ck_tile::index_t kClusterSizeN =
|
||||
std::conditional_t<ClusterLaunch, ck_tile::number<2>, ck_tile::number<1>>{};
|
||||
constexpr ck_tile::index_t kClusterSizeK =
|
||||
std::conditional_t<ClusterLaunch, ck_tile::number<1>, ck_tile::number<1>>{};
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
@@ -222,14 +231,27 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
static constexpr bool NumWaveGroup = 1;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
using GemmShape = std::conditional_t<
|
||||
ClusterLaunch,
|
||||
ck_tile::ClusterTileGemmShape<
|
||||
ck_tile::sequence<kClusterSizeM, kClusterSizeN, kClusterSizeK>,
|
||||
ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>,
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape, 8, 4>;
|
||||
using TilePartitioner =
|
||||
std::conditional_t<ClusterLaunch,
|
||||
ck_tile::GemmClusterTilePartitioner<GemmShape>,
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
@@ -305,8 +327,17 @@ class TestCkTileMxGemmPipeline : public ::testing::Test
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
if constexpr(ClusterLaunch)
|
||||
{
|
||||
dim3 clusters = Kernel::ClusterSize();
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, clusters, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user