[rocm-libraries] ROCm/rocm-libraries#5222 (commit 4fe0911)

[CK_TILE] Fix MMA layout test to match amdgcn_mma OpFamily
 parameter (#5222)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- PR #4837 added `MmaOpFamily OpFamily_` as a new template parameter to
`amdgcn_mma` and `MmaDefaultSelector`, but the MMA layout test (PR
#4495) was not updated to include it
- Add the missing `OpFamily_` parameter to all three `RegisterMapTraits`
partial specializations (gfx9, gfx11, gfx12) and all
`MmaDefaultSelector` usages
- Fixes build failure: `template argument for non-type template
parameter must be an expression`

## Test plan

- [x] Verified test compiles cleanly with ROCm 7.1.1 clang++ targeting
gfx90a
- [x] `test_amdgcn_mma_layout` gfx90a (MFMA): PASSED
- [x] `test_amdgcn_mma_layout` gfx1201 (WMMA): SKIPPED (no device)
- [x] `test_amdgcn_mma_layout` gfx1100 (WMMA): SKIPPED (no device)
- [x] CI validation on all GPU targets

🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
Christopher Millette
2026-03-09 05:18:50 +00:00
committed by assistant-librarian[bot]
parent a7b894544e
commit d7836ff0b2
2 changed files with 44 additions and 13 deletions

View File

@@ -80,7 +80,8 @@ struct MmaLayoutTestKernel
BlockM,
BlockN,
BlockK,
decltype(ck_tile::core::arch::get_compiler_target())>;
decltype(ck_tile::core::arch::get_compiler_target()),
mma::MmaOpFamily::DENSE>;
using MmaOp = typename Selector::SelectedOp;
using MmaTraits = mma::MmaOpTraits<MmaOp>;
@@ -253,12 +254,30 @@ using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_
using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target<
ck_tile::core::arch::amdgcn_target_id::GFX1100>());
using MmaGfx1201Selector = mma::
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1201CompilerTarget>;
using MmaGfx90aSelector = mma::
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx90aCompilerTarget>;
using MmaGfx1100Selector = mma::
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1100CompilerTarget>;
using MmaGfx1201Selector = mma::MmaDefaultSelector<ck::fp16_t,
ck::fp16_t,
ck::fp32_t,
16u,
16u,
16u,
MmaGfx1201CompilerTarget,
mma::MmaOpFamily::DENSE>;
using MmaGfx90aSelector = mma::MmaDefaultSelector<ck::fp16_t,
ck::fp16_t,
ck::fp32_t,
16u,
16u,
16u,
MmaGfx90aCompilerTarget,
mma::MmaOpFamily::DENSE>;
using MmaGfx1100Selector = mma::MmaDefaultSelector<ck::fp16_t,
ck::fp16_t,
ck::fp32_t,
16u,
16u,
16u,
MmaGfx1100CompilerTarget,
mma::MmaOpFamily::DENSE>;
// clang-format off
using KernelTypes = ::testing::Types<

View File

@@ -68,7 +68,9 @@ struct RegisterMap
/**
* @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12
*/
template <typename CtrlFlags, typename CompilerTarget>
template <typename CtrlFlags,
typename CompilerTarget,
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
ck_tile::fp16_t,
ck_tile::fp16_t,
@@ -78,6 +80,7 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
16u,
CtrlFlags,
CompilerTarget,
OpFamily_,
ck_tile::core::arch::enable_if_target_family_gfx12_t<CompilerTarget>>>
{
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
@@ -87,7 +90,8 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
16u,
16u,
CtrlFlags,
CompilerTarget>;
CompilerTarget,
OpFamily_>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
@@ -147,7 +151,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
/**
* @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9
*/
template <typename CtrlFlags, typename CompilerTarget>
template <typename CtrlFlags,
typename CompilerTarget,
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
ck_tile::fp16_t,
ck_tile::fp16_t,
@@ -157,6 +163,7 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
16u,
CtrlFlags,
CompilerTarget,
OpFamily_,
ck_tile::core::arch::enable_if_target_family_gfx9_t<CompilerTarget>>>
{
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
@@ -166,7 +173,8 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
16u,
16u,
CtrlFlags,
CompilerTarget>;
CompilerTarget,
OpFamily_>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =
@@ -212,7 +220,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
/**
* @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11
*/
template <typename CtrlFlags, typename CompilerTarget>
template <typename CtrlFlags,
typename CompilerTarget,
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
ck_tile::fp16_t,
ck_tile::fp16_t,
@@ -222,6 +232,7 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
16u,
CtrlFlags,
CompilerTarget,
OpFamily_,
ck_tile::core::arch::enable_if_target_family_gfx11_t<CompilerTarget>>>
{
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
@@ -231,7 +242,8 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
16u,
16u,
CtrlFlags,
CompilerTarget>;
CompilerTarget,
OpFamily_>;
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
static constexpr index_t WaveSize =