diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp index 546148be62..c411aaa8f4 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp @@ -80,7 +80,8 @@ struct MmaLayoutTestKernel BlockM, BlockN, BlockK, - decltype(ck_tile::core::arch::get_compiler_target())>; + decltype(ck_tile::core::arch::get_compiler_target()), + mma::MmaOpFamily::DENSE>; using MmaOp = typename Selector::SelectedOp; using MmaTraits = mma::MmaOpTraits; @@ -253,12 +254,30 @@ using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_ using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target< ck_tile::core::arch::amdgcn_target_id::GFX1100>()); -using MmaGfx1201Selector = mma:: - MmaDefaultSelector; -using MmaGfx90aSelector = mma:: - MmaDefaultSelector; -using MmaGfx1100Selector = mma:: - MmaDefaultSelector; +using MmaGfx1201Selector = mma::MmaDefaultSelector; +using MmaGfx90aSelector = mma::MmaDefaultSelector; +using MmaGfx1100Selector = mma::MmaDefaultSelector; // clang-format off using KernelTypes = ::testing::Types< diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp index cb14e1676d..850435d256 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp @@ -68,7 +68,9 @@ struct RegisterMap /** * @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12 */ -template +template struct RegisterMapTraits>> { using MmaOp = ck_tile::core::arch::mma::amdgcn_mma; + CompilerTarget, + OpFamily_>; using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; static constexpr index_t WaveSize = @@ -147,7 +151,9 @@ struct RegisterMapTraits +template struct RegisterMapTraits>> { using MmaOp = ck_tile::core::arch::mma::amdgcn_mma; + CompilerTarget, + OpFamily_>; using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; static constexpr index_t WaveSize = @@ -212,7 +220,9 @@ struct RegisterMapTraits +template struct RegisterMapTraits>> { using MmaOp = ck_tile::core::arch::mma::amdgcn_mma; + CompilerTarget, + OpFamily_>; using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; static constexpr index_t WaveSize =