From 1bc5480f37ef4726103c186fc2d1f842bb6e36d7 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Fri, 6 Mar 2026 17:20:39 +0100 Subject: [PATCH] Adding layout test for amdgcn_mma structs (#4495) ## Motivation Currently, the test suite for `amdgcn_mma` focuses on the design (e.g. choosing the correct specialization based on SFINAE) and a single live test that checks if selected MmaOp runs. This PR adds a simplified GEMM test kernel that checks the exact layout of the selected MmaOp. ## Technical Details The test in `test_amdgcn_mma_layout.cpp` launches MxKxN test cases (one per block), where each case: 1. Constructs A and B tensors on a device with a single 1 at A(m,k) and B(k,n) (rest is all 0s) 2. Executes the MMA intrinsic. 3. Checks if C has the "1" on the excpeted position. For the MMA instrinsic, it pulls a Mma op from amdgcn_mma specialization based on a given input (tile dimension, data types). Note 1: As a helper, in `test_amdgcn_mma_layout_util.hpp` we add register map for a given amdgcn_mma specialization. Register mapping is currently based on the `tile_distribution_encoding`. Note 2: Everything is added to the test suite, no additions to the actual `amdgcn_mma` structs. All the extra information that is needed, but not yet provided by `amdgcn_mma` structs, is added as a boilerplate to the header. TODO: Rebase this PR on top of the `amdgcn_mma` refactor or clean it up after merge. ## Test Plan This PR solely adds a new test to the existing code. ## Test Result Tests pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Kiefer van Teutem --- test/ck_tile/core/arch/mma/CMakeLists.txt | 8 + .../core/arch/mma/test_amdgcn_mma_layout.cpp | 286 +++++++++++++++++ .../arch/mma/test_amdgcn_mma_layout_util.hpp | 292 ++++++++++++++++++ 3 files changed, 586 insertions(+) create mode 100644 test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp create mode 100644 test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index 77691735bd..964acfb02a 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -17,3 +17,11 @@ if(GPU_TARGETS MATCHES "gfx9") else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() + +if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") + add_gtest_executable(test_amdgcn_mma_layout test_amdgcn_mma_layout.cpp) + target_compile_options(test_amdgcn_mma_layout PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +else() + message(DEBUG "Skipping gfx9|gfx11|gfx12 mma layout validation tests for current target") +endif() + diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp new file mode 100644 index 0000000000..546148be62 --- /dev/null +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.cpp @@ -0,0 +1,286 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/stream_config.hpp" +#include "ck_tile/host/device_memory.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/utility/env.hpp" + +#include "test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck = ck_tile; +namespace mma = ck_tile::core::arch::mma; + +// MMA register layout validation test for amdgcn_mma structs. +// +// Strategy: for every (m, k, n) triple in the tile, the test constructs a pair of input tensors +// A and B that contain exactly one non-zero element each, placed so that their product +// contributes to a single output element C(m, n): +// +// A (M x K) B (K x N) C = A * B (M x N) +// . . . . . . . . . . . . . . . . . . . . . . . . +// . . . . . . . . . . . . . . . . . . . . . . . . +// . . . 1 . . . . . . . . . . . . . . . . . . . . +// . . . . . . . . . . . 1 . . . . . . . . . 1 . . +// . . . . . . . . . . . . . . . . . . . . . . . . +// A(m,k) = 1 B(k,n) = 1 C(m,n) = 1 +// +// The kernel uses RegisterMap to scatter A and B into the correct (lane, vecIdx) positions +// of the MMA fragment registers, executes the intrinsic, then uses RegisterMap again to +// gather back into C matrix. The position of "1" in C is checked against the expected (m, n) +// location. + +namespace { + +/** + * @class MmaLayoutTestKernel + * @brief Device kernel that performs C = AB using a given Mma op + * + * @tparam ADataType Data type of tensor A elements + * @tparam BDataType Data type of tensor B elements + * @tparam CDataType Data type of tensor C elements + * @tparam BlockM M-dimension of the MMA tile + * @tparam BlockN N-dimension of the MMA tile + * @tparam BlockK K-dimension of the MMA tile + * @tparam BlockSize HIP block size + */ +template +struct MmaLayoutTestKernel +{ + static constexpr int kBlockSize = BlockSize; + + __device__ void operator()(uint32_t* error_flags) const + { + using Selector = + mma::MmaDefaultSelector; + using MmaOp = typename Selector::SelectedOp; + using MmaTraits = mma::MmaOpTraits; + + if constexpr(MmaTraits::IsSupported) + { + using AVecType = typename MmaTraits::AVecType; + using BVecType = typename MmaTraits::BVecType; + using CVecType = typename MmaTraits::CVecType; + constexpr uint32_t a_vec_size = vector_traits::vector_size; + constexpr uint32_t b_vec_size = vector_traits::vector_size; + constexpr uint32_t c_vec_size = vector_traits::vector_size; + + const uint32_t lane = threadIdx.x; + + AVecType a_frag{}; + BVecType b_frag{}; + CVecType c_frag{}; + + // get (m, k, n), where "1" should be placed for this block + const uint32_t case_idx = static_cast(blockIdx.x); + const uint32_t m = case_idx / (MmaTraits::BlockK * MmaTraits::BlockN); + const uint32_t k = (case_idx / MmaTraits::BlockN) % MmaTraits::BlockK; + const uint32_t n = case_idx % MmaTraits::BlockN; + + // place a single "1" in A/B fragments using (lane, vecIdx) -> (row, col) mapping + for(uint32_t v = 0; v < a_vec_size; ++v) + { + auto a_coords = RegisterMap::Register2AMap(lane, v); + if(static_cast(a_coords[0]) == m && + static_cast(a_coords[1]) == k) + { + a_frag[v] = static_cast(1); + } + } + + for(uint32_t v = 0; v < b_vec_size; ++v) + { + auto b_coords = RegisterMap::Register2BMap(lane, v); + if(static_cast(b_coords[0]) == n && + static_cast(b_coords[1]) == k) + { + b_frag[v] = static_cast(1); + } + } + + c_frag = MmaOp::exec(a_frag, b_frag, c_frag); + + uint32_t err = 0; + const CDataType tol = static_cast( + 1.0e-1f); // TODO: this tolerance might not be suitable for all data types and + // should be revisited if we add more configurations + for(uint32_t v = 0; v < c_vec_size; ++v) + { + auto c_coords = RegisterMap::Register2CMap(lane, v); + const uint32_t i = static_cast(c_coords[0]); + const uint32_t j = static_cast(c_coords[1]); + + const CDataType expected = + (i == m && j == n) ? static_cast(1) : static_cast(0); + const CDataType value = static_cast(c_frag[v]); + if(fabsf(static_cast(value - expected)) > static_cast(tol)) + { + err = 1; + } + } + + const uint32_t any_err = __any(err); + if(threadIdx.x == 0) + { + error_flags[case_idx] = any_err; + } + } + } +}; + +/** + * @brief Test driver: runs the test for a given MMA configuration. + * + * The testlaunches (mkn) test cases (one per block) to check all possible positions of the "1" in + * the A/B tensors. + * 1. Constructs A and B tensors with a single 1 at A(m,k) and B(k,n). + * 2. Executes MMA intrinsic to compute C tensor. + * 3. Checks if C has the 1 in the expected position. + * + * @tparam Selector Selector for the Mma operation + * @return true if the test ran on hardware; false if skipped (no device or unsupported) + */ +template +bool run_mma_layout_test() +{ + using MmaOp = typename Selector::SelectedOp; + using MmaTraits = mma::MmaOpTraits; + using ADataType = typename MmaTraits::ADataType; + using BDataType = typename MmaTraits::BDataType; + using CDataType = typename MmaTraits::CDataType; + constexpr uint32_t BlockM = MmaTraits::BlockM; + constexpr uint32_t BlockN = MmaTraits::BlockN; + constexpr uint32_t BlockK = MmaTraits::BlockK; + constexpr auto selector_target_id = MmaTraits::CompilerTarget::TARGET_ID; + constexpr auto selector_wave_size = MmaTraits::CompilerTarget::WAVE_SIZE_ID; + + int device_count = 0; + hipDevice_t device{}; + HIP_CHECK_ERROR(hipGetDevice(&device)); + HIP_CHECK_ERROR(hipGetDeviceCount(&device_count)); + + hipDeviceProp_t props{}; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); + + const auto runtime_target = + ck_tile::core::arch::hip_device_prop_gcn_arch_name_to_amdgcn_target_id(props.gcnArchName); + const bool has_device = device_count > 0; + + if(!has_device || runtime_target == ck_tile::core::arch::amdgcn_target_id::HOST || + runtime_target != selector_target_id || + props.warpSize != static_cast(selector_wave_size)) + { + return false; + } + + constexpr uint32_t total_cases = BlockM * BlockK * BlockN; + ck_tile::DeviceMem d_errors(total_cases * sizeof(uint32_t)); + std::vector h_errors(total_cases, 0u); + + auto* d_error_ptr = static_cast(d_errors.GetDeviceBuffer()); + + std::ignore = hipGetLastError(); + + using Kernel = MmaLayoutTestKernel(selector_wave_size)>; + + std::ignore = + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel(Kernel{}, + dim3(total_cases), + dim3(static_cast(selector_wave_size)), + 0, + d_error_ptr)); + + HIP_CHECK_ERROR(hipMemcpyAsync( + h_errors.data(), d_error_ptr, d_errors.GetBufferSize(), hipMemcpyDeviceToHost)); + HIP_CHECK_ERROR(hipStreamSynchronize(nullptr)); + + for(uint32_t case_idx = 0; case_idx < total_cases; ++case_idx) + { + const uint32_t m = case_idx / (BlockK * BlockN); + const uint32_t k = (case_idx / BlockN) % BlockK; + const uint32_t n = case_idx % BlockN; + + EXPECT_EQ(h_errors[case_idx], 0u) << "Mismatch for m=" << m << " k=" << k << " n=" << n; + } + + return true; +} + +} // namespace + +// ==================== Test configurations per target ==================== +// TODO: currently we have only 1 specific target per test. This should be revisited to enable all +// the targets within the family (gfx12, gfx11, gfx9) +using MmaGfx1201CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx12_target< + ck_tile::core::arch::amdgcn_target_id::GFX1201>()); +using MmaGfx90aCompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx9_target< + ck_tile::core::arch::amdgcn_target_id::GFX90A>()); +using MmaGfx1100CompilerTarget = decltype(ck_tile::core::arch::make_amdgcn_gfx11_target< + ck_tile::core::arch::amdgcn_target_id::GFX1100>()); + +using MmaGfx1201Selector = mma:: + MmaDefaultSelector; +using MmaGfx90aSelector = mma:: + MmaDefaultSelector; +using MmaGfx1100Selector = mma:: + MmaDefaultSelector; + +// clang-format off +using KernelTypes = ::testing::Types< + MmaGfx1201Selector, + MmaGfx90aSelector, + MmaGfx1100Selector + >; +// clang-format on + +template +class TestMmaLayout : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(TestMmaLayout, KernelTypes); + +TYPED_TEST(TestMmaLayout, Mma_16x16x16_F16_F16_F32) +{ + bool executed = run_mma_layout_test(); + + if(!executed) + { + GTEST_SKIP() << "No supported HIP device found. Skipping test."; + } +} diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp new file mode 100644 index 0000000000..cb14e1676d --- /dev/null +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout_util.hpp @@ -0,0 +1,292 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" + +#include +#include + +namespace { + +using namespace ck_tile; + +/** + * @class RegisterMapTraits + * @brief Traits class that defines tile_distribution_encoding for each MmaOp + * @tparam MmaOp amdgcn_mma specialization + */ +template +struct RegisterMapTraits +{ + static_assert(sizeof(MmaOp) == 0, "RegisterMapTraits requires a specialization"); +}; + +/** + * @class RegisterMap + * @brief Uses specialized RegisterMapTraits to get the encoding + * @tparam MmaOp amdgcn_mma specialization + */ +template +struct RegisterMap +{ + using Traits = RegisterMapTraits; + + using AMap = core::arch::mma::TileDistrEncRegMap; + using BMap = core::arch::mma::TileDistrEncRegMap; + using CMap = core::arch::mma::TileDistrEncRegMap; + + CK_TILE_HOST_DEVICE static auto Register2AMap(const uint32_t lane, const uint32_t vecIdx) + { + return AMap::calc_matrix_indices_from_lane_vector(static_cast(lane), + static_cast(vecIdx)); + } + + CK_TILE_HOST_DEVICE static auto Register2BMap(const uint32_t lane, const uint32_t vecIdx) + { + return BMap::calc_matrix_indices_from_lane_vector(static_cast(lane), + static_cast(vecIdx)); + } + + CK_TILE_HOST_DEVICE static auto Register2CMap(const uint32_t lane, const uint32_t vecIdx) + { + return CMap::calc_matrix_indices_from_lane_vector(static_cast(lane), + static_cast(vecIdx)); + } +}; + +// ====================== Specializations per target ===================== + +/** + * @brief RegisterMapTraits for GFX12 WMMA 16x16x16_F16_F16_F32_GFX12 + */ +template +struct RegisterMapTraits>> +{ + using MmaOp = ck_tile::core::arch::mma::amdgcn_mma; + + using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; + static constexpr index_t WaveSize = + static_cast(MmaTraits::CompilerTarget::WAVE_SIZE_ID); + static constexpr index_t AVecSize = vector_traits::vector_size; + static constexpr index_t BVecSize = vector_traits::vector_size; + static constexpr index_t CVecSize = vector_traits::vector_size; + + using kABPs2RHssMajor = sequence<2, 1>; + using kABPs2RHssMinor = sequence<1, 0>; + using kABYs2RHsMajor = sequence<2, 2>; + using kABYs2RHsMinor = sequence<0, 2>; + using kCPs2RHssMajor = sequence<1, 2>; + using kCPs2RHssMinor = sequence<1, 0>; + using kCYs2RHsMajor = sequence<1, 1>; + using kCYs2RHsMinor = sequence<0, 2>; + + // TODO: remove these and fix constants in amdgcn_mma + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABK0PerLane = 1; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABK1PerLane = 8; + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 8; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<1>, + tuple, sequence>, // <16>, <1, 2, 8> + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<1>, + tuple, sequence>, // <16>, <1, 2, 8> + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using CWarpDstrEncoding = + tile_distribution_encoding, + tuple, + sequence>, // <1, 2, 8>, <16> + tuple, + tuple, + kCYs2RHsMajor, + kCYs2RHsMinor>; +}; + +/** + * @brief RegisterMapTraits for GFX9 MFMA 16x16x16_F16_F16_F32_GFX9 + */ +template +struct RegisterMapTraits>> +{ + using MmaOp = ck_tile::core::arch::mma::amdgcn_mma; + + using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; + static constexpr index_t WaveSize = + static_cast(MmaTraits::CompilerTarget::WAVE_SIZE_ID); + static constexpr index_t AVecSize = vector_traits::vector_size; + static constexpr index_t BVecSize = vector_traits::vector_size; + static constexpr index_t CVecSize = vector_traits::vector_size; + + using kABPs2RHssMajor = sequence<2, 1>; + using kABPs2RHssMinor = sequence<0, 0>; + using kABYs2RHsMajor = sequence<2>; + using kABYs2RHsMinor = sequence<1>; + using kCPs2RHssMajor = sequence<1, 2>; + using kCPs2RHssMinor = sequence<0, 0>; + using kCYs2RHsMajor = sequence<1>; + using kCYs2RHsMinor = sequence<1>; + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, + tuple, + kCYs2RHsMajor, + kCYs2RHsMinor>; +}; + +/** + * @brief RegisterMapTraits for GFX11 WMMA 16x16x16_F16_F16_F32_GFX11 + */ +template +struct RegisterMapTraits>> +{ + using MmaOp = ck_tile::core::arch::mma::amdgcn_mma; + + using MmaTraits = ck_tile::core::arch::mma::MmaOpTraits; + static constexpr index_t WaveSize = + static_cast(MmaTraits::CompilerTarget::WAVE_SIZE_ID); + static constexpr index_t AVecSize = vector_traits::vector_size; + static constexpr index_t BVecSize = vector_traits::vector_size; + static constexpr index_t CVecSize = vector_traits::vector_size; + + using kABPs2RHssMajor = sequence<0, 1>; + using kABPs2RHssMinor = sequence<0, 0>; + using kABYs2RHsMajor = sequence<2>; + using kABYs2RHsMinor = sequence<0>; + using kCPs2RHssMajor = sequence<1, 2>; + using kCPs2RHssMinor = sequence<1, 0>; + using kCYs2RHsMajor = sequence<1>; + using kCYs2RHsMinor = sequence<0>; + + // TODO: remove these and fix constants in amdgcn_mma + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABK0PerLane = 1; + static constexpr index_t kABKLane = 1; + static constexpr index_t kABK1PerLane = 16; + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 8; + static constexpr index_t kCM1PerLane = 1; + + using AWarpDstrEncoding = + tile_distribution_encoding, // kRepeat + tuple, sequence>, + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using BWarpDstrEncoding = + tile_distribution_encoding, // kRepeat + tuple, sequence>, + tuple, + tuple, + kABYs2RHsMajor, + kABYs2RHsMinor>; + + using CWarpDstrEncoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, + tuple, + kCYs2RHsMajor, + kCYs2RHsMinor>; +}; + +// ======================================================================== + +} // namespace