From 3ea9ce7e37dad3811508affdc2acdf644769fc07 Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Fri, 22 May 2026 13:34:33 +0200 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6567 (commit 753c7a8) [CK Tile] Adding WMMA wrappers for sparse builtins (#6567) ## Motivation This PR is part of the [WMMA/MFMA] unification work. It's the third of the series of PRs (after https://github.com/ROCm/rocm-libraries/pull/5801 and https://github.com/ROCm/rocm-libraries/pull/6014) that add all the necessary MMA builtins as amdgcn_mma structs. This PR focuses on sparse WMMA intrinsics. ## Technical Details This change adds new specializations for WMMA sparse builtins. In total, we add 8 WMMA builtins. ## Test Plan All the new wrappers were added to the test suite in `test_amdgcn_mma_layout.inc`. ## Test Result Test pass locally, waiting for the CI. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- .../arch/mma/sparse/wmma/sparse_gfx12.hpp | 223 +++++++++++++++++- .../core/arch/mma/test_amdgcn_mma_layout.inc | 8 + 2 files changed, 229 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index fd8c817794..2257cf7db8 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -14,23 +14,242 @@ namespace ck_tile::core::arch::mma { +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ // TODO: c++20 template // TODO: c++20 requires template // clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)}; } }; +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness + aVec, + true, // B signedness + bVec, + cVec, + idx, + CtrlFlags::Clamp)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32(aVec, bVec, cVec, idx)}; + } +}; + // TODO: c++20 template // TODO: c++20 requires template diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc index 391caa924d..45656f9e7f 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc @@ -434,6 +434,14 @@ using Gfx12Intrinsics = ::testing::Types< amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 amdgcn_mma, // wmma_i32_16x16x32_iu4_w32_gfx12 amdgcn_mma, // swmmac_f32_16x16x32_f16_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_f16_16x16x32_f16_w32 + amdgcn_mma, // swmmac_bf16_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_i32_16x16x32_iu8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_bf8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_bf8_w32 amdgcn_mma, // swmmac_i32_16x16x32_iu4_w32 amdgcn_mma // swmmac_i32_16x16x64_iu4_w32 >;