diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index fd8c817794..2257cf7db8 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -14,23 +14,242 @@ namespace ck_tile::core::arch::mma { +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ // TODO: c++20 template // TODO: c++20 requires template // clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { return {__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32(aVec, bVec, cVec, idx)}; } }; +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp16_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf16_t, bf16_t, bf16_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for int8_t, int8_t, int32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness + aVec, + true, // B signedness + bVec, + cVec, + idx, + CtrlFlags::Clamp)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp8_t, fp8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for fp8_t, bf8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf8_t, fp8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32(aVec, bVec, cVec, idx)}; + } +}; + +/** + * @struct amdgcn_mma + * @brief Specialization of amdgcn_mma for bf8_t, bf8_t, fp32_t MMA operation on GFX12 + * architecture. + * @tparam CtrlFlags Control flags for the WMMA operation + * @tparam CompilerTarget Current compiler target + */ +// TODO: c++20 template +// TODO: c++20 requires +template +// clang-format off +// | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | +struct amdgcn_mma> +: amdgcn_mma_base +// clang-format on +{ + static constexpr const char* instruction_name = + "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32"; + + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) + { + return {__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32(aVec, bVec, cVec, idx)}; + } +}; + // TODO: c++20 template // TODO: c++20 requires template diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc index 391caa924d..45656f9e7f 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc @@ -434,6 +434,14 @@ using Gfx12Intrinsics = ::testing::Types< amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 amdgcn_mma, // wmma_i32_16x16x32_iu4_w32_gfx12 amdgcn_mma, // swmmac_f32_16x16x32_f16_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_f16_16x16x32_f16_w32 + amdgcn_mma, // swmmac_bf16_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_i32_16x16x32_iu8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_bf8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_bf8_w32 amdgcn_mma, // swmmac_i32_16x16x32_iu4_w32 amdgcn_mma // swmmac_i32_16x16x64_iu4_w32 >;