[rocm-libraries] ROCm/rocm-libraries#6567 (commit 753c7a8)

[CK Tile] Adding WMMA wrappers for sparse builtins (#6567)

## Motivation

This PR is part of the [WMMA/MFMA] unification work. It's the third of
the series of PRs (after
https://github.com/ROCm/rocm-libraries/pull/5801 and
https://github.com/ROCm/rocm-libraries/pull/6014) that add all the
necessary MMA builtins as amdgcn_mma structs. This PR focuses on sparse
WMMA intrinsics.

## Technical Details

This change adds new specializations for WMMA sparse builtins. In total,
we add 8 WMMA builtins.

## Test Plan

All the new wrappers were added to the test suite in
`test_amdgcn_mma_layout.inc`.

## Test Result

Test pass locally, waiting for the CI.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Wojciech Laskowski
2026-05-22 13:34:33 +02:00
committed by GitHub
parent e02c566795
commit 3ea9ce7e37
2 changed files with 229 additions and 2 deletions

View File

@@ -14,23 +14,242 @@
namespace ck_tile::core::arch::mma {
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32";
CK_TILE_DEVICE static auto
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness
aVec,
true, // B signedness
bVec,
cVec,
idx,
CtrlFlags::Clamp)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32(aVec, bVec, cVec, idx)};
}
};
/**
* @struct amdgcn_mma
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
* architecture.
* @tparam CtrlFlags Control flags for the WMMA operation
* @tparam CompilerTarget Current compiler target
*/
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>
// clang-format off
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
// clang-format on
{
static constexpr const char* instruction_name =
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32";
CK_TILE_DEVICE static CVecType
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
{
return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32(aVec, bVec, cVec, idx)};
}
};
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
// TODO: c++20 requires
template <typename CtrlFlags, typename CompilerTarget>

View File

@@ -434,6 +434,14 @@ using Gfx12Intrinsics = ::testing::Types<
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x32_iu4_w32_gfx12
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_f16_w32
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf16_w32
amdgcn_mma<F16, F16, F16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f16_16x16x32_f16_w32
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_bf16_16x16x32_bf16_w32
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu8_w32
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_fp8_w32
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_bf8_w32
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_fp8_w32
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_bf8_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu4_w32
amdgcn_mma<I4, I4, I32, 16u, 16u, 64u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_i32_16x16x64_iu4_w32
>;