From 91b7dae95a2452596f53d2372186dbd8e81a21b8 Mon Sep 17 00:00:00 2001 From: Yung-sheng Tu <112800063+yungshengtu@users.noreply.github.com> Date: Mon, 20 Apr 2026 16:28:23 +0200 Subject: [PATCH] [CK TILE] Unification of Scale MFMA/WMMA Policy Structs (#5857) ## Motivation The existing unification work supports DENSE and SPARSE intrinsics. In this PR, we enable support for SCALE intrinsics and add example SCALE implementations. ## Technical Details Adding MFMA SCALE intrinsics support, adding tests for MFMA SCALE intrinsics, and adding WMMA SCALE policy trait. Note: fp6 SCALE intrinsics support is not included in this PR, as its handling in ck_tile is currently more specialized and does not follow the same pattern as other datatypes. ## Test Plan Added new tests for the relevant SCALE specialisations. ## Test Result Test should pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- include/ck_tile/core.hpp | 7 + include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 6 +- .../ck_tile/core/arch/mma/mma_pipeline.hpp | 66 ++++- .../core/arch/mma/scale/mfma/scale_gfx9.hpp | 229 +++++++++++++++ .../core/arch/mma/scale/mfma/selector.hpp | 149 ++++++++++ include/ck_tile/core/arch/mma/scale/scale.hpp | 10 + .../arch/mma/scale/scale_mma_pipeline.hpp | 77 +++++ .../core/arch/mma/scale/scale_selector.hpp | 6 + .../core/arch/mma/scale/scale_traits.hpp | 93 ++++++ .../core/arch/mma/scale/scale_transforms.hpp | 43 +++ test/ck_tile/core/arch/mma/CMakeLists.txt | 4 + .../mma/pipeline/pipeline_tests_helper.hpp | 125 +++++++- .../mma/pipeline/test_amdgcn_scale_mma.cpp | 270 ++++++++++++++++++ .../core/arch/mma/test_amdgcn_mma_layout.inc | 73 +++-- 14 files changed, 1116 insertions(+), 42 deletions(-) create mode 100644 include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp create mode 100644 include/ck_tile/core/arch/mma/scale/mfma/selector.hpp create mode 100644 include/ck_tile/core/arch/mma/scale/scale.hpp create mode 100644 include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp create mode 100644 include/ck_tile/core/arch/mma/scale/scale_selector.hpp create mode 100644 include/ck_tile/core/arch/mma/scale/scale_traits.hpp create mode 100644 include/ck_tile/core/arch/mma/scale/scale_transforms.hpp create mode 100644 test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 3a9309e41e..4085f876c6 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -25,6 +25,13 @@ #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" #include "ck_tile/core/arch/mma/mma_wavewise.hpp" +#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp" +#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp" +#include "ck_tile/core/arch/mma/scale/scale.hpp" +#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp" +#include "ck_tile/core/arch/mma/scale/scale_selector.hpp" +#include "ck_tile/core/arch/mma/scale/scale_traits.hpp" +#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/sparse.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 072ac0bc36..c31aee0e1d 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -245,7 +245,7 @@ concept MmaOpI = requires(MmaOp op) { { MmaOp::kCMPerLane } -> std::convertible_to; { MmaOp::kCMNumAccess } -> std::convertible_to; { MmaOp::kCompressionRatio } -> std::convertible_to; -} && (HasExecSignature || HasExecSignature); +} && (HasExecSignature || HasExecSignature || HasExecSignature); #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER @@ -303,6 +303,8 @@ struct amdgcn_mma : amdgcn_mma_base - CK_TILE_DEVICE static decltype(auto) - applyTransformsToInputs(ATransformInputs&& a, BTransformInputs&& b, CTransformInputs&& accum) + template + CK_TILE_DEVICE static decltype(auto) applyTransformsToInputs(ATransformInputs&& a, + BTransformInputs&& b, + CTransformInputs&& accum, + ExtraArgs&&... extras) { using InternalAVecT = typename Derived::InternalAVecT; using InternalBVecT = typename Derived::InternalBVecT; @@ -224,19 +230,18 @@ struct MmaPipelineBase return std::make_tuple( preApplyTransform(std::forward(a)), preApplyTransform(std::forward(b)), - preApplyTransform(std::forward(accum))); + preApplyTransform(std::forward(accum)), + std::forward(extras)...); } /** * @brief Apply the post-transform and buffer formatting to the C (accumulator) output. - * @param vecs The (A, B, C) tuple after @c execImpl; only C is consumed. + * @param c_result The accumulator to post-process. * @return The final D output in the user-facing vector type. */ - template - CK_TILE_DEVICE static auto - applyTransformToOutput(std::tuple&& vecs) + template + CK_TILE_DEVICE static auto applyTransformToOutput(CTransformResult&& c_result) { - auto&& [a_result, b_result, c_result] = vecs; static_assert(!is_std_tuple_v, "If CTransform returns more than the vector, update this function."); @@ -270,7 +275,46 @@ struct MmaPipelineBase Derived::execImpl(transformed_inputs); - return applyTransformToOutput(std::move(transformed_inputs)); + auto&& [a_result, b_result, c_result] = std::move(transformed_inputs); + return applyTransformToOutput(std::move(c_result)); + } + else + { + // Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp) + // Code should not reach here, but HOST/DEVICE compile passes are + // weirdly intertwined and instead of having constexpr in the calling + // site (tests) we do this. See also changes by this commit. + return Derived::MmaOp::exec({}, {}, {}); + } + } + + template + CK_TILE_DEVICE static decltype(auto) + exec(VecTA&& a, VecTB&& b, VecTC&& accum, ScaleADataType&& scale_A, ScaleBDataType&& scale_B) + { + if constexpr(MmaOpTraits::IsSupported) + { + // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly + auto transformed_inputs = applyTransformsToInputs( + hasFlag() ? std::forward(b) + : std::forward(a), + hasFlag() ? std::forward(a) + : std::forward(b), + std::forward(accum), + hasFlag() ? std::forward(scale_B) + : std::forward(scale_A), + hasFlag() ? std::forward(scale_A) + : std::forward(scale_B)); + + Derived::execImpl(transformed_inputs); + + auto&& [a_result, b_result, c_result, scale_A_result, scale_B_result] = + std::move(transformed_inputs); + return applyTransformToOutput(std::move(c_result)); } else { diff --git a/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp b/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp new file mode 100644 index 0000000000..50bda33229 --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp @@ -0,0 +1,229 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/scale/scale_traits.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +namespace ck_tile::core::arch::mma { + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets + * + * This specialization implements the Scale MFMA instruction for fp8_t A and B + * matrices with fp32_t accumulator, with 16x16x128 block sizes. + * + * @tparam CtrlFlags Control flags for the Scale MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) + { + return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + bit_cast(aVec), + bit_cast(bVec), + cVec, + scale::detail::ScaleDataTypeToFlag_v, + scale::detail::ScaleDataTypeToFlag_v, + static_cast(CtrlFlags::OPSEL_A), + scale_A, + static_cast(CtrlFlags::OPSEL_B), + scale_B)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets + * + * This specialization implements the Scale MFMA instruction for bf8_t A and B + * matrices with fp32_t accumulator, with 16x16x128 block sizes. + * + * @tparam CtrlFlags Control flags for the Scale MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) + { + return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + bit_cast(aVec), + bit_cast(bVec), + cVec, + scale::detail::ScaleDataTypeToFlag_v, + scale::detail::ScaleDataTypeToFlag_v, + static_cast(CtrlFlags::OPSEL_A), + scale_A, + static_cast(CtrlFlags::OPSEL_B), + scale_B)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets + * + * This specialization implements the Scale MFMA instruction for pk_fp4_t A and B + * matrices with fp32_t accumulator, with 16x16x128 block sizes. + * + * @tparam CtrlFlags Control flags for the Scale MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) + { + return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + bit_cast(aVec), + bit_cast(bVec), + cVec, + scale::detail::ScaleDataTypeToFlag_v, + scale::detail::ScaleDataTypeToFlag_v, + static_cast(CtrlFlags::OPSEL_A), + scale_A, + static_cast(CtrlFlags::OPSEL_B), + scale_B)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets + * + * This specialization implements the Scale MFMA instruction for fp8_t A and B + * matrices with fp32_t accumulator, with 32x32x64 block sizes. + * + * @tparam CtrlFlags Control flags for the Scale MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) + { + return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + bit_cast(aVec), + bit_cast(bVec), + cVec, + scale::detail::ScaleDataTypeToFlag_v, + scale::detail::ScaleDataTypeToFlag_v, + static_cast(CtrlFlags::OPSEL_A), + scale_A, + static_cast(CtrlFlags::OPSEL_B), + scale_B)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets + * + * This specialization implements the Scale MFMA instruction for bf8_t A and B + * matrices with fp32_t accumulator, with 32x32x64 block sizes. + * + * @tparam CtrlFlags Control flags for the Scale MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) + { + return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + bit_cast(aVec), + bit_cast(bVec), + cVec, + scale::detail::ScaleDataTypeToFlag_v, + scale::detail::ScaleDataTypeToFlag_v, + static_cast(CtrlFlags::OPSEL_A), + scale_A, + static_cast(CtrlFlags::OPSEL_B), + scale_B)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for Scale MFMA on GFX950 targets + * + * This specialization implements the Scale MFMA instruction for pk_fp4_t A and B + * matrices with fp32_t accumulator, with 32x32x64 block sizes. + * + * @tparam CtrlFlags Control flags for the Scale MFMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) + { + return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + bit_cast(aVec), + bit_cast(bVec), + cVec, + scale::detail::ScaleDataTypeToFlag_v, + scale::detail::ScaleDataTypeToFlag_v, + static_cast(CtrlFlags::OPSEL_A), + scale_A, + static_cast(CtrlFlags::OPSEL_B), + scale_B)}; + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp b/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp new file mode 100644 index 0000000000..b4f2d230ca --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp @@ -0,0 +1,149 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp" +#include "ck_tile/core/arch/mma/scale/scale_traits.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#include +#include + +namespace ck_tile::core::arch::mma { + +/** + * @class ScaleMfmaDefaultSelector + * @brief Implements a default scale MFMA selector strategy. The SelectedOp can be unsupported. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension + * @tparam WaveTileN Size of the N dimension + * @tparam WaveTileKTest Size of the K dimension + * @tparam CompilerTarget The compiler target + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires(is_target_arch_cdna(CompilerTarget) && +// is_power_of_two_integer(WaveTileKTest)) +struct ScaleMfmaDefaultSelector +{ + private: + // Define our candidate MFMA implementation for the current parameters + using CandidateOp = amdgcn_mma; + + public: + // If the candidate is supported (e.g., a backend implementation exists), then select it. + // Otherwise, fall back to the unsupported pass-through implementation. + using SelectedOp = std::conditional_t::IsSupported, + CandidateOp, + amdgcn_mma, + MmaOpFamily::UNDEFINED>>; +}; + +/** + * @struct MmaDefaultSelector + * @brief Implements the CDNA default MMA selector strategy for scale MFMA. + * If no supported instruction is found, falls back to an unsupported pass-through implementation. + * @tparam ADataType Data type of matrix A + * @tparam BDataType Data type of matrix B + * @tparam CDataType Data type of the accumulator + * @tparam WaveTileM Size of the M dimension of the WaveTile to decompose + * @tparam WaveTileN Size of the N dimension of the WaveTile to decompose + * @tparam WaveTileK Size of the K dimension of the WaveTile to decompose + * @tparam CompilerTarget The compiler target + * @tparam OpFamily The MMA operation family + */ +template +// TODO: c++20 amdgcn_target_arch_id CompilerTarget> +// TODO: c++20 requires +struct MmaDefaultSelector, + std::enable_if_t>> +{ + private: + // Provide the default depth-K search strategy for each class of common MFMA shapes. + // Start searching from the largest K dimension MFMA shape down to the smallest. + using CandidateOp16x16 = typename ScaleMfmaDefaultSelector::SelectedOp; + using CandidateOp32x32 = typename ScaleMfmaDefaultSelector::SelectedOp; + + // Default operation triggers pass-through + using DefaultOp = typename ScaleMfmaDefaultSelector::SelectedOp; + + // Check if each candidate is supported for the given fragment sizes + // For this case, we require the fragment sizes to be multiples of the MFMA shape + static constexpr bool IsSupported16x16 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp16x16::kM == 0u) && + (WaveTileN % CandidateOp16x16::kN == 0u) && (WaveTileK % CandidateOp16x16::kK == 0u); + static constexpr bool IsSupported32x32 = + MmaOpTraits::IsSupported && (WaveTileM % CandidateOp32x32::kM == 0u) && + (WaveTileN % CandidateOp32x32::kN == 0u) && (WaveTileK % CandidateOp32x32::kK == 0u); + + public: + // Select the largest supported MFMA operation for the given fragment shape + using SelectedOp = + std::conditional_t>; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/scale/scale.hpp b/include/ck_tile/core/arch/mma/scale/scale.hpp new file mode 100644 index 0000000000..8e6c70a6f7 --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/scale.hpp @@ -0,0 +1,10 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +// Include scale MFMA traits and architecture-specific implementations +#include "ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp" +#include "ck_tile/core/arch/mma/scale/scale_selector.hpp" +#include "ck_tile/core/arch/mma/scale/scale_traits.hpp" +#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp" diff --git a/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp new file mode 100644 index 0000000000..f582c27a13 --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" +#include "ck_tile/core/arch/mma/scale/scale_selector.hpp" +#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp" +#include "ck_tile/core/config.hpp" + +#include +#include +#include +#include + +namespace ck_tile::core::arch::mma { + +template ::SelectedOp, + typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = + typename MmaTransformsDefaultSelector::SelectedTransforms> +// clang-format off +struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> +{ + using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; + // clang-format on + + using MmaOp = MmaOp_; // Expose the selected MmaOp + + // Expose caller-side vector types + using AVecType = typename MmaOp::AVecType; + using BVecType = typename MmaOp::BVecType; + using CVecType = typename MmaOp::CVecType; + + // Expose internal vector types + using InternalAVecT = typename MmaOp::AVecType; + using InternalBVecT = typename MmaOp::BVecType; + using InternalCVecT = typename MmaOp::CVecType; + + // Transforms + using ATransform = typename MmaTransforms::ATransform; + using BTransform = typename MmaTransforms::BTransform; + using CTransform = typename MmaTransforms::CTransform; + using DTransform = typename MmaTransforms::DTransform; + + template + CK_TILE_DEVICE static void + execImpl(std::tuple& vecs) + { + auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs; + c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B); + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/scale/scale_selector.hpp b/include/ck_tile/core/arch/mma/scale/scale_selector.hpp new file mode 100644 index 0000000000..087e813d6d --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/scale_selector.hpp @@ -0,0 +1,6 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/mma/scale/mfma/selector.hpp" diff --git a/include/ck_tile/core/arch/mma/scale/scale_traits.hpp b/include/ck_tile/core/arch/mma/scale/scale_traits.hpp new file mode 100644 index 0000000000..57530ef74c --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/scale_traits.hpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" +// #include "ck_tile/core/numeric/pk_fp6.hpp" + +#include +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +#include +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +namespace ck_tile::core::arch::mma { + +namespace scale::detail { + +template +struct ScaleDataTypeToFlag; + +template <> +struct ScaleDataTypeToFlag // e4m3 +{ + static constexpr std::int32_t value = 0; +}; + +template <> +struct ScaleDataTypeToFlag // e5m2 +{ + static constexpr std::int32_t value = 1; +}; + +// template <> +// struct ScaleDataTypeToFlag> // e2m3 +// { +// static constexpr std::int32_t value = 2; +// }; + +// template <> +// struct ScaleDataTypeToFlag // e3m2 +// { +// static constexpr std::int32_t value = 3; +// }; + +template <> +struct ScaleDataTypeToFlag // e2m1 +{ + static constexpr std::int32_t value = 4; +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +/** + * @concept ScaleMfmaDataTypeToFlag + * @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9 + */ +template +concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) { + // Flag members for scale MFMA instructions + { DataTypeToFlag::value } -> std::convertible_to; +}; + +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +template +inline constexpr std::int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag::value; + +} // namespace scale::detail + +struct DefaultScaleMfmaCtrlFlags +{ + static constexpr std::int32_t OPSEL_A = 0; + static constexpr std::int32_t OPSEL_B = 0; +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +/** + * @concept ScaleMfmaCtrlFlags + * @brief Expresses the interface of required members for each CtrlFlags type on Gfx9 + */ +template +concept ScaleMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { + // Flag members for scale MFMA instructions + { CtrlFlags::OPSEL_A } -> std::convertible_to; + { CtrlFlags::OPSEL_B } -> std::convertible_to; +}; + +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/scale/scale_transforms.hpp b/include/ck_tile/core/arch/mma/scale/scale_transforms.hpp new file mode 100644 index 0000000000..2270011c09 --- /dev/null +++ b/include/ck_tile/core/arch/mma/scale/scale_transforms.hpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_transforms.hpp" + +#include + +namespace ck_tile::core::arch::mma { + +/** + * @struct MmaDefaultTransformsScale + * @brief Implements the default MMA transforms for Scale + */ +struct MmaDefaultTransformsScale +{ + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; + using CTransform = PassThroughTransform; + using DTransform = PassThroughTransform; +}; + +/** + * @struct MmaTransformsDefaultSelector + * @brief Specialization for Scale MFMA transforms + * Provides default transform selection for scale operations + * + * @tparam MmaOp Scale MMA operation + * @tparam CompilerTarget The compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires(is_mma_op_scale(MmaOp)) +template +struct MmaTransformsDefaultSelector> +{ + using SelectedTransforms = MmaDefaultTransformsScale; +}; + +} // namespace ck_tile::core::arch::mma diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index d93de32fea..34b1142cfc 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -11,6 +11,10 @@ if(GPU_TARGETS MATCHES "gfx9|gfx12") add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() +if(GPU_TARGETS MATCHES "gfx950") + add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp) + target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp index a23cf08b1e..8460100aa9 100644 --- a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp +++ b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp @@ -10,23 +10,27 @@ #include #include "ck_tile/core/arch/arch.hpp" -#include -#include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include #include "../get_wave_size_helper.hpp" -template +template struct MmaPipelineTest { using AType = AType_; using BType = BType_; using CType = CType_; + using ScaleAType = ScaleAType_; + using ScaleBType = ScaleBType_; static constexpr auto WaveTileM = WaveTileM_; static constexpr auto WaveTileN = WaveTileN_; static constexpr auto WaveTileK = WaveTileK_; @@ -120,4 +124,109 @@ struct MmaPipelineTest HIP_CHECK_ERROR(hipFree(d_c)); HIP_CHECK_ERROR(hipFree(d_out)); } + + void + test_pipeline(std::function shouldSkip, + std::function kernel, + std::function getExpected, + std::function aInitializer = nullptr) + { + using namespace ck_tile; + using namespace ck_tile::core::arch; + + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + if(!hasDevice || shouldSkip(currentArchId)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + // WaveTile size, also the expected fragment size (MmaTile) from the selector. + // Note: Actual FragK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t FragM = WaveTileM; + static constexpr uint32_t FragN = WaveTileN; + static constexpr uint32_t FragK = WaveTileK; + + // The number of elements per thread + uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits::PackedSize; + uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits::PackedSize; + uint32_t CElements = FragM * FragN / deviceWarpSize; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + uint32_t ScaleASize = 1 * sizeof(ScaleAType); + uint32_t ScaleBSize = 1 * sizeof(ScaleBType); + + // Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's + std::vector h_a(AElements); + if(aInitializer) + { + for(size_t i = 0; i < AElements; ++i) + h_a[i] = aInitializer(i); + } + else + { + std::fill(h_a.begin(), h_a.end(), type_convert(1.0f)); + } + std::vector h_b(BElements, type_convert(1.0f)); + std::vector h_c(CElements, type_convert(0.0f)); + std::vector h_out(CElements, type_convert(0.0f)); + // The actual scale is computed as pow(2, scale - 127), so: + // 126 -> 2^-1 and 129 -> 2^2. + ScaleAType h_scale_a = 126; + ScaleBType h_scale_b = 129; + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + ScaleAType* d_scale_a; + ScaleBType* d_scale_b; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize)); + HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice)); + + const auto wave_size = getDeviceWaveSize(); + kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Verify output against expected value for all elements + for(size_t i = 0; i < CElements; ++i) + { + EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); + HIP_CHECK_ERROR(hipFree(d_scale_a)); + HIP_CHECK_ERROR(hipFree(d_scale_b)); + } }; diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp new file mode 100644 index 0000000000..a9adeba7d7 --- /dev/null +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp @@ -0,0 +1,270 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "pipeline_tests_helper.hpp" + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/utility/functional.hpp" + +#include + +#include +#include +#include +#include + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target()); + +template +void ScaleMfmaGfx950Specialization_impl() +{ + using TestScaleMma = amdgcn_mma; + + static_assert(std::is_same_v && + TestScaleMma::OpFamily == MmaOpFamily::SCALE, + "GFX950 scale intrinsic should have ScaleMFMAOp type"); + + static_assert(is_mma_op_of_family_v, + "GFX950 scale intrinsic should be detected as Scale"); + + // Get its traits + using TestTraits = MmaOpTraits; + + // Verify trait detection + static_assert(TestTraits::IsScale, "Scale MMA should be detected as scale"); + static_assert(TestTraits::IsSupported, "Scale MMA specialization should be supported"); + static_assert(TestTraits::IsMfma, "Scale MFMA should be detected as MFMA"); + static_assert(!TestTraits::IsWmma, "Scale MFMA should not be detected as WMMA"); +} + +TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization) +{ + // Test fp8 → fp32 scale MFMA for GFX950 (16x16x128) + ScaleMfmaGfx950Specialization_impl(); + // Test bf8 → fp32 scale MFMA for GFX950 (16x16x128) + ScaleMfmaGfx950Specialization_impl(); + // Test fp4 → fp32 scale MFMA for GFX950 (16x16x128) + ScaleMfmaGfx950Specialization_impl(); + // Test fp8 → fp32 scale MFMA for GFX950 (32x32x64) + ScaleMfmaGfx950Specialization_impl(); + // Test bf8 → fp32 scale MFMA for GFX950 (32x32x64) + ScaleMfmaGfx950Specialization_impl(); + // Test fp4 → fp32 scale MFMA for GFX950 (32x32x64) + ScaleMfmaGfx950Specialization_impl(); + + std::cout << "GFX950 scale MFMA specialization is correct" << std::endl; +} + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +template +void TestConceptRequirements_impl() +{ + using TestScaleMma = amdgcn_mma; + static_assert(MmaOpI); +} +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +TEST(ScaleMMATrait, TestConceptRequirements) +{ +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + TestConceptRequirements_impl(); + TestConceptRequirements_impl(); + TestConceptRequirements_impl(); + TestConceptRequirements_impl(); + TestConceptRequirements_impl(); + TestConceptRequirements_impl(); +#else + GTEST_SKIP() << "Not compiled with concepts. Skipping test."; +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +} + +template +void ScaleSelector_impl() +{ + static_for<2, 14, 6>{}([](auto k_factor) { + static_for<1, 33, 1>{}([&](auto i) { + using Selected = typename MmaDefaultSelector(i), + static_cast(i), + static_cast(k_factor * i), + CompilerTargetGfx950, + MmaOpFamily::SCALE>::SelectedOp; + static constexpr bool isValid = (i == 16 && k_factor == 8) || (i == 32); + if constexpr(isValid) + { + // Selector should pick a scale MFMA implementation + static_assert(MmaOpTraits::IsScale); + static_assert(MmaOpTraits::IsMfma); + static_assert(MmaOpTraits::IsSupported); + static_assert((std::is_same::value)); + } + else + { + // Selector should pick the unsupported pass through + static_assert(!MmaOpTraits::IsSupported); + } + }); + }); +} + +TEST(ScaleMMATrait, ScaleSelector) +{ + ScaleSelector_impl(); + ScaleSelector_impl(); + ScaleSelector_impl(); +} + +template +__global__ void +test_scale_accum_over_k(void* a, void* b, void* c, void* out, void* scale_A, void* scale_B) +{ + using Pipeline = ScaleMmaPipeline; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + // NOTE: WaveTileK is used as a Pipeline template parameter, but the K iteration is + // happening outside the Pipeline. This is a bit incorrect currently. + static constexpr std::uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK; + + // Initialize the accumulator + CVecType result = *reinterpret_cast(c); + + // Accumulate input AxB over WaveTileK/FragK iterations + for(std::uint32_t i = 0; i < kIters; ++i) + { + result = Pipeline::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + result, + *reinterpret_cast(scale_A), + *reinterpret_cast(scale_B)); + } + + *reinterpret_cast(out) = result; +} + +template +void MmaSelector_Scale_Real_impl() +{ + using TestType = MmaPipelineTest; + TestType test; + const auto should_skip = [](amdgcn_target_id currentArchId) { + bool isSupportedWmma = false; + bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950); + return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma)); + }; + const std::function + validator = + [](std::uint32_t fragK, TestType::ScaleAType scale_A, TestType::ScaleBType scale_B) { + fp32_t actual_scale_A = std::powf(2.0f, scale_A - 127.0f); + fp32_t actual_scale_B = std::powf(2.0f, scale_B - 127.0f); + return static_cast(fragK) * actual_scale_A * actual_scale_B; + }; + const auto kernel = [](std::uint32_t waveSize, + void* a, + void* b, + void* c, + void* out, + void* scale_A, + void* scale_B) { + test_scale_accum_over_k + <<<1, waveSize>>>(a, b, c, out, scale_A, scale_B); + }; + test.test_pipeline(should_skip, kernel, validator); +} + +// Live test on real hardware for scale selection and execution. +TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_16x16x128_Real) +{ + MmaSelector_Scale_Real_impl(); +} + +// Live test on real hardware for scale selection and execution. +TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_16x16x128_Real) +{ + MmaSelector_Scale_Real_impl(); +} + +// Live test on real hardware for scale selection and execution. +TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_16x16x128_Real) +{ + MmaSelector_Scale_Real_impl(); +} + +// Live test on real hardware for scale selection and execution. +TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x64_Real) +{ + MmaSelector_Scale_Real_impl(); +} + +// Live test on real hardware for scale selection and execution. +TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real) +{ + MmaSelector_Scale_Real_impl(); +} + +// Live test on real hardware for scale selection and execution. +TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_32x32x64_Real) +{ + MmaSelector_Scale_Real_impl(); +} diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc index ec8ea2a830..e757ff9cf2 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc @@ -3,18 +3,32 @@ #pragma once -#include -#include - -#include "ck_tile/host/hip_check_error.hpp" -#include "ck_tile/host/stream_config.hpp" -#include "ck_tile/host/device_memory.hpp" -#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mfma/mfma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/scale/scale.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integer.hpp" +// #include "ck_tile/core/numeric/pk_fp4.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/host/device_memory.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_config.hpp" + +#include +#include -#include -#include #include +#include +#include namespace { @@ -22,6 +36,9 @@ using namespace ck_tile; using namespace ck_tile::core::arch; using namespace mma; +// using F4 = pk_fp4_t; +using F8 = fp8_t; +using BF8 = bf8_t; using F16 = fp16_t; using F32 = fp32_t; using Target908 = decltype(make_amdgcn_gfx9_target()); @@ -80,6 +97,10 @@ struct MmaLayoutTestKernel BVecType b_frag{}; CVecType c_frag{}; uint32_t sparse_idx{}; + // The actual scale is computed as pow(2, scale - 127), so: + // 125 -> 2^-2 and 129 -> 2^2. + int scale_A = 125; + int scale_B = 129; static_assert(MmaOp::kCompressionRatio <= 2); // Allow only 4:2 compression (or no). // get (m, k, n), where "1" should be placed for this block @@ -97,7 +118,7 @@ struct MmaLayoutTestKernel // direction and we just put our "1" in the k / 2 position (rounded down). if(a_coords[0] == m && a_coords[1] == (k / MmaOp::kCompressionRatio)) { - a_frag[v] = 1; + a_frag[v] = type_convert(1.0f); // Calc an appropriate sparse idx value for a single 1 in position k. We use a // baseline index of 0x88888888. This sends each compressed index i to @@ -114,7 +135,7 @@ struct MmaLayoutTestKernel auto b_coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v); if(b_coords[0] == n && b_coords[1] == k) { - b_frag[v] = 1; + b_frag[v] = type_convert(1.0f); } } @@ -122,6 +143,10 @@ struct MmaLayoutTestKernel { c_frag = MmaOp::exec(a_frag, b_frag, c_frag, sparse_idx); } + else if constexpr(MmaOpTraits::IsScale) + { + c_frag = MmaOp::exec(a_frag, b_frag, c_frag, scale_A, scale_B); + } else { c_frag = MmaOp::exec(a_frag, b_frag, c_frag); @@ -211,24 +236,30 @@ void run_mma_layout_test() // Lists of intrinsics to test. // clang-format off using Gfx9Intrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_16x16x16f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_16x16x16f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma // mfma_f32_4x4x4f16 >; using Gfx942Intrinsics = ::testing::Types< - amdgcn_mma // smfmac_f32_16x16x32_f16 + amdgcn_mma // smfmac_f32_16x16x32_f16 >; using Gfx950Intrinsics = ::testing::Types< - amdgcn_mma // mfma_f32_16x16x32_f16 + amdgcn_mma, // mfma_f32_16x16x32_f16 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + // amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma // mfma_scale_f32_32x32x64_f8f6f4 + // amdgcn_mma // mfma_scale_f32_32x32x64_f8f6f4 >; using Gfx11Intrinsics = ::testing::Types< - amdgcn_mma, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32 + amdgcn_mma, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32 >; using Gfx12Intrinsics = ::testing::Types< - amdgcn_mma, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12 - amdgcn_mma // swmmac_f32_16x16x32_f16_w32 + amdgcn_mma, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12 + amdgcn_mma // swmmac_f32_16x16x32_f16_w32 >; // clang-format on