From 9a564f7e5fed79f5090495dd13e3502f3da5a0b0 Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc Date: Thu, 5 Mar 2026 20:52:04 +0100 Subject: [PATCH] [CK TILE] Unification of sparse MFMA/WMMA policy structs (#4837) ## Motivation The existing unification work supports DENSE intrinsics. In this PR we enable support for SPARSE as well as SCALE intrinsics and add an example SPARSE implementation. ## Technical Details Mostly trivial changes. One framework change is that the desired `MmaOpFamily` is passed to the `MmaDefaultSelector`. As my relevant commit explains, we do not support a fallback family at the moment, but it is something we can consider. ## Test Plan Added a new test for the relevant sparse specializations. ## Test Result Test should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Signed-off-by: Chris Tsiaousis --- include/ck_tile/core/arch/arch.hpp | 6 + include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 6 +- .../ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp | 9 +- .../core/arch/mma/mfma/mfma_selector.hpp | 14 +- include/ck_tile/core/arch/mma/mma.hpp | 4 +- .../ck_tile/core/arch/mma/mma_op_family.hpp | 48 +++ .../ck_tile/core/arch/mma/mma_selector.hpp | 16 +- include/ck_tile/core/arch/mma/mma_traits.hpp | 38 ++- .../core/arch/mma/sparse/mfma/selector.hpp | 151 ++++++++++ .../core/arch/mma/sparse/mfma/sparse_gfx9.hpp | 108 +++++++ .../ck_tile/core/arch/mma/sparse/sparse.hpp | 68 +++++ .../core/arch/mma/sparse/sparse_selector.hpp | 7 + .../arch/mma/sparse/sparse_transforms.hpp | 48 +++ .../core/arch/mma/sparse/wmma/selector.hpp | 134 +++++++++ .../arch/mma/sparse/wmma/sparse_gfx12.hpp | 73 +++++ .../ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp | 7 +- .../ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp | 5 +- .../core/arch/mma/wmma/wmma_selector.hpp | 22 +- .../ops/gemm/warp/warp_gemm_smfmac_impl.hpp | 91 +++--- test/ck_tile/core/arch/mma/CMakeLists.txt | 4 + .../core/arch/mma/get_wave_size_helper.hpp | 34 +++ .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 95 ++++-- .../core/arch/mma/test_amdgcn_sparse_mma.cpp | 274 ++++++++++++++++++ 23 files changed, 1173 insertions(+), 89 deletions(-) create mode 100644 include/ck_tile/core/arch/mma/mma_op_family.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/sparse.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp create mode 100644 include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp create mode 100644 test/ck_tile/core/arch/mma/get_wave_size_helper.hpp create mode 100644 test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 40bdb2ff31..5069172386 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -823,6 +823,12 @@ using enable_if_target_wave64_t = #endif // __cplusplus <= 201703L +template +constexpr bool all_types_void = std::conjunction_v...>; + +template +using enable_if_all = std::enable_if_t>; + } // namespace core::arch CK_TILE_HOST bool is_wave32() diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 7801e8ed3c..52943dc2e4 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/ignore.hpp" @@ -82,11 +83,13 @@ template struct amdgcn_mma { // The base instance is unsupported because there is no __builtin to wrap. - using OpType = Unsupported; + using OpType = Unsupported; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::UNDEFINED; // Interface types for A, B, C vectors types using AVecType = ext_vector_t; @@ -122,3 +125,4 @@ struct amdgcn_mma // Include the implementations #include "wmma/wmma.hpp" #include "mfma/mfma.hpp" +#include "sparse/sparse.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp index 7db76b1919..225ceb60f5 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" namespace ck_tile::core::arch::mma { @@ -68,10 +69,12 @@ struct amdgcn_mma> { // Mfma operation type - using OpType = MfmaOp; + using OpType = MfmaOp; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; // Register types using AVecType = ext_vector_t; @@ -125,9 +128,11 @@ struct amdgcn_mma> { - using OpType = MfmaOp; + using OpType = MfmaOp; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; // Packed register types using AVecType = ext_vector_t; diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index 5c87419d0c..051b9d30ff 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -52,7 +52,8 @@ struct MfmaDefaultSelector BlockN, BlockKTest, DefaultMfmaCtrlFlags, // By default, let's assume no special flags for MFMA - CompilerTarget>; + CompilerTarget, + MmaOpFamily::DENSE>; using CandidateTraits = MmaOpTraits; public: @@ -98,7 +99,8 @@ struct MfmaDefaultSelector; + CompilerTarget, + MmaOpFamily::DENSE>; }; /** @@ -114,6 +116,7 @@ struct MfmaDefaultSelector // TODO: c++20 amdgcn_target_arch_id CompilerTarget> + typename CompilerTarget, + MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget> struct MmaDefaultSelector> + OpFamily, + enable_if_all, + std::enable_if_t>> { private: // Provide the default depth-K search strategy for each class of common MFMA shapes. diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp index 2a5de37550..9b38ff9b18 100644 --- a/include/ck_tile/core/arch/mma/mma.hpp +++ b/include/ck_tile/core/arch/mma/mma.hpp @@ -57,6 +57,7 @@ template ::SelectedOp, + CompilerTarget, + OpFamily>::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> struct WaveWiseMma diff --git a/include/ck_tile/core/arch/mma/mma_op_family.hpp b/include/ck_tile/core/arch/mma/mma_op_family.hpp new file mode 100644 index 0000000000..bdbd834ff5 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma_op_family.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @enum MmaOpFamily + * @brief Enumeration that defines mma op families and + */ +enum struct MmaOpFamily +{ + UNDEFINED = 0, + DENSE, + SPARSE, + SCALE, +}; + +/** + * @class is_ctrl_fis_mma_op_of_familylag_of_family + * @brief Meta-function to check if MmaOp is of the specified MmaOpFamily + * @tparam Family Control flag family + * @tparam MmaOp amdgcn struct specialization type + */ +template +struct is_mma_op_of_family : std::false_type +{ +}; + +/** + * @struct is_mma_op_of_family + * @brief Specialization for Family == MmaOp::OpFamily detection + */ +template +struct is_mma_op_of_family> + : std::true_type +{ +}; + +/** + * @brief Convenience evaluator for is_mma_op_of_family trait + * @tparam Family Desired control flag family + * @tparam MmaOp The amdgcn struct specialization type to check + */ +template +static constexpr bool is_mma_op_of_family_v = is_mma_op_of_family::value; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp index eae0f705df..1bb206283b 100644 --- a/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -3,6 +3,7 @@ #pragma once #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" namespace ck_tile::core::arch::mma { @@ -15,11 +16,12 @@ namespace ck_tile::core::arch::mma { * architecture. * @tparam ADataType Data type of matrix A * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator + * @tparam CDataType Data type of the accumulator * @tparam FragM Fragment M dimension * @tparam FragN Fragment N dimension * @tparam FragK Fragment K dimension * @tparam CompilerTarget The compiler target + * @tparam OpFamily The MMA operation family * @tparam Enable SFINAE enabler * @note Here we distinguish that Fragment MNK sizes from Block MNK sizes used in the actual MMA * operation. Fragment sizes correspond to the overall tile size being computed, while Block sizes @@ -34,14 +36,22 @@ template // TODO c++20 requires struct MmaDefaultSelector { // By default, no selection is made, and we fall back to a pass-through unsupported // implementation. This is because we do not have any knowledge of the target architecture here. - using SelectedOp = - amdgcn_mma>; + using SelectedOp = amdgcn_mma, + MmaOpFamily::UNDEFINED>; }; #if CK_TILE_CONCEPTS diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp index 7bcf95ac55..fca2dd058c 100644 --- a/include/ck_tile/core/arch/mma/mma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -1,6 +1,8 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once + +#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "amdgcn_mma.hpp" #include "ck_tile/core/arch/arch.hpp" #include "mfma/mfma_traits.hpp" @@ -69,6 +71,7 @@ concept MmaOpParamsI = requires(MmaOpParams op) { { MmaOpParams::BlockN } -> std::convertible_to; { MmaOpParams::BlockK } -> std::convertible_to; { MmaOpParams::GfxTargetId } -> std::convertible_to; + { MmaOpParams::Family } -> std::convertible_to; }; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER @@ -92,7 +95,8 @@ template + typename CompilerTarget_, + MmaOpFamily OpFamily_> // TODO: c++20 amdgcn_target_arch_id CompilerTarget_> struct MmaOpParams> + CompilerTarget_, + OpFamily_>> { // Capture incoming template parameters - using ADataType = ADataType_; - using BDataType = BDataType_; - using CDataType = CDataType_; - static constexpr uint32_t BlockM = BlockM_; - static constexpr uint32_t BlockN = BlockN_; - static constexpr uint32_t BlockK = BlockK_; - using CtrlFlags = CtrlFlags_; - using CompilerTarget = CompilerTarget_; + using ADataType = ADataType_; + using BDataType = BDataType_; + using CDataType = CDataType_; + static constexpr uint32_t BlockM = BlockM_; + static constexpr uint32_t BlockN = BlockN_; + static constexpr uint32_t BlockK = BlockK_; + using CtrlFlags = CtrlFlags_; + using CompilerTarget = CompilerTarget_; + static constexpr auto MmaOpFamily = OpFamily_; // TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_; }; @@ -131,6 +137,8 @@ struct MmaOpTraits : public MmaOpParams using BVecType = typename MmaOp::BVecType; using CVecType = typename MmaOp::CVecType; + static constexpr MmaOpFamily OpFamily = MmaOp::OpFamily; + // Capture layout parameters static constexpr index_t kAMBlock = MmaOp::kAMBlock; static constexpr index_t kBNBlock = MmaOp::kBNBlock; @@ -144,9 +152,13 @@ struct MmaOpTraits : public MmaOpParams static constexpr index_t kCM1PerLane = MmaOp::kCM1PerLane; // Additional traits to identify the type of MmaOp at compile time - constexpr static bool IsMfma = is_mma_op_mfma_v; - constexpr static bool IsWmma = is_mma_op_wmma_v; - constexpr static bool IsSupported = is_mma_op_supported_v; + constexpr static bool IsMfma = is_mma_op_mfma_v; + constexpr static bool IsWmma = is_mma_op_wmma_v; + constexpr static bool IsDense = OpFamily == MmaOpFamily::DENSE; + constexpr static bool IsSparse = OpFamily == MmaOpFamily::SPARSE; + constexpr static bool IsScale = OpFamily == MmaOpFamily::SCALE; + constexpr static bool IsSupported = + is_mma_op_supported_v && OpFamily != MmaOpFamily::UNDEFINED; }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp new file mode 100644 index 0000000000..92d14a257d --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp @@ -0,0 +1,151 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @class SparseMfmaDefaultSelector + * @brief Implements a default sparse MFMA selector strategy. The SelectedOp can be unsupported. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam BlockM Size of the M dimension + * @tparam BlockN Size of the N dimension + * @tparam BlockKTest Size of the K dimension + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +struct SparseMfmaDefaultSelector +{ + private: + // Define our candidate MFMA implementation for the current parameters + using CandidateOp = amdgcn_mma; + + using CandidateTraits = MmaOpTraits; + + public: + // If the candidate is supported (e.g., a backend implementation exists), then select it. + // Otherwise, fall back to the unsupported pass-through implementation. + using SelectedOp = std::conditional_t, + MmaOpFamily::UNDEFINED>>; +}; + +/** + * @struct MmaDefaultSelector + * @brief Implements the CDNA default MMA selector strategy for sparse MFMA. + * If no supported instruction is found, falls back to an unsupported pass-through implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam FragM Size of the M dimension of the fragment to decompose + * @tparam FragN Size of the N dimension of the fragment to decompose + * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam CompilerTarget The compiler target + * @tparam OpFamily The MMA operation family + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires +struct MmaDefaultSelector, + std::enable_if_t>> +{ + private: + // Provide the default depth-K search strategy for each class of common MFMA shapes. + // Start searching from the largest K dimension MFMA shape down to the smallest. + using CandidateOp16x16 = typename SparseMfmaDefaultSelector::SelectedOp; + using CandidateOp32x32 = typename SparseMfmaDefaultSelector::SelectedOp; + + // Default operation triggers pass-through + using DefaultOp = typename SparseMfmaDefaultSelector::SelectedOp; + + // Traits for each candidate + using CandidateTraits16x16 = MmaOpTraits; + using CandidateTraits32x32 = MmaOpTraits; + + // Check if each candidate is supported for the given fragment sizes + // For this case, we require the fragment sizes to be multiples of the MFMA shape + static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && + (FragM % CandidateTraits16x16::BlockM == 0u) && + (FragN % CandidateTraits16x16::BlockN == 0u) && + (FragK % CandidateTraits16x16::BlockK == 0u); + static constexpr bool IsSupported32x32 = CandidateTraits32x32::IsSupported && + (FragM % CandidateTraits32x32::BlockM == 0u) && + (FragN % CandidateTraits32x32::BlockN == 0u) && + (FragK % CandidateTraits32x32::BlockK == 0u); + + public: + // Select the largest supported MFMA operation for the given fragment shape + using SelectedOp = + std::conditional_t>; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp new file mode 100644 index 0000000000..89fb6688c0 --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -0,0 +1,108 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @struct DefaultSparseMfmaCtrlFlags + * @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is + * 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit. + */ +struct DefaultSparseMfmaCtrlFlags +{ + static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST; +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +#include + +/** + * @concept SparseMfmaCtrlFlags + * @brief Expresses the interface of required members for each CtrlFlags type + */ +template +concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { + // Flag members for sparse MFMA instructions + { CtrlFlags::CompressionIndex } -> std::convertible_to; +}; + +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets + * + * This specialization implements the SMFMA instruction for fp16_t A and B + * matrices with structured sparsity, fp32_t accumulator, with 16x16x32 block sizes. + * + * @tparam CtrlFlags Control flags for the Sparse MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +struct amdgcn_mma< + fp16_t, + fp16_t, + fp32_t, + 16u, + 16u, + 32u, + CtrlFlags, + CompilerTarget, + MmaOpFamily::SPARSE, + std::enable_if_t> +{ + using OpType = MfmaOp; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE; + + static constexpr index_t ABVecN = 8; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t kCompressionRatio = 2; + + CK_TILE_DEVICE static auto + exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + { + static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; + using AVecCompressed = ext_vector_t; + static_assert(CompressedSize == 4); + // TODO: Compressing A on-the-fly should be OK for now, but we need to validate + // and evaluate changing this to a transform at a higher level. + // aVec not being const can cause problems when running multiple intrinsics. + const int32_t idx = ck_tile::compress_a_impl(aVec); + + const AVecCompressed a_vec_pruned = {aVec[0], aVec[1], aVec[2], aVec[3]}; + + using namespace sparse::detail; + static constexpr BuiltinParams PARAMS = getBuiltinParams(); + return {__builtin_amdgcn_smfmac_f32_16x16x32_f16( + a_vec_pruned, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/sparse.hpp b/include/ck_tile/core/arch/mma/sparse/sparse.hpp new file mode 100644 index 0000000000..5adadd371b --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/sparse.hpp @@ -0,0 +1,68 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @enum SparseCompressionIndex + * @brief Indicates which set of sparse-indices within a VGPR starting at srcC + * containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data) + * of index information for a lane. \see DefaultSparseMfmaCtrlFlags + */ +enum struct SparseCompressionIndex : int +{ + FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively + SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively + THIRD = 2, // Uses bits [23:16] + FOURTH = 3, // Uses bits [31:24] +}; + +namespace sparse::detail { + +/** + * @struct BuiltinParams + * @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse + * builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data: + * If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC + * containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected + * (VGPR[srcC][7..0]). + * + * 8-bit source data: + * If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC + * containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected + * (VGPR[srcC][15..0]). + */ +struct BuiltinParams +{ + int UseFirstIndex; // CBSZ + int ByteIndexToOverride; // ABID +}; + +template +static constexpr BuiltinParams getBuiltinParams() +{ + BuiltinParams params; + if constexpr(Idx == SparseCompressionIndex::FIRST) + { + params.UseFirstIndex = 1; + params.ByteIndexToOverride = 0; + } + else + { + params.UseFirstIndex = 0; + params.ByteIndexToOverride = static_cast(Idx); + } + return params; +} + +} // namespace sparse::detail + +} // namespace ck_tile::core::arch::mma + +// Include sparse MFMA traits and architecture-specific implementations +#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp new file mode 100644 index 0000000000..a38bbc460b --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/sparse_selector.hpp @@ -0,0 +1,7 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp new file mode 100644 index 0000000000..7da8f4f616 --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_transforms.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @struct MmaDefaultTransformsSparse + * @brief Implements the default transforms for Sparse + * + * For 2:4 structured sparsity with inline register metadata: + * - ATransform: Pass-through (sparse operands formatted in Exec) TODO! + * - BTransform: Pass-through (sparse operands already formatted) + * - CTransform: Pass-through (input accumulator) + * - DTransform: Pass-through (output accumulator as-is) + */ +struct MmaDefaultTransformsSparse +{ + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; + using CTransform = PassThroughTransform; + using DTransform = PassThroughTransform; +}; + +/** + * @class MmaTransformsDefaultSelector + * @brief Specialization for Sparse MFMA transforms + * Provides default transform selection for sparse operations + * + * @tparam MmaOp Sparse MMA operation + * @tparam CompilerTarget The compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires(is_mma_op_sparse(MmaOp)) +template +struct MmaTransformsDefaultSelector> +{ + using SelectedTransforms = MmaDefaultTransformsSparse; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp new file mode 100644 index 0000000000..802e132083 --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @class SparseWmmaDefaultSelector + * @brief Implements a default sparse WMMA selector strategy. The SelectedOp can be unsupported. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam BlockM Size of the M dimension + * @tparam BlockN Size of the N dimension + * @tparam BlockKTest Size of the K dimension + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires(is_target_arch_rdna(CompilerTarget) && is_power_of_two_integer(BlockKTest)) +struct SparseWmmaDefaultSelector +{ + private: + // Define our candidate WMMA implementation for the current parameters + using CandidateOp = amdgcn_mma; + + using CandidateTraits = MmaOpTraits; + + public: + // If the candidate is supported (e.g., a backend implementation exists), then select it. + // Otherwise, fall back to the unsupported pass-through implementation. + using SelectedOp = std::conditional_t, + MmaOpFamily::UNDEFINED>>; +}; + +/** + * @struct MmaDefaultSelector + * @brief Implements the RDNA default MMA selector strategy for sparse WMMA. + * If no supported instruction is found, falls back to an unsupported pass-through implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam FragM Size of the M dimension of the fragment to decompose + * @tparam FragN Size of the N dimension of the fragment to decompose + * @tparam FragK Size of the K dimension of the fragment to decompose + * @tparam CompilerTarget The compiler target + * @tparam OpFamily The MMA operation family + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires +struct MmaDefaultSelector, + std::enable_if_t>> +{ + private: + // Provide the default depth-K search strategy for each class of common WMMA shapes. + // Start searching from the largest K dimension WMMA shape down to the smallest. + using CandidateOp16x16 = typename SparseWmmaDefaultSelector::SelectedOp; + + // Default operation triggers pass-through + using DefaultOp = typename SparseWmmaDefaultSelector::SelectedOp; + + // Traits for each candidate + using CandidateTraits16x16 = MmaOpTraits; + + // Check if each candidate is supported for the given fragment sizes + // For this case, we require the fragment sizes to be multiples of the WMMA shape + static constexpr bool IsSupported16x16 = CandidateTraits16x16::IsSupported && + (FragM % CandidateTraits16x16::BlockM == 0u) && + (FragN % CandidateTraits16x16::BlockN == 0u) && + (FragK % CandidateTraits16x16::BlockK == 0u); + + public: + // Select the largest supported WMMA operation for the given fragment shape + using SelectedOp = std::conditional_t; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp new file mode 100644 index 0000000000..a1406a7f8c --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -0,0 +1,73 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +namespace ck_tile::core::arch::mma { + +struct DefaultSparseWmmaCtrlFlags +{ +}; + +// TODO: c++20 template +// TODO: c++20 requires +template +struct amdgcn_mma> +{ + using OpType = WmmaOp; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::SPARSE; + + static constexpr index_t ABVecN = 16; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t kCompressionRatio = 2; + + CK_TILE_DEVICE static auto + exec(AVecType& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + { + static constexpr index_t CompressedSize = ABVecN / kCompressionRatio; + using AVecCompressed = ext_vector_t; + static_assert(CompressedSize == 8); + // TODO: Compressing A on-the-fly should be OK for now, but we need to validate + // and evaluate changing this to a transform at a higher level. + // aVec not being const can cause problems when running multiple intrinsics. + const int32_t idx = ck_tile::compress_a_impl(aVec); + + const AVecCompressed a_vec_pruned = { + aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]}; + + return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(a_vec_pruned, bVec, cVec, idx)}; + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp index 58fe51abb7..568a55c659 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" namespace ck_tile::core::arch::mma { @@ -77,10 +78,12 @@ struct amdgcn_mma> + MmaOpFamily::DENSE, + std::enable_if_t()>> { // Wmma operation type - using OpType = WmmaOp; + using OpType = WmmaOp; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; // Register types (duplicated input / b32 accum) using AVecType = ext_vector_t; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp index 71e518d4a3..f047862a06 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" namespace ck_tile::core::arch::mma { @@ -37,10 +38,12 @@ struct amdgcn_mma> { // Wmma operation type - using OpType = WmmaOp; + using OpType = WmmaOp; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; // Register types using AVecType = ext_vector_t; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp index e758ad9a5f..367aa2677f 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -46,7 +46,8 @@ struct WmmaDefaultSelector BlockN, BlockKTest, CtrlFlags, - CompilerTarget>; + CompilerTarget, + MmaOpFamily::DENSE>; using CandidateTraits = MmaOpTraits; @@ -91,8 +92,15 @@ struct WmmaDefaultSelector; // Default unsupported pass-through if no instruction is found - using SelectedOp = - amdgcn_mma; + using SelectedOp = amdgcn_mma; }; /** @@ -108,6 +116,7 @@ struct WmmaDefaultSelector + typename CompilerTarget, + MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget> // TODO: c++20 requires struct MmaDefaultSelector> + OpFamily, + enable_if_all, + std::enable_if_t>> { private: // Provide the default depth-K search strategy for each class of common WMMA shapes. diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp index 3d64e148c4..9b72839755 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -4,8 +4,54 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" namespace ck_tile { +/** + * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + * elements into lower part of a_vec to half its effective size. + * @param a_vec Vector to be compressed. + * @tparam ADataType The data type of a_vec + * @tparam CompressedSize The target compression size + * @tparam AVec The vector type of a_vec (deduced) + * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. + * Each field encodes the original position (0–3) of the corresponding + * non‑zero element in the input. If fewer than CompressedSize + * non‑zeros are found, remaining fields default to 2 (see below). + */ +template +static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) +{ + // idx holds one 2‑bit index per output element (total CompressedSize entries). + // It is initialized to the pattern 0b10 for every field. This matches + // what the hardware expects when there are fewer than two non‑zero values + // in a 4‑element group – the unused output is treated as coming from slot 2. + // The loop below will clear and set each field as real non‑zeros are seen. + int32_t idx = 0; + static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); }); + + static_for<0, CompressedSize / 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + // clear the two‑bit field for this output and insert j + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; +} + template struct WarpGemmSmfmacImpl { @@ -41,37 +87,10 @@ struct WarpGemmSmfmacImpl return WarpGemmAttribute_::get_num_of_access(); } - //---------------------------------------------------------------------------------------------- - /// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero - /// elements into lower part of a_vec to half its effective size. - /// - /// @param a_vec Vector to be compressed. - /// - /// @return Four 2-bit indexes of non-zero elements locations - /// - template - CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const + template + CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec) { - int32_t idx = 0b11101110; - - static_for<0, 2, 1>{}([&](auto i) { - ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; - int32_t non_zero_pos = 0; - - static_for<0, 3, 1>{}([&](auto j) { - if(a_vec[i * 4 + j] != 0.0f) - { - nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; - idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); - idx |= j << 2 * (i * 2 + non_zero_pos); - ++non_zero_pos; - } - }); - a_vec[i * 2] = nonzero_elems[0]; - a_vec[i * 2 + 1] = nonzero_elems[1]; - }); - - return idx; + return compress_a_impl(a_vec); } template @@ -84,10 +103,11 @@ struct WarpGemmSmfmacImpl constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; using AVec = ext_vector_t; - using AVecCompressed = - ext_vector_t; - using BVec = ext_vector_t; - using CVec = ext_vector_t; + static constexpr index_t CompressedSize = + ATensor::get_thread_buffer_size() / CompressionRatio; + using AVecCompressed = ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{}; @@ -95,8 +115,9 @@ struct WarpGemmSmfmacImpl const auto b_vec = b.get_thread_buffer().template get_as()[I0]; auto c_vec = c.get_thread_buffer().template get_as()[I0]; - const int32_t idx = compress_a(a_vec); + const int32_t idx = compress_a_vec(a_vec); + static_assert(CompressedSize == 4); // @TODO can we simply set a_vec_pruned to a_vec[0:3]? const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index f5ecbf7f8b..77691735bd 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -7,6 +7,10 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +if(GPU_TARGETS MATCHES "gfx9|gfx12") + add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp) + target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp b/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp new file mode 100644 index 0000000000..84a3f955e5 --- /dev/null +++ b/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include +#include "ck_tile/host/hip_check_error.hpp" + +namespace { + +__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize) +{ + using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target()); + + if(waveSize) + *waveSize = static_cast(CompilerTarget::WAVE_SIZE_ID); +} + +static __host__ uint32_t getDeviceWaveSize() +{ + uint32_t* d_wave_size; + HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t))); + getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + uint32_t wave_size; + HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost)); + return wave_size; +} + +} // namespace diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index c7093e3477..2ed8b96f19 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -11,6 +11,8 @@ #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "get_wave_size_helper.hpp" + using namespace ck_tile; using namespace ck_tile::core::arch; using namespace ck_tile::core::arch::mma; @@ -47,10 +49,12 @@ struct amdgcn_mma> { // Mfma operation type - using OpType = DummyOpType; + using OpType = DummyOpType; + static constexpr MmaOpFamily OpFamily = MmaOpFamily::DENSE; // Register types using AVecType = ext_vector_t; @@ -81,8 +85,15 @@ struct amdgcn_mma template -using DummyAmdgcnMma = - amdgcn_mma; +using DummyAmdgcnMma = amdgcn_mma; /*! @struct MmaDefaultSelector * @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both @@ -93,6 +104,7 @@ using DummyAmdgcnMma = * @tparam FragN Size of the N dimension of the fragment to decompose * @tparam FragK Size of the K dimension of the fragment to decompose * @tparam CompilerTarget The compiler target + * @tparam OpFamily The MMA operation family */ template + typename CompilerTarget, + MmaOpFamily OpFamily> // TODO: c++20 amdgcn_target_arch_id CompilerTarget> // TODO: requires struct MmaDefaultSelector> + OpFamily, + enable_if_all, + std::enable_if_t>> { using SelectedOp = DummyAmdgcnMma; }; @@ -128,6 +143,8 @@ TEST(TestAmdgcnMma, ArchSupported) // Check OpType EXPECT_TRUE( (std::is_same::value)); // OpType is DummyOpType + // Check OpFamily + EXPECT_TRUE((is_mma_op_of_family_v)); // Check AVecType, BVecType, CVecType EXPECT_TRUE((std::is_same>::value)); @@ -157,6 +174,8 @@ TEST(TestAmdgcnMma, ArchUnsupported) // OpType should be Unsupported EXPECT_TRUE((std::is_same::value)); + // OpFamily should be Undefined + EXPECT_TRUE((is_mma_op_of_family_v)); // AVecType, BVecType, CVecType should match default EXPECT_TRUE((std::is_same>::value)); @@ -367,6 +386,7 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers) EXPECT_TRUE((std::is_same>::value)); EXPECT_TRUE((std::is_same>::value)); EXPECT_TRUE((std::is_same>::value)); + EXPECT_EQ(Traits::OpFamily, MmaOpFamily::UNDEFINED); EXPECT_EQ(Traits::kAMBlock, 0); EXPECT_EQ(Traits::kBNBlock, 0); EXPECT_EQ(Traits::kAMLane, 0); @@ -386,9 +406,14 @@ TEST(TestAmdgcnMma, MmaOpTraitsUnsupportedMembers) TEST(TestAmdgcnMma, MmaDefaultSelectorSupported) { // Direct selection of the supported dummy instruction - using SelectedMma = - typename MmaDefaultSelector:: - SelectedOp; + using SelectedMma = typename MmaDefaultSelector::SelectedOp; // Should select DummyAmdgcnMma specialization EXPECT_TRUE((std::is_same>::value)); // OpType should be DummyOpType @@ -401,8 +426,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupported) TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported) { // Direct selection of the unsupported dummy instruction - using SelectedMma = - MmaDefaultSelector>::SelectedOp; + using SelectedMma = MmaDefaultSelector, + MmaOpFamily::UNDEFINED>::SelectedOp; // OpType should be Unsupported EXPECT_TRUE((std::is_same::value)); // IsSupported should be false @@ -414,9 +445,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupported) TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment) { // Select indirectly with a fragment size of 256x128x64 - using SelectedMma = - MmaDefaultSelector:: - SelectedOp; + using SelectedMma = MmaDefaultSelector::SelectedOp; // Should select DummyAmdgcnMma specialization EXPECT_TRUE((std::is_same>::value)); // OpType should be DummyOpType @@ -429,8 +465,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorSupportedFragment) TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment) { // This should fall back to unsupported since DummyAmdgcnMma only supports 16x16x16 - using SelectedMma = - MmaDefaultSelector::SelectedOp; + using SelectedMma = MmaDefaultSelector::SelectedOp; EXPECT_FALSE((std::is_same::value)); EXPECT_TRUE(MmaOpTraits::IsSupported); } @@ -438,8 +480,14 @@ TEST(TestAmdgcnMma, MmaDefaultSelectorUnsupportedFragment) // Test MmaDefaultSelector for a different data type (fp16_t) and unsupported arch TEST(TestAmdgcnMma, MmaDefaultSelectorFp16Unsupported) { - using SelectedMma = - MmaDefaultSelector>::SelectedOp; + using SelectedMma = MmaDefaultSelector, + MmaOpFamily::UNDEFINED>::SelectedOp; // Should select default amdgcn_mma (Unsupported) EXPECT_TRUE((std::is_same::value)); EXPECT_FALSE(MmaOpTraits::IsSupported); @@ -464,7 +512,8 @@ __global__ void test_accum_over_k(void* a, void* b, void* c, void* out) FragM, FragN, FragK, - decltype(get_compiler_target())>; + decltype(get_compiler_target()), + MmaOpFamily::DENSE>; using MmaOp = typename Selector::SelectedOp; using MmaTraits = MmaOpTraits; @@ -561,8 +610,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real) HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - // Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour - test_accum_over_k<<<1, 64>>>(d_a, d_b, d_c, d_out); + const auto wave_size = getDeviceWaveSize(); + test_accum_over_k + <<<1, wave_size>>>(d_a, d_b, d_c, d_out); HIP_CHECK_ERROR(hipDeviceSynchronize()); HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); @@ -661,8 +711,9 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real) HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - // Need at least 1 WG with 64 threads to get defined MFMA/WMMA behaviour - test_accum_over_k<<<1, 64>>>(d_a, d_b, d_c, d_out); + const auto wave_size = getDeviceWaveSize(); + test_accum_over_k + <<<1, wave_size>>>(d_a, d_b, d_c, d_out); HIP_CHECK_ERROR(hipDeviceSynchronize()); HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp new file mode 100644 index 0000000000..735eac09b0 --- /dev/null +++ b/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp @@ -0,0 +1,274 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include +#include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#include "get_wave_size_helper.hpp" + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target()); + +TEST(SparseMMATrait, SparseMfmaGfx950Specialization) +{ + // Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32) + using TestSparseMfma16x16 = amdgcn_mma; + + static_assert(std::is_same_v && + TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE, + "GFX950 sparse 16x16x32 should have SparseMFMAOp type"); + + static_assert(is_mma_op_of_family_v, + "GFX950 sparse 16x16x32 should be detected as Sparse"); + + std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl; +} + +TEST(SparseMMATrait, MmaOpTraitsIntegration) +{ + // Create a sparse MMA op (16x16x32 fp16 specialization) + using TestSparseMmma = amdgcn_mma; + + // Get its traits + using TestTraits = MmaOpTraits; + + // Verify trait detection + static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse"); + static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported"); + static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA"); + static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA"); + + std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl; +} + +TEST(SparseMMATrait, DenseVsSparseDistinction) +{ + // Dense MFMA from mfma/mfma_gfx9.hpp + using DenseMfma = amdgcn_mma; + + // Sparse MFMA on GFX950 + using SparseMfma = amdgcn_mma; + + // Verify they have different operation types + static_assert(std::is_same_v && + DenseMfma::OpFamily != SparseMfma::OpFamily, + "Dense and Sparse MFMA should have the same OpType tags and different OpFamily"); + + // Verify traits correctly identify them + static_assert(MmaOpTraits::IsMfma && MmaOpTraits::IsDense && + !MmaOpTraits::IsSparse && !MmaOpTraits::IsScale && + MmaOpTraits::IsSupported, + "Dense MFMA should be identified correctly"); + + static_assert(MmaOpTraits::IsSparse && MmaOpTraits::IsMfma && + !MmaOpTraits::IsDense && !MmaOpTraits::IsScale && + MmaOpTraits::IsSupported, + "Sparse MFMA should be identified correctly"); + + std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl; +} + +TEST(SparseMMATrait, SparseSelector) +{ + static_for<1, 33, 1>{}([](auto i) { + using Selected = typename MmaDefaultSelector(i), + static_cast(i), + static_cast(2 * i), + CompilerTargetGfx950, + MmaOpFamily::SPARSE>::SelectedOp; + + static constexpr bool isValid = (i == 16) || (i == 32); + if constexpr(isValid) + { + // Selector should pick a sparse MFMA implementation + static_assert(MmaOpTraits::IsSparse); + static_assert(MmaOpTraits::IsMfma); + static_assert(MmaOpTraits::IsSupported); + static_assert((std::is_same::value)); + } + else + { + // Selector should pick the unsupported pass through + static_assert(!MmaOpTraits::IsSupported); + } + }); +} + +template +__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) +{ + using CompilerTarget = decltype(get_compiler_target()); + using Selector = MmaDefaultSelector; + + using MmaOp = typename Selector::SelectedOp; + using MmaTraits = MmaOpTraits; + + using CVecType = typename MmaOp::CVecType; + + static constexpr uint32_t kIters = FragK / MmaTraits::BlockK; + + // Initialize the accumulator + CVecType result = *reinterpret_cast(c); + + // Accumulate input AxB over FragK/BlockK iterations + for(uint32_t i = 0; i < kIters; ++i) + { + result = MmaOp::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + result); + } + + *reinterpret_cast(out) = result; +} + +// Live test on real hardware for sparse selection and execution. +TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real) +{ + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) && + (currentArchId <= amdgcn_target_id::GFX12_GENERIC); + bool isSupportedMfma = + (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); + // TODO: c++20 add check for arch id + if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) || + !(isSupportedWmma || isSupportedMfma)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + using AType = fp16_t; + using BType = fp16_t; + using CType = fp32_t; + + // Fragment size, also the expected block size from the selector. + // Note: Actual blockK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t FragM = 16; + static constexpr uint32_t FragN = 16; + static constexpr uint32_t FragK = 32; + static constexpr uint32_t BlockM = FragM; + static constexpr uint32_t BlockN = FragN; + static constexpr uint32_t BlockK = FragK; + + // The number of elements per thread + uint32_t AElements = BlockM * BlockK / deviceWarpSize; + uint32_t BElements = BlockN * BlockK / deviceWarpSize; + uint32_t CElements = BlockM * BlockN / deviceWarpSize; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + + // Initialize A and B to all 1's, C to all 0's + std::vector h_a(AElements, static_cast(1)); + std::vector h_b(BElements, static_cast(1)); + std::vector h_c(CElements, static_cast(0)); + std::vector h_out(CElements, static_cast(0)); + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + + const auto wave_size = getDeviceWaveSize(); + test_sparse_accum_over_k + <<<1, wave_size>>>(d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Output should be FragK for all elements, because the inputs are all 1's + for(size_t i = 0; i < CElements; ++i) + { + // In sparse only half of the A values are non-zero, thus the /2. + CType expected = static_cast(FragK) / 2; + + EXPECT_NEAR(h_out[i], expected, 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); +}