From db05d611368c8ce935005ead4a7a404390a19afa Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc <253485634+chris-tsiaousis-hpc@users.noreply.github.com> Date: Wed, 3 Jun 2026 14:35:18 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6212 (commit ccee58d) =?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?= =?UTF-8?q?More=20accurate=20tests=20for=20MmaPipelines=20(#6212)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR solves several issues: #### More accurate tests for MmaPipelines The current tests for the MmaPipelines (test_amdgcn_sparse_mma, test_amdgcn_wavewise_mma) use explicit input fragment vectors filled with 1s, and only check the output of a single lane. We should have tests that actually use the MmaPipelines with non-trivial input matrices and verify the complete output. Some other aspects of the current MmaPipelines tests that I noticed and deserve some attention: 1. There is sometimes iteration over K outside of the pipeline, which is then included in WaveTileK or FragK, which is not correct. We should remove it, move K iteration inside of the pipeline, or be more clear about this outer-K loop size and how it propagates downwards. 2. There is very tight coupling between the kernel, gtest code, and test_pipeline helper, requiring a lot of information and functions to be passed back and forth. 3. The test_pipeline helper is doing a bunch of register-related logic on the host (related to point 1) 4. Without this register logic the only thing it does is check the device, call the kernel, and check the output, but with a lot of boilerplate. #### Test helper for detecting target arch at HOST runtime There is a really apparent issue we faced while writing tests: Scenario: 1. Compile a test that supports both gfx950 and gfx1201 for gfx950 2. Run the test on a server that only has gfx1201 GPU Actual: Segmentation fault Expected: The test can correctly detect from HOST runtime that the DEVICE target_id was different and skips the test. Notes: The only way of detecting the COMPILER_TARGET_ID in the existing "arch" framework is launching a kernel and calling `get_compiler_target()` (so, from a DEVICE code). This will create a segmentation fault if the current arch differs from the target arch. To cope with this issue, we propose to export the compiler target(s) (note they can be many) through `projects/composablekernel/test/ck_tile/core/arch/CMakeLists.txt` and define a test helper to deal with such cases. #### Add composition support to Transforms We have a small number of Transforms which act on MmaOp input and output data, before and after the MmaOp call respectively. These are currently implemented to work on an MmaTile level, but in theory they are also supposed to work at a WaveTile level, i.e. after composition of multiple MmaTiles to create larger effective MNK dimensions. Currently the composed MmaTiles look like 2D C-style arrays of the individual MmaTile level register vectors (see WaveWiseMmaPipeline). The transforms should be able to take these and perform the proper transforms to the whole WaveTile at once. This might allow for better performing transformations. Note: This PR handles the SparseTransform case and if we don't end up doing scale as a transformation, there isn't really much left to do. If we end up having only the sparse transform as a non-trivial transform, then we could also consider removing the Transform framework. --- .../example_tile_distr_enc_calc.cpp | 2 +- include/ck_tile/core.hpp | 1 + include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 7 - .../core/arch/mma/mfma/mfma_selector.hpp | 1 + .../core/arch/mma/mfma/mfma_transforms.hpp | 1 + include/ck_tile/core/arch/mma/mma.hpp | 8 + .../ck_tile/core/arch/mma/mma_pipeline.hpp | 39 +- .../ck_tile/core/arch/mma/mma_wavewise.hpp | 2 +- .../arch/mma/scale/scale_mma_pipeline.hpp | 101 ++- .../arch/mma/sparse/sparse_mma_pipeline.hpp | 170 ++++- .../arch/mma/sparse/sparse_transforms.hpp | 99 ++- test/ck_tile/core/arch/mma/CMakeLists.txt | 116 +++- .../arch/mma/get_cmake_targets_helper.hpp | 87 +++ .../core/arch/mma/get_wave_size_helper.hpp | 34 - .../mma/pipeline/pipeline_tests_helper.hpp | 590 ++++++++++++------ .../mma/pipeline/test_amdgcn_mma_pipeline.cpp | 2 +- .../mma/pipeline/test_amdgcn_scale_mma.cpp | 403 +++++++++--- .../mma/pipeline/test_amdgcn_sparse_mma.cpp | 379 +++++++---- .../mma/pipeline/test_amdgcn_wavewise_mma.cpp | 184 ++++-- .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 9 +- 20 files changed, 1646 insertions(+), 589 deletions(-) create mode 100644 include/ck_tile/core/arch/mma/mma.hpp create mode 100644 test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp delete mode 100644 test/ck_tile/core/arch/mma/get_wave_size_helper.hpp diff --git a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp index a491c0d2b9..9e62f6e939 100644 --- a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp +++ b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp @@ -5,7 +5,7 @@ #include #include #include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp" #include "ck_tile/core/container/tuple.hpp" diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 4afba77d6a..2b7066cabf 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -24,6 +24,7 @@ #include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" +#include "ck_tile/core/arch/mma/mma.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_pipeline.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 5985f63440..4cb28762ba 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -393,10 +393,3 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma= 23 #pragma clang diagnostic pop #endif - -// Include the implementations -#include "wmma/wmma.hpp" // should be included before the below headers - -#include "mfma/mfma.hpp" -#include "scale/scale.hpp" -#include "sparse/sparse.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index 2140e3317a..0fa1bada78 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp index 5a3fc9a7e4..9609ed5116 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_transforms.hpp" namespace ck_tile::core::arch::mma { diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp new file mode 100644 index 0000000000..ec38fe78e3 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma.hpp @@ -0,0 +1,8 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "wmma/wmma.hpp" +#include "mfma/mfma.hpp" +#include "sparse/sparse.hpp" diff --git a/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/include/ck_tile/core/arch/mma/mma_pipeline.hpp index de01760620..0994f2d178 100644 --- a/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -270,12 +270,20 @@ struct MmaPipelineBase { if constexpr(MmaOpTraits::IsSupported) { - auto transformed_inputs = applyTransformsToInputs( - hasFlag() ? std::forward(b) - : std::forward(a), - hasFlag() ? std::forward(a) - : std::forward(b), - std::forward(accum)); + constexpr bool swap_a_and_b = hasFlag(); + + auto transformed_inputs = [&]() { + if constexpr(swap_a_and_b) + { + return applyTransformsToInputs( + std::forward(b), std::forward(a), std::forward(accum)); + } + else + { + return applyTransformsToInputs( + std::forward(a), std::forward(b), std::forward(accum)); + } + }(); Derived::execImpl(transformed_inputs); @@ -302,17 +310,18 @@ struct MmaPipelineBase { if constexpr(MmaOpTraits::IsSupported) { - // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly + static_assert(MmaOpTraits::IsScale, + "This exec variant is intended for scale policy structs"); + constexpr bool swap_a_and_b = hasFlag(); + auto transformed_inputs = applyTransformsToInputs( - hasFlag() ? std::forward(b) - : std::forward(a), - hasFlag() ? std::forward(a) - : std::forward(b), + swap_a_and_b ? std::forward(b) : std::forward(a), + swap_a_and_b ? std::forward(a) : std::forward(b), std::forward(accum), - hasFlag() ? std::forward(scale_B) - : std::forward(scale_A), - hasFlag() ? std::forward(scale_A) - : std::forward(scale_B)); + swap_a_and_b ? std::forward(scale_B) + : std::forward(scale_A), + swap_a_and_b ? std::forward(scale_A) + : std::forward(scale_B)); Derived::execImpl(transformed_inputs); diff --git a/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/include/ck_tile/core/arch/mma/mma_wavewise.hpp index 9fbbab411e..bc7e383f6d 100644 --- a/include/ck_tile/core/arch/mma/mma_wavewise.hpp +++ b/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -169,7 +169,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> // clang-format off -struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> +struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> { - using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; + using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; // clang-format on using MmaOp = MmaOp_; // Expose the selected MmaOp - // Expose caller-side vector types - using AVecType = typename MmaOp::AVecType; - using BVecType = typename MmaOp::BVecType; - using CVecType = typename MmaOp::CVecType; + // Fragment dimensions (from the hardware MmaOp) + constexpr static uint32_t FragM = MmaOp::kM; + constexpr static uint32_t FragN = MmaOp::kN; + constexpr static uint32_t FragK = MmaOp::kK; - // Expose internal vector types + // Fragment counts for decomposition + constexpr static uint32_t FragsM = WaveTileM / FragM; + constexpr static uint32_t FragsN = WaveTileN / FragN; + constexpr static uint32_t FragsK = WaveTileK / FragK; + + // Vector types for packed registers in each fragment using InternalAVecT = typename MmaOp::AVecType; using InternalBVecT = typename MmaOp::BVecType; using InternalCVecT = typename MmaOp::CVecType; + // Buffer types for WaveTiles + using AVecType = InternalAVecT[FragsM][FragsK]; + using BVecType = InternalBVecT[FragsN][FragsK]; + using CVecType = InternalCVecT[FragsM][FragsN]; + // Transforms using ATransform = typename MmaTransforms::ATransform; using BTransform = typename MmaTransforms::BTransform; using CTransform = typename MmaTransforms::CTransform; using DTransform = typename MmaTransforms::DTransform; + // Sanity checks + static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM"); + static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN"); + static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK"); + static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); + static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); + static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); + template (MmaPipelineOpt CK_TILE_DEVICE static void execImpl(std::tuple& vecs) { - auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs; - c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B); + auto& [a_frag, b_frag, c_frag, scale_A, scale_B] = vecs; + + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B); + } + } + } + } + else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B); + } + } + } + } + else + { + static_assert(false, "Invalid accumulation policy"); + } } }; diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index d57f544a41..7b2f24dea8 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -5,10 +5,10 @@ #include "ck_tile/core/arch/mma/mma_pipeline.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include -#include namespace ck_tile::core::arch::mma { @@ -20,12 +20,33 @@ constexpr inline int getPipelineFlags() } } // namespace sparse::detail +/** + * @class SparseMmaPipeline + * @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation + * (e.g., smfmac), this class performs fragment-wise (MmaTile) decomposition to matrix-multiply + * input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and accumulates + * results into output WaveTile (C: WaveTileM x WaveTileN). + * Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates + * internally over FragsM × FragsN × FragsK. The A operand is provided in uncompressed form; + * 2:4 structured sparsity compression (SparseCompressTransform) is applied. + * @tparam ADataType Data type of input WaveTile A + * @tparam BDataType Data type of input WaveTile B + * @tparam CDataType Data type of input/output WaveTile C (accumulator) + * @tparam WaveTileM Mma WaveTile M dimension + * @tparam WaveTileN Mma WaveTile N dimension + * @tparam WaveTileK Mma WaveTile K dimension + * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) + * @tparam CompilerTarget The compiler target + * @tparam MmaOp_ Backend wrapper class that will perform the mma op + * @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles + */ template ::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> // clang-format off -struct SparseMmaPipeline : public MmaPipelineBase> +struct SparseMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase>; + using Base = MmaPipelineBase>; // clang-format on static_assert(!Base::template hasFlag(), @@ -52,48 +73,153 @@ struct SparseMmaPipeline : public MmaPipelineBase; static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio; using AVecType = ext_vector_t; }; + using ExternalAFragVecT = typename ExternalAVecCalculator::AVecType; - // Expose caller-side vector types - using AVecType = typename ExternalAVecCalculator::AVecType; - using BVecType = typename MmaOp::BVecType; - using CVecType = typename MmaOp::CVecType; + // Scalar type of A + using AScalarT = typename ExternalAVecCalculator::AVecTraits::scalar_type; - // Expose internal vector types - using InternalAVecT = typename MmaOp::AVecType; + // Per-fragment sizes + static constexpr uint32_t ExternalAFragSize = ExternalAVecCalculator::ASize; + static constexpr uint32_t InternalAFragSize = + vector_traits::vector_size; + + // Full wave-tile sizes (all fragments combined) + static constexpr uint32_t TotalUncompressedElems = FragsM * FragsK * ExternalAFragSize; + static constexpr uint32_t TotalCompressedElems = + TotalUncompressedElems / MmaOp::kCompressionRatio; + + // Variable-length idx type for the whole wave-tile (spans multiple int32_t words if needed) + static constexpr index_t IdxNumWords = sparse::detail::idx_words_needed; + using IdxType = sparse::detail::SparseIdxPack; + + // Per-fragment compressed vector type (for individual MmaOp::exec calls) + using FragAVecT = typename MmaOp::AVecType; + + // Internal vector types used by the base class formatBuffer. + // InternalAVecT matches the full compressed wave-tile so the base class can + // format the SparseCompressTransform result via formatBuffer. + using InternalAVecT = ext_vector_t; using InternalBVecT = typename MmaOp::BVecType; using InternalCVecT = typename MmaOp::CVecType; + // Buffer types for WaveTiles (caller-facing). + // A is a single flat uncompressed vector covering the whole wave-tile. + // The base class compresses it in one pass via ATransform. + using AVecType = ext_vector_t; + using BVecType = InternalBVecT[FragsN][FragsK]; + using CVecType = InternalCVecT[FragsM][FragsN]; + // Transforms using ATransform = typename MmaTransforms::ATransform; using BTransform = typename MmaTransforms::BTransform; using CTransform = typename MmaTransforms::CTransform; using DTransform = typename MmaTransforms::DTransform; + // Sanity checks + static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM"); + static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN"); + static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK"); + static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); + static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); + static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); + template CK_TILE_DEVICE static void - execImpl(std::tuple& vecs) + execImpl(std::tuple& transformedInputs) { + auto& [a, b_frag, c_frag] = transformedInputs; + auto& [a_compressed_whole, idx] = a; + + // Validate that the ATransform result and per-fragment reinterpretation are correct checkATransformResult(); - auto& [a_result, b_vec, c_vec] = vecs; - auto& [a_vec, idx] = a_result; - c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx); + + // Reinterpret the full compressed vector as per-fragment arrays + auto* a_frags = ck_tile::bit_cast(&a_compressed_whole); + + // Accumulation loop with per-fragment idx extraction + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frags[bm][bk], + b_frag[bn][bk], + c_frag[bm][bn], + sparse::detail::extract_fragment_idx( + idx, bm, bk)); + } + } + } + } + else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frags[bm][bk], + b_frag[bn][bk], + c_frag[bm][bn], + sparse::detail::extract_fragment_idx( + idx, bm, bk)); + } + } + } + } + else + { + static_assert(false, "Invalid accumulation policy"); + } } private: - // Type check helper - not a device function, so std::declval is available + // Compile-time validation of ATransform result and per-fragment reinterpretation. + // Ensures the compressed vector returned by ATransform::exec can be safely + // reinterpreted as FragAVecT[FragsM][FragsK] for per-fragment MmaOp dispatch. template static constexpr void checkATransformResult() { using ExternalAvecRef = std::add_lvalue_reference_t; static_assert(std::is_same_v()))>); + decltype(ATransform::exec(std::declval()))>, + "ATransformResult must match the return type of ATransform::exec"); + + using CompressedVecType = + std::remove_reference_t>; + static_assert(sizeof(CompressedVecType) == sizeof(FragAVecT) * FragsM * FragsK, + "Compressed A vector size must equal sizeof(FragAVecT[FragsM][FragsK])"); + + static_assert(alignof(CompressedVecType) >= alignof(FragAVecT), + "Compressed vector alignment must be >= FragAVecT alignment " + "for safe reinterpret_cast to per-fragment array"); + + using ActualIdxType = std::tuple_element_t<1, ATransformResult>; + static_assert(std::is_same_v, + "Sparsity index type must match SparseIdxPack"); } }; diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index 4b0effc2bf..f89a062240 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -13,6 +13,30 @@ namespace ck_tile::core::arch::mma { namespace sparse::detail { + +/// Number of int32_t words needed to store CompressedSize 2-bit idx fields. +template +static constexpr index_t idx_words_needed = (CompressedSize * 2 + 31) / 32; + +/** + * @class SparseIdxPack + * @brief Variable-length container for 2:4 structured sparsity index metadata. + * + * Each compressed element produces a 2-bit index field encoding the original + * position (0–3) within its group of 4. When composing multiple MMA fragments + * in M and K dimensions within a WaveTile, the total number of index bits can + * exceed 32. This struct packs the index fields into an array of int32_t words, + * sized at compile time. + * + * @tparam NumWords Number of int32_t words needed to store all index fields. + */ +template +struct SparseIdxPack +{ + static_assert(NumWords > 0, "SparseIdxPack requires at least 1 word"); + int32_t words[NumWords] = {}; +}; + /** * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero * elements into lower part of a_vec to half its effective size. @@ -20,21 +44,29 @@ namespace sparse::detail { * @tparam ADataType The data type of a_vec * @tparam CompressedSize The target compression size * @tparam AVec The vector type of a_vec (deduced) - * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. - * Each field encodes the original position (0–3) of the corresponding - * non‑zero element in the input. If fewer than CompressedSize - * non‑zeros are found, remaining fields default to 2 (see below). + * @return SparseIdxPack containing **CompressedSize** 2‑bit fields packed + * across one or more int32_t words. Each field encodes the original + * position (0–3) of the corresponding non‑zero element in the input. + * If fewer than CompressedSize non‑zeros are found, remaining fields + * default to 2 (see below). */ template -static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) +static CK_TILE_DEVICE auto compress_a_impl(AVec& a_vec) { - // idx holds one 2‑bit index per output element (total CompressedSize entries). + static constexpr index_t NumIdxWords = idx_words_needed; + // idx holds one 2‑bit index per output element (total CompressedSize entries), + // packed across NumIdxWords int32_t words. // It is initialized to the pattern 0b10 for every field. This matches // what the hardware expects when there are fewer than two non‑zero values // in a 4‑element group – the unused output is treated as coming from slot 2. // The loop below will clear and set each field as real non‑zeros are seen. - int32_t idx = 0; - static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); }); + SparseIdxPack idx{}; + static_for<0, CompressedSize, 1>{}([&](auto k) { + constexpr uint32_t bit_pos = static_cast(k) * 2u; + constexpr uint32_t word = bit_pos / 32u; + constexpr uint32_t shift = bit_pos % 32u; + idx.words[word] |= static_cast(2u << shift); + }); static_for<0, CompressedSize / 2, 1>{}([&](auto i) { ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; @@ -45,8 +77,13 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) { nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; // clear the two‑bit field for this output and insert j - idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos))); - idx |= static_cast(j) << (2u * (i * 2 + non_zero_pos)); + const uint32_t field_idx = + static_cast(i) * 2u + static_cast(non_zero_pos); + const uint32_t bit_pos = field_idx * 2u; + const uint32_t word = bit_pos / 32u; + const uint32_t shift = bit_pos % 32u; + idx.words[word] &= ~static_cast(0b11u << shift); + idx.words[word] |= static_cast(static_cast(j) << shift); ++non_zero_pos; } }); @@ -56,6 +93,40 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) return idx; } +/** + * @brief Extract the per-fragment sparsity index from a packed idx pack. + * After whole-wave-tile compression, the returned idx packs 2-bit fields for + * every compressed output element across one or more int32_t words. + * @return A single int32_t with this fragment's 2-bit fields at the + * least-significant positions, suitable for passing to the MMA builtin. + */ +template +static CK_TILE_DEVICE int32_t extract_fragment_idx(const SparseIdxPack& idx, + uint32_t m, + uint32_t k) +{ + static constexpr uint32_t IdxBitsPerFrag = FragCompressedSize * 2; + const auto fragLinearIdx = m * FragsK + k; + const auto totalBitOffset = fragLinearIdx * IdxBitsPerFrag; + const auto wordIdx = totalBitOffset / 32u; + const auto bitInWord = totalBitOffset % 32u; + + uint32_t result = static_cast(idx.words[wordIdx]) >> bitInWord; + + // If fragment bits span a word boundary, stitch in bits from the next word. + // (This is a safety measure; it should not occur when IdxBitsPerFrag is a + // power-of-2 divisor of 32, which is always the case for current MMA ops.) + if constexpr(NumIdxWords > 1) + { + if(bitInWord != 0 && bitInWord + IdxBitsPerFrag > 32u) + { + result |= static_cast(idx.words[wordIdx + 1]) << (32u - bitInWord); + } + } + + return static_cast(result); +} + } // namespace sparse::detail /** @@ -75,15 +146,15 @@ struct SparseCompressTransform static constexpr auto VecN = VecTraits::vector_size; static constexpr index_t CompressedSize = VecN / CompressionRatio; using VecCompressed = ext_vector_t; + using IdxType = + sparse::detail::SparseIdxPack>; static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio"); static_assert(CompressedSize > 0, "CompressedSize must be > 0"); - const auto idx = sparse::detail::compress_a_impl(v); + auto idx = sparse::detail::compress_a_impl(v); - // TODO c++20: Use bit_cast - return std::tuple( - *std::launder(reinterpret_cast(&v)), idx); + return std::tuple(*ck_tile::bit_cast(&v), idx); } }; diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index e65a76c134..7f7817b5bf 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -7,19 +7,109 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx9|gfx120") - add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) +# --------------------------------------------------------------------------- +# Map GPU target strings to hex amdgcn_target_id values (arch.hpp). +# Builds a -DCK_CMAKE_GPU_TARGET_IDS=0xHHHH,... definition that host-side +# test code can consume without launching a device kernel. +# --------------------------------------------------------------------------- +function(_ck_gpu_target_string_to_id TARGET_STR OUT_VAR) + string(TOLOWER "${TARGET_STR}" _tgt) + string(REGEX REPLACE ":.*" "" _tgt "${_tgt}") + # GFX9 + if(_tgt STREQUAL "gfx908") + set(${OUT_VAR} "0x0908" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx90a") + set(${OUT_VAR} "0x090A" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx942") + set(${OUT_VAR} "0x0942" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx950") + set(${OUT_VAR} "0x0950" PARENT_SCOPE) + # GFX10.3 + elseif(_tgt STREQUAL "gfx1030") + set(${OUT_VAR} "0x1030" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1031") + set(${OUT_VAR} "0x1031" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1032") + set(${OUT_VAR} "0x1032" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1033") + set(${OUT_VAR} "0x1033" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1034") + set(${OUT_VAR} "0x1034" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1035") + set(${OUT_VAR} "0x1035" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1036") + set(${OUT_VAR} "0x1036" PARENT_SCOPE) + elseif(_tgt MATCHES "^gfx10-3-generic$") + set(${OUT_VAR} "0x103F" PARENT_SCOPE) + # GFX11 + elseif(_tgt STREQUAL "gfx1100") + set(${OUT_VAR} "0x1100" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1101") + set(${OUT_VAR} "0x1101" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1102") + set(${OUT_VAR} "0x1102" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1103") + set(${OUT_VAR} "0x1103" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1150") + set(${OUT_VAR} "0x1150" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1151") + set(${OUT_VAR} "0x1151" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1152") + set(${OUT_VAR} "0x1152" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1153") + set(${OUT_VAR} "0x1153" PARENT_SCOPE) + elseif(_tgt MATCHES "^gfx11-generic$") + set(${OUT_VAR} "0x11FF" PARENT_SCOPE) + # GFX12 + elseif(_tgt STREQUAL "gfx1200") + set(${OUT_VAR} "0x1200" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1201") + set(${OUT_VAR} "0x1201" PARENT_SCOPE) + elseif(_tgt MATCHES "^gfx12-generic$") + set(${OUT_VAR} "0x12FF" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1250") + set(${OUT_VAR} "0x1250" PARENT_SCOPE) + else() + message(WARNING "_ck_gpu_target_string_to_id: unknown GPU target '${TARGET_STR}', skipping") + set(${OUT_VAR} "" PARENT_SCOPE) + endif() +endfunction() +function(_ck_add_gpu_target_ids_define TARGET_NAME) + get_property(_archs TARGET ${TARGET_NAME} PROPERTY HIP_ARCHITECTURES) + string(REPLACE "," ";" _archs "${_archs}") + set(_hex_ids) + foreach(_tgt IN LISTS _archs) + _ck_gpu_target_string_to_id("${_tgt}" _hex) + if(_hex AND NOT _hex STREQUAL "0x0000") + list(APPEND _hex_ids "${_hex}") + endif() + endforeach() + list(JOIN _hex_ids "," _hex_str) + if(_hex_str) + target_compile_definitions(${TARGET_NAME} PRIVATE "CK_CMAKE_GPU_TARGET_IDS=${_hex_str}") + endif() +endfunction() + +# Convenience: add_gtest_executable + inject CK_CMAKE_GPU_TARGET_IDS +macro(_add_mma_gtest TEST_NAME) + add_gtest_executable(${TEST_NAME} ${ARGN}) + _ck_add_gpu_target_ids_define(${TEST_NAME}) +endmacro() +# --------------------------------------------------------------------------- + +if(GPU_TARGETS MATCHES "gfx9|gfx120") + _add_mma_gtest(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx950") - add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp) + _add_mma_gtest(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp) target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) + _add_mma_gtest(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp) + _add_mma_gtest(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp) target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") @@ -37,45 +127,45 @@ macro(set_mma_test_arch_define target_name) endmacro() if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp) target_compile_options(test_amdgcn_mma_layout_gfx9 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx9) endif() if(GPU_TARGETS MATCHES "gfx908|gfx90a") - add_gtest_executable(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp) target_compile_options(test_amdgcn_mma_layout_gfx908_and_gfx90a PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx908_and_gfx90a) endif() if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") - add_gtest_executable(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp) target_compile_options(test_amdgcn_mma_layout_gfx90a_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx90a_and_higher) endif() if(GPU_TARGETS MATCHES "gfx942|gfx950") - add_gtest_executable(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp) target_compile_options(test_amdgcn_mma_layout_gfx942_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS} -Wno-header-hygiene) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx942_and_higher) endif() if(GPU_TARGETS MATCHES "gfx950") - add_gtest_executable(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp) target_compile_options(test_amdgcn_mma_layout_gfx950 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx950) endif() if(GPU_TARGETS MATCHES "gfx11") - add_gtest_executable(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp) target_compile_options(test_amdgcn_mma_layout_gfx11 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx120") - add_gtest_executable(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp) target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() -add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) +_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp b/test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp new file mode 100644 index 0000000000..eeac607a1b --- /dev/null +++ b/test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include + +namespace ck_tile::core::arch::testing { + +static CK_TILE_HOST auto getCMakeGpuTargetIds() +{ + using ck_tile::core::arch::amdgcn_target_id; +#ifdef CK_CMAKE_GPU_TARGET_IDS + constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS}; + std::unordered_set result; + for(auto id : ids) + result.insert(static_cast(id)); + return result; +#else + return std::unordered_set{}; +#endif +} + +template +static CK_TILE_HOST bool dispatchCompilerTarget(ck_tile::core::arch::amdgcn_target_id id, + Func&& func) +{ + using namespace ck_tile::core::arch; + + // clang-format off + switch(id) + { + case amdgcn_target_id::GFX908: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX90A: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX942: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX950: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX1030: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1031: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1032: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1033: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1034: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1035: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1036: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX103_GENERIC: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1100: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1101: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1102: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1103: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1150: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1151: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1152: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1153: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX11_GENERIC: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1200: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::GFX1201: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::GFX12_GENERIC: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::GFX1250: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::HOST: return false; + } + // clang-format on + __builtin_unreachable(); +} + +static CK_TILE_HOST constexpr int32_t getCMakeWaveSize() +{ + using ck_tile::core::arch::amdgcn_target_id; +#ifdef CK_CMAKE_GPU_TARGET_IDS + constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS}; + constexpr index_t targets_size = sizeof(ids) / sizeof(ids[0]); + static_assert(targets_size > 0); + constexpr auto first_target_id = static_cast(ids[0]); + if constexpr(first_target_id >= amdgcn_target_id::GFX908 && + first_target_id <= amdgcn_target_id::GFX950) + { + return 64; + } + else + { + return 32; + } +#else + static_assert(false, "Configure CK_CMAKE_GPU_TARGET_IDS before calling this function."); + return 0; +#endif +} +} // namespace ck_tile::core::arch::testing diff --git a/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp b/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp deleted file mode 100644 index 84a3f955e5..0000000000 --- a/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core/arch/arch.hpp" -#include -#include "ck_tile/host/hip_check_error.hpp" - -namespace { - -__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize) -{ - using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target()); - - if(waveSize) - *waveSize = static_cast(CompilerTarget::WAVE_SIZE_ID); -} - -static __host__ uint32_t getDeviceWaveSize() -{ - uint32_t* d_wave_size; - HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t))); - getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - uint32_t wave_size; - HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost)); - return wave_size; -} - -} // namespace diff --git a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp index 8460100aa9..edd5828a5b 100644 --- a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp +++ b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp @@ -10,223 +10,421 @@ #include #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include -#include "../get_wave_size_helper.hpp" +#include "../get_cmake_targets_helper.hpp" -template -struct MmaPipelineTest +namespace mma_pipeline_test { + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; +using namespace ck_tile::core::arch::testing; + +inline bool hipTargetMatchesCmakeTargets(amdgcn_target_id arch) { - using AType = AType_; - using BType = BType_; - using CType = CType_; - using ScaleAType = ScaleAType_; - using ScaleBType = ScaleBType_; - static constexpr auto WaveTileM = WaveTileM_; - static constexpr auto WaveTileN = WaveTileN_; - static constexpr auto WaveTileK = WaveTileK_; - - void test_pipeline(std::function shouldSkip, - std::function kernel, - std::function getExpected, - std::function aInitializer = nullptr) + const auto cmake_targets = getCMakeGpuTargetIds(); + if(cmake_targets.count(arch) == 0) { - using namespace ck_tile; - using namespace ck_tile::core::arch; - - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - if(!hasDevice || shouldSkip(currentArchId)) + // gfx12-generic and gfx11-generic make no difference with the specialized archs. + // Some CI pipelines make use of that and configure the project with the generic + // flags besides compiling for (f.e.) gfx1201. + if(arch >= amdgcn_target_id::GFX1200 && arch <= amdgcn_target_id::GFX12_GENERIC) { - GTEST_SKIP() << "No HIP device found. Skipping test."; + return (cmake_targets.count(amdgcn_target_id::GFX12_GENERIC) > 0); } - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize; - uint32_t BElements = FragN * FragK / deviceWarpSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - - // Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's - std::vector h_a(AElements); - if(aInitializer) + else if(arch >= amdgcn_target_id::GFX1100 && arch <= amdgcn_target_id::GFX11_GENERIC) { - for(size_t i = 0; i < AElements; ++i) - h_a[i] = aInitializer(i); + return (cmake_targets.count(amdgcn_target_id::GFX11_GENERIC) > 0); } - else + } + return true; +} +template +void reference_matmul(std::vector& C, + const std::vector& A, + const std::vector& B, + uint32_t M, + uint32_t N, + uint32_t K) +{ + for(uint32_t m = 0; m < M; ++m) + { + for(uint32_t n = 0; n < N; ++n) { - std::fill(h_a.begin(), h_a.end(), type_convert(1)); + float acc = 0.0f; + for(uint32_t k = 0; k < K; ++k) + { + acc += type_convert(A[m * K + k]) * type_convert(B[k * N + n]); + } + C[m * N + n] = static_cast(acc); } - std::vector h_b(BElements, type_convert(1)); - std::vector h_c(CElements, type_convert(0)); - std::vector h_out(CElements, type_convert(0)); + } +} - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; +template +T deterministic_value(uint32_t row, uint32_t col, uint32_t minor_dim) +{ + float v = static_cast((row * minor_dim + col) % 7 + 1) * 0.25f; + return type_convert(v); +} - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - kernel(wave_size, d_a, d_b, d_c, d_out); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Verify output against expected value for all elements - for(size_t i = 0; i < CElements; ++i) +// Apply 2:4 sparsity pattern to A matrix in-place (for sparse pipeline tests). +// Every group of 4 consecutive K elements keeps slots 0 and 2, zeros slots 1 and 3. +template +void apply_sparse_pattern(std::vector& A, uint32_t M, uint32_t K) +{ + for(uint32_t m = 0; m < M; ++m) + { + for(uint32_t k = 0; k < K; k += 4) { - EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3); + // Keep slots 0, 2. Zero out slots 1, 3. + if(k + 1 < K) + A[m * K + k + 1] = static_cast(0); + if(k + 3 < K) + A[m * K + k + 3] = static_cast(0); } + } +} - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); +// Fill per-lane A fragments from logical A[M][K] matrix. +// For dense pipelines: AVecType = InternalAVecT[FragsM][FragsK] +// For sparse pipelines: AVecType = ExternalAFragVecT[FragsM][FragsK] (uncompressed) +template +void fill_a_fragments(typename Pipeline::AVecType* a_per_lane, + const std::vector& A_matrix, + uint32_t K, + uint32_t waveSize) +{ + using MmaOp = typename Pipeline::MmaOp; + using ARegMap = TileDistrEncRegMap::AWarpDstrEncoding>; + using AFragScalar = typename vector_traits::scalar_type; + + constexpr uint32_t FragM = Pipeline::FragM; + constexpr uint32_t FragK = Pipeline::FragK; + constexpr uint32_t FragsM = Pipeline::FragsM; + constexpr uint32_t FragsK = Pipeline::FragsK; + + constexpr uint32_t kCompressionRatio = MmaOp::kCompressionRatio; + + // The A register map maps (lane, vec_idx) -> (m_within_frag, k_within_frag) + // For sparse: k_within_frag is in the compressed K domain (K / kCompressionRatio) + constexpr index_t a_vec_size = ARegMap::num_vector_items; + constexpr index_t external_a_frag_vec_size = a_vec_size * kCompressionRatio; + + for(uint32_t lane = 0; lane < waveSize; ++lane) + { + auto* lane_a = reinterpret_cast(&a_per_lane[lane]); + + for(uint32_t bm = 0; bm < FragsM; ++bm) + { + for(uint32_t bk = 0; bk < FragsK; ++bk) + { + uint32_t frag_offset = (bm * FragsK + bk) * external_a_frag_vec_size; + + if constexpr(kCompressionRatio > 1) + { + // Sparse: fill external (uncompressed) vector + for(index_t ev = 0; ev < external_a_frag_vec_size; ++ev) + { + index_t compressed_v = ev / kCompressionRatio; + index_t sub_pos = ev % kCompressionRatio; + + auto coords = + ARegMap::calc_matrix_indices_from_lane_vector(lane, compressed_v); + uint32_t m_local = coords[0]; + uint32_t k_compressed = coords[1]; + uint32_t k_local = k_compressed * kCompressionRatio + sub_pos; + + uint32_t m_global = bm * FragM + m_local; + uint32_t k_global = bk * FragK + k_local; + + lane_a[frag_offset + ev] = + static_cast(A_matrix[m_global * K + k_global]); + } + } + else + { + // Dense/Scale: direct mapping + for(index_t v = 0; v < a_vec_size; ++v) + { + auto coords = ARegMap::calc_matrix_indices_from_lane_vector(lane, v); + uint32_t m_local = coords[0]; + uint32_t k_local = coords[1]; + + uint32_t m_global = bm * FragM + m_local; + uint32_t k_global = bk * FragK + k_local; + + lane_a[frag_offset + v] = + static_cast(A_matrix[m_global * K + k_global]); + } + } + } + } + } +} + +// Fill per-lane B fragments from logical B[K][N] matrix. +// BVecType = InternalBVecT[FragsN][FragsK] +template +void fill_b_fragments(typename Pipeline::BVecType* b_per_lane, + const std::vector& B_matrix, + uint32_t N, + uint32_t waveSize) +{ + using MmaOp = typename Pipeline::MmaOp; + using BRegMap = TileDistrEncRegMap::BWarpDstrEncoding>; + using BFragScalar = typename vector_traits::scalar_type; + + constexpr uint32_t FragN = Pipeline::FragN; + constexpr uint32_t FragK = Pipeline::FragK; + constexpr uint32_t FragsN = Pipeline::FragsN; + constexpr uint32_t FragsK = Pipeline::FragsK; + + constexpr index_t b_vec_size = BRegMap::num_vector_items; + + for(uint32_t lane = 0; lane < waveSize; ++lane) + { + auto* lane_b = reinterpret_cast(&b_per_lane[lane]); + + for(uint32_t bn = 0; bn < FragsN; ++bn) + { + for(uint32_t bk = 0; bk < FragsK; ++bk) + { + uint32_t frag_offset = (bn * FragsK + bk) * b_vec_size; + + for(index_t v = 0; v < b_vec_size; ++v) + { + auto coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v); + uint32_t n_local = coords[0]; + uint32_t k_local = coords[1]; + + uint32_t n_global = bn * FragN + n_local; + uint32_t k_global = bk * FragK + k_local; + + // B matrix is stored as B[K][N] + lane_b[frag_offset + v] = + static_cast(B_matrix[k_global * N + n_global]); + } + } + } + } +} + +// Extract C matrix from per-lane C fragments. +// CVecType = InternalCVecT[FragsM][FragsN] +template +void extract_c_matrix(const typename Pipeline::CVecType* c_per_lane, + std::vector& C_matrix, + uint32_t N, + uint32_t waveSize) +{ + using MmaOp = typename Pipeline::MmaOp; + using CRegMap = TileDistrEncRegMap::CWarpDstrEncoding>; + using CFragScalar = typename vector_traits::scalar_type; + + constexpr uint32_t FragM = Pipeline::FragM; + constexpr uint32_t FragN = Pipeline::FragN; + constexpr uint32_t FragsM = Pipeline::FragsM; + constexpr uint32_t FragsN = Pipeline::FragsN; + + constexpr index_t c_vec_size = CRegMap::num_vector_items; + + for(uint32_t lane = 0; lane < waveSize; ++lane) + { + auto* lane_c = reinterpret_cast(&c_per_lane[lane]); + + for(uint32_t bm = 0; bm < FragsM; ++bm) + { + for(uint32_t bn = 0; bn < FragsN; ++bn) + { + uint32_t frag_offset = (bm * FragsN + bn) * c_vec_size; + + for(index_t v = 0; v < c_vec_size; ++v) + { + auto coords = CRegMap::calc_matrix_indices_from_lane_vector(lane, v); + uint32_t m_local = coords[0]; + uint32_t n_local = coords[1]; + + uint32_t m_global = bm * FragM + m_local; + uint32_t n_global = bn * FragN + n_local; + + C_matrix[m_global * N + n_global] = + static_cast(lane_c[frag_offset + v]); + } + } + } + } +} + +/// Internal: runs the test with a fully resolved Pipeline type. +/// Called from run_pipeline_matrix_test after dispatching on compiler target. +template +void run_pipeline_matrix_test_impl(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t waveSize, + KernelType kernel, + bool isSparse, + bool transposeExpected = false, + float referenceScale = 1.0f) +{ + std::vector A_matrix(M * K); + std::vector B_matrix(K * N); + std::vector C_expected(M * N, static_cast(0)); + std::vector C_actual(M * N, static_cast(0)); + + for(uint32_t m = 0; m < M; ++m) + for(uint32_t k = 0; k < K; ++k) + A_matrix[m * K + k] = deterministic_value(m, k, K); + + for(uint32_t k = 0; k < K; ++k) + for(uint32_t n = 0; n < N; ++n) + B_matrix[k * N + n] = deterministic_value(k, n, N); + + if(isSparse) + { + apply_sparse_pattern(A_matrix, M, K); } - void - test_pipeline(std::function shouldSkip, - std::function kernel, - std::function getExpected, - std::function aInitializer = nullptr) + reference_matmul(C_expected, A_matrix, B_matrix, M, N, K); + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + const size_t a_buf_size = waveSize * sizeof(AVecType); + const size_t b_buf_size = waveSize * sizeof(BVecType); + const size_t c_buf_size = waveSize * sizeof(CVecType); + + std::vector h_a(a_buf_size, 0); + std::vector h_b(b_buf_size, 0); + std::vector h_c(c_buf_size, 0); + + fill_a_fragments(reinterpret_cast(h_a.data()), A_matrix, K, waveSize); + fill_b_fragments(reinterpret_cast(h_b.data()), B_matrix, N, waveSize); + + void *d_a, *d_b, *d_c; + HIP_CHECK_ERROR(hipMalloc(&d_a, a_buf_size)); + HIP_CHECK_ERROR(hipMalloc(&d_b, b_buf_size)); + HIP_CHECK_ERROR(hipMalloc(&d_c, c_buf_size)); + + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), a_buf_size, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), b_buf_size, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemset(d_c, 0, c_buf_size)); + + ck_tile::launch_kernel(ck_tile::stream_config{}, + ck_tile::make_kernel(kernel, dim3(1), dim3(waveSize), 0, d_a, d_b, d_c)); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_c.data(), d_c, c_buf_size, hipMemcpyDeviceToHost)); + extract_c_matrix( + reinterpret_cast(h_c.data()), C_actual, N, waveSize); + + for(uint32_t m = 0; m < M; ++m) { - using namespace ck_tile; - using namespace ck_tile::core::arch; - - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - if(!hasDevice || shouldSkip(currentArchId)) + for(uint32_t n = 0; n < N; ++n) { - GTEST_SKIP() << "No HIP device found. Skipping test."; + // When transposeExpected is true, the kernel computes C^T via SwapAB, + // so compare actual C[m][n] against reference C[n][m]. + constexpr float relative_tolerance = 1e-2f; + constexpr float absolute_tolerance = 1e-3f; + + float expected = transposeExpected ? static_cast(C_expected[n * M + m]) + : static_cast(C_expected[m * N + n]); + expected *= referenceScale; + float actual = static_cast(C_actual[m * N + n]); + EXPECT_NEAR( + actual, expected, std::abs(expected) * relative_tolerance + absolute_tolerance) + << "Mismatch at C[" << m << "][" << n << "]"; } - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits::PackedSize; - uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits::PackedSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - uint32_t ScaleASize = 1 * sizeof(ScaleAType); - uint32_t ScaleBSize = 1 * sizeof(ScaleBType); - - // Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's - std::vector h_a(AElements); - if(aInitializer) - { - for(size_t i = 0; i < AElements; ++i) - h_a[i] = aInitializer(i); - } - else - { - std::fill(h_a.begin(), h_a.end(), type_convert(1.0f)); - } - std::vector h_b(BElements, type_convert(1.0f)); - std::vector h_c(CElements, type_convert(0.0f)); - std::vector h_out(CElements, type_convert(0.0f)); - // The actual scale is computed as pow(2, scale - 127), so: - // 126 -> 2^-1 and 129 -> 2^2. - ScaleAType h_scale_a = 126; - ScaleBType h_scale_b = 129; - - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; - ScaleAType* d_scale_a; - ScaleBType* d_scale_b; - - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize)); - HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Verify output against expected value for all elements - for(size_t i = 0; i < CElements; ++i) - { - EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3); - } - - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); - HIP_CHECK_ERROR(hipFree(d_scale_a)); - HIP_CHECK_ERROR(hipFree(d_scale_b)); } -}; + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); +} + +/// @tparam PipelineFactory A template template that, given a CompilerTarget type, produces +/// the Pipeline type: PipelineFactory::type +/// @tparam KernelType Kernel functor struct with kBlockSize and __device__ operator() +/// @tparam AScalar Scalar type for A matrix (e.g., fp16_t) +/// @tparam BScalar Scalar type for B matrix (e.g., fp16_t) +/// @tparam CScalar Scalar type for C matrix (e.g., fp32_t) +/// @param M WaveTile M dimension +/// @param N WaveTile N dimension +/// @param K WaveTile K dimension +/// @param shouldSkip Predicate returning true if current device should skip +/// @param kernel Kernel functor instance to launch via make_kernel +/// @param isSparse Whether to apply 2:4 sparsity pattern to A +/// @param transposeExpected When true, compare against transposed reference (for +/// SwapAB/TransposeC) +/// @param referenceScale Scalar multiplier applied to the reference matmul result before +/// comparison (e.g., to account for scale-MMA scaling factors) +template