mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5222 (commit 4fe0911)
[CK_TILE] Fix MMA layout test to match amdgcn_mma OpFamily parameter (#5222) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary - PR #4837 added `MmaOpFamily OpFamily_` as a new template parameter to `amdgcn_mma` and `MmaDefaultSelector`, but the MMA layout test (PR #4495) was not updated to include it - Add the missing `OpFamily_` parameter to all three `RegisterMapTraits` partial specializations (gfx9, gfx11, gfx12) and all `MmaDefaultSelector` usages - Fixes build failure: `template argument for non-type template parameter must be an expression` ## Test plan - [x] Verified test compiles cleanly with ROCm 7.1.1 clang++ targeting gfx90a - [x] `test_amdgcn_mma_layout` gfx90a (MFMA): PASSED - [x] `test_amdgcn_mma_layout` gfx1201 (WMMA): SKIPPED (no device) - [x] `test_amdgcn_mma_layout` gfx1100 (WMMA): SKIPPED (no device) - [x] CI validation on all GPU targets 🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a7b894544e
commit
d7836ff0b2
@@ -80,7 +80,8 @@ struct MmaLayoutTestKernel
|
||||
BlockM,
|
||||
BlockN,
|
||||
BlockK,
|
||||
decltype(ck_tile::core::arch::get_compiler_target())>;
|
||||
decltype(ck_tile::core::arch::get_compiler_target()),
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaOp = typename Selector::SelectedOp;
|
||||
using MmaTraits = mma::MmaOpTraits<MmaOp>;
|
||||
|
||||
@@ -253,12 +254,30 @@ using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_
|
||||
using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target<
|
||||
ck_tile::core::arch::amdgcn_target_id::GFX1100>());
|
||||
|
||||
using MmaGfx1201Selector = mma::
|
||||
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1201CompilerTarget>;
|
||||
using MmaGfx90aSelector = mma::
|
||||
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx90aCompilerTarget>;
|
||||
using MmaGfx1100Selector = mma::
|
||||
MmaDefaultSelector<ck::fp16_t, ck::fp16_t, ck::fp32_t, 16u, 16u, 16u, MmaGfx1100CompilerTarget>;
|
||||
using MmaGfx1201Selector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx1201CompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaGfx90aSelector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx90aCompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
using MmaGfx1100Selector = mma::MmaDefaultSelector<ck::fp16_t,
|
||||
ck::fp16_t,
|
||||
ck::fp32_t,
|
||||
16u,
|
||||
16u,
|
||||
16u,
|
||||
MmaGfx1100CompilerTarget,
|
||||
mma::MmaOpFamily::DENSE>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
|
||||
@@ -68,7 +68,9 @@ struct RegisterMap
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12
|
||||
*/
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
@@ -78,6 +80,7 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx12_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
@@ -87,7 +90,8 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
@@ -147,7 +151,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9
|
||||
*/
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
@@ -157,6 +163,7 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx9_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
@@ -166,7 +173,8 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
@@ -212,7 +220,9 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
/**
|
||||
* @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11
|
||||
*/
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
template <typename CtrlFlags,
|
||||
typename CompilerTarget,
|
||||
ck_tile::core::arch::mma::MmaOpFamily OpFamily_>
|
||||
struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
@@ -222,6 +232,7 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget,
|
||||
OpFamily_,
|
||||
ck_tile::core::arch::enable_if_target_family_gfx11_t<CompilerTarget>>>
|
||||
{
|
||||
using MmaOp = ck_tile::core::arch::mma::amdgcn_mma<ck_tile::fp16_t,
|
||||
@@ -231,7 +242,8 @@ struct RegisterMapTraits<ck_tile::core::arch::mma::amdgcn_mma<
|
||||
16u,
|
||||
16u,
|
||||
CtrlFlags,
|
||||
CompilerTarget>;
|
||||
CompilerTarget,
|
||||
OpFamily_>;
|
||||
|
||||
using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits<MmaOp>;
|
||||
static constexpr index_t WaveSize =
|
||||
|
||||
Reference in New Issue
Block a user