mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 08:18:26 +00:00
[rocm-libraries] ROCm/rocm-libraries#6567 (commit 753c7a8)
[CK Tile] Adding WMMA wrappers for sparse builtins (#6567) ## Motivation This PR is part of the [WMMA/MFMA] unification work. It's the third of the series of PRs (after https://github.com/ROCm/rocm-libraries/pull/5801 and https://github.com/ROCm/rocm-libraries/pull/6014) that add all the necessary MMA builtins as amdgcn_mma structs. This PR focuses on sparse WMMA intrinsics. ## Technical Details This change adds new specializations for WMMA sparse builtins. In total, we add 8 WMMA builtins. ## Test Plan All the new wrappers were added to the test suite in `test_amdgcn_mma_layout.inc`. ## Test Result Test pass locally, waiting for the CI. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
GitHub
parent
e02c566795
commit
3ea9ce7e37
@@ -14,23 +14,242 @@
|
||||
|
||||
namespace ck_tile::core::arch::mma {
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32";
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp16_t, fp16_t, fp16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf16_t, bf16_t, bf16_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<int8_t, int8_t, int32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<int8_t, int8_t, int32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness
|
||||
aVec,
|
||||
true, // B signedness
|
||||
bVec,
|
||||
cVec,
|
||||
idx,
|
||||
CtrlFlags::Clamp)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<fp8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf8_t, fp8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @struct amdgcn_mma
|
||||
* @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12
|
||||
* architecture.
|
||||
* @tparam CtrlFlags Control flags for the WMMA operation
|
||||
* @tparam CompilerTarget Current compiler target
|
||||
*/
|
||||
// TODO: c++20 template <CtrlFlagsGfx12I CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
// clang-format off
|
||||
// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar |
|
||||
struct amdgcn_mma<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, CtrlFlags, CompilerTarget, MmaOpFamily::SPARSE, enable_if_target_family_gfx12_t<CompilerTarget>>
|
||||
: amdgcn_mma_base<bf8_t, bf8_t, fp32_t, 16u, 16u, 32u, 32u, 16, 1, 1, 1, 1, 8, 1, WmmaOp, MmaOpFamily::SPARSE>
|
||||
// clang-format on
|
||||
{
|
||||
static constexpr const char* instruction_name =
|
||||
"__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32";
|
||||
|
||||
CK_TILE_DEVICE static CVecType
|
||||
exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx)
|
||||
{
|
||||
return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32(aVec, bVec, cVec, idx)};
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: c++20 template <CtrlFlagsSparseWmmaI CtrlFlags, amdgcn_target CompilerTarget>
|
||||
// TODO: c++20 requires
|
||||
template <typename CtrlFlags, typename CompilerTarget>
|
||||
|
||||
@@ -434,6 +434,14 @@ using Gfx12Intrinsics = ::testing::Types<
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 16u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x16_iu4_w32_gfx12
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::DENSE>, // wmma_i32_16x16x32_iu4_w32_gfx12
|
||||
amdgcn_mma<F16, F16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_f16_w32
|
||||
amdgcn_mma<BF16, BF16, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf16_w32
|
||||
amdgcn_mma<F16, F16, F16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f16_16x16x32_f16_w32
|
||||
amdgcn_mma<BF16, BF16, BF16, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_bf16_16x16x32_bf16_w32
|
||||
amdgcn_mma<I8, I8, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu8_w32
|
||||
amdgcn_mma<F8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_fp8_w32
|
||||
amdgcn_mma<F8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_fp8_bf8_w32
|
||||
amdgcn_mma<BF8, F8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_fp8_w32
|
||||
amdgcn_mma<BF8, BF8, F32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_f32_16x16x32_bf8_bf8_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 32u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE>, // swmmac_i32_16x16x32_iu4_w32
|
||||
amdgcn_mma<I4, I4, I32, 16u, 16u, 64u, DefaultWmmaCtrlFlags, Target12, MmaOpFamily::SPARSE> // swmmac_i32_16x16x64_iu4_w32
|
||||
>;
|
||||
|
||||
Reference in New Issue
Block a user