From 0ebeb88ba97c056d4c9cf0056ad7d0e0b23e3917 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Mon, 27 Apr 2026 13:57:51 +0200 Subject: [PATCH] [CK Tile] Adding WMMA wrappers for dense builtins (#5801) ## Motivation This PR is part of the [WMMA/MFMA] unification work. It's the first of the series of PRs that add all the necessary MMA builtins as a `amdgcn_mma` structs. ## Technical Details This change adds new specializations for WMMA dense builtins. In total, we have now 9 RDNA4 builtins and 3 RDNA3 builtins. ## Test Plan All the new wrappers were added to the test suite in `test_amdgcn_mma_layout.inc`. ## Test Result Test pass locally, waiting for the CI. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Yung-sheng Tu --- .../example_tile_distr_enc_calc.cpp | 4 +- .../core/arch/mma/sparse/sparse_traits.hpp | 4 - .../core/arch/mma/sparse/wmma/selector.hpp | 3 +- include/ck_tile/core/arch/mma/wmma/wmma.hpp | 23 -- .../ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp | 81 ++++--- .../ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp | 203 +++++++++++++++++- .../core/arch/mma/wmma/wmma_selector.hpp | 4 +- .../core/arch/mma/wmma/wmma_traits.hpp | 14 ++ .../core/arch/mma/test_amdgcn_mma_layout.inc | 36 ++-- 9 files changed, 296 insertions(+), 76 deletions(-) diff --git a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp index 6de7af2cbd..7559ac6f0c 100644 --- a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp +++ b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp @@ -77,8 +77,8 @@ using Intrinsics = ck_tile::tuple< amdgcn_mma, // mfma_f32_4x4x4f16 amdgcn_mma, // mfma_f32_4x4x4f16 amdgcn_mma, // mfma_f32_16x16x32_f16 - amdgcn_mma, Target11, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32 - amdgcn_mma, Target12, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32 + amdgcn_mma // wmma_f32_16x16x16_f16_w32_gfx12 >; // clang-format on diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp index a551d9b08c..2528105a6a 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp @@ -79,8 +79,4 @@ concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { }; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -struct DefaultSparseWmmaCtrlFlags -{ -}; - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp index 8b4803b6bf..cc3571a778 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/wmma/wmma_traits.hpp" namespace ck_tile::core::arch::mma { @@ -41,7 +42,7 @@ struct SparseWmmaDefaultSelector WaveTileM, WaveTileN, WaveTileKTest, - DefaultSparseWmmaCtrlFlags, + DefaultWmmaCtrlFlags, CompilerTarget, MmaOpFamily::SPARSE>; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma.hpp b/include/ck_tile/core/arch/mma/wmma/wmma.hpp index ae5269dcb8..9b649717ff 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma.hpp @@ -3,29 +3,6 @@ #pragma once -namespace ck_tile::core::arch::mma { - -/** - * @enum WmmaCtrlFlags - * @brief Common wmma control flags for gfx11 and gfx12 - */ -enum struct WmmaCtrlFlags : bool -{ - // Only has an effect on gfx11 when the accumulator is 16-bit - // Determines which half of the 32-bit accum register to use - // Low = bits [15:0] - // High = bits[31:16] - LOW = false, - HIGH = true, - - // Only has an effect on gfx11 / 12 when the input is 8-bit int - // Signage indicator of inputs / accum - UNSIGNED = false, - SIGNED = true -}; - -} // namespace ck_tile::core::arch::mma - // Include the architecture-specific WMMA implementations and traits #include "wmma_gfx11.hpp" #include "wmma_gfx12.hpp" diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp index c86190573e..86f99a3ac5 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -8,8 +8,8 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" -#include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" namespace ck_tile::core::arch::mma { // TODO: Specifically for gfx11 wmma, we need to deal with quirks such as: @@ -36,30 +36,6 @@ namespace ck_tile::core::arch::mma { // For flexibility, it is recommended that for each backend wrapper it supports at least // one packed register for each input to be able to process smaller K values by padding. -/** - * @class DefaultWmmaFlags - * @brief Generates default WMMA control flags based on data types. - * @tparam ADataType Data type of matrix A - * @tparam BDataType Data type of matrix B - * @tparam CDataType Data type of the accumulator - */ -template -struct DefaultWmmaCtrlFlags -{ - // Generate default flags for signage - // Only used currently for integer inputs / accum in gfx11 / gfx12 - constexpr static WmmaCtrlFlags InputSignA = - std::is_signed_v ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED; - constexpr static WmmaCtrlFlags InputSignB = - std::is_signed_v ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED; - constexpr static WmmaCtrlFlags AccumSign = - std::is_signed_v ? WmmaCtrlFlags::SIGNED : WmmaCtrlFlags::UNSIGNED; - - // Generate default flags for accumulator destination bits. - // Only used if accumulation size is 16-bit in gfx11 - constexpr static WmmaCtrlFlags AccumBits = WmmaCtrlFlags::LOW; -}; - /** * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11 @@ -76,11 +52,62 @@ struct amdgcn_mma // clang-format on { - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32(aVec, bVec, cVec)}; } }; +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX11 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(aVec, bVec, cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX11 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness + bit_cast(aVec), + true, // B signedness + bit_cast(bVec), + cVec, + CtrlFlags::Clamp)}; + } +}; + } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp index 0a74bf8d65..941677cba9 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -8,8 +8,8 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" -#include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" namespace ck_tile::core::arch::mma { @@ -22,7 +22,7 @@ namespace ck_tile::core::arch::mma { /** * @struct amdgcn_mma - * @brief Specialization of amdgcn_wmma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 * architecture. * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target @@ -32,15 +32,208 @@ namespace ck_tile::core::arch::mma { template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma()>> : amdgcn_mma_base // clang-format on { - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { return {__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)}; } }; +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(aVec, bVec, cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(aVec, bVec, cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12(aVec, bVec, cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness + bit_cast(aVec), + true, // B signedness + bit_cast(bVec), + cVec, + CtrlFlags::Clamp)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(aVec), bit_cast(bVec), cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(aVec), bit_cast(bVec), cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(aVec), bit_cast(bVec), cVec)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma()>> +: amdgcn_mma_base +// clang-format on +{ + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) + { + return {__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(aVec), bit_cast(bVec), cVec)}; + } +}; + } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp index f8616ad19c..0d6efe9b07 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -36,7 +36,7 @@ struct WmmaDefaultSelector { private: // By default, let's assume no special flags for WMMA - using CtrlFlags = DefaultWmmaCtrlFlags; + using CtrlFlags = DefaultWmmaCtrlFlags; // Define our candidate WMMA implementation for the current parameters using CandidateOp = amdgcn_mma { // By default, let's assume no special flags for WMMA - using CtrlFlags = DefaultWmmaCtrlFlags; + using CtrlFlags = DefaultWmmaCtrlFlags; // Default unsupported pass-through if no instruction is found using SelectedOp = amdgcn_mma static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma::value; +/** + * @struct DefaultWmmaCtrlFlags + * @brief Default WMMA control flags for dense and sparse WMMA operations. + */ +struct DefaultWmmaCtrlFlags +{ + constexpr static bool Clamp = false; + + // Only has an effect on gfx11 when the accumulator is 16-bit. + // Determines which half of the 32-bit accum register to use for the 16-bit result. + // false = low bits [15:0], true = high bits [31:16] + constexpr static bool UseHighAccumBits = true; +}; + } // namespace ck_tile::core::arch::mma diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc index e757ff9cf2..e610a84518 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc @@ -3,19 +3,16 @@ #pragma once +#include +#include + #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/mfma/mfma.hpp" -#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" -#include "ck_tile/core/arch/mma/scale/scale.hpp" #include "ck_tile/core/arch/mma/sparse/sparse.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" -#include "ck_tile/core/numeric/float8.hpp" -#include "ck_tile/core/numeric/half.hpp" -#include "ck_tile/core/numeric/integer.hpp" -// #include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/host/device_memory.hpp" @@ -23,9 +20,9 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_config.hpp" -#include -#include - +#include +#include +#include #include #include #include @@ -40,7 +37,12 @@ using namespace mma; using F8 = fp8_t; using BF8 = bf8_t; using F16 = fp16_t; +using BF16 = bf16_t; using F32 = fp32_t; +using I8 = int8_t; +using FP8 = fp8_t; +using BF8 = bf8_t; +using I32 = int32_t; using Target908 = decltype(make_amdgcn_gfx9_target()); using Target942 = decltype(make_amdgcn_gfx9_target()); using Target950 = decltype(make_amdgcn_gfx9_target()); @@ -255,11 +257,21 @@ using Gfx950Intrinsics = ::testing::Types< // amdgcn_mma // mfma_scale_f32_32x32x64_f8f6f4 >; using Gfx11Intrinsics = ::testing::Types< - amdgcn_mma, Target11, MmaOpFamily::DENSE> // wmma_f32_16x16x16_f16_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32, + amdgcn_mma, // wmma_f32_16x16x16_bf16_w32, + amdgcn_mma // wmma_i32_16x16x16_iu8_w32 >; using Gfx12Intrinsics = ::testing::Types< - amdgcn_mma, Target12, MmaOpFamily::DENSE>, // wmma_f32_16x16x16_f16_w32_gfx12 - amdgcn_mma // swmmac_f32_16x16x32_f16_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12, + amdgcn_mma, // wmma_f32_16x16x16_bf16_w32_gfx12, + amdgcn_mma, // wmma_f16_16x16x16_f16_w32_gfx12, + amdgcn_mma, // wmma_bf16_16x16x16_bf16_w32_gfx12, + amdgcn_mma, // wmma_i32_16x16x16_iu8_w32_gfx12, + amdgcn_mma, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12, + amdgcn_mma, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12, + amdgcn_mma, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12, + amdgcn_mma, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12 + amdgcn_mma // swmmac_f32_16x16x32_f16_w32 >; // clang-format on