From 2089713f94fe2d23aa2ea94b90f38d09f6fbbded Mon Sep 17 00:00:00 2001 From: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com> Date: Fri, 26 Jun 2026 12:00:58 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#8227 (commit 75c30d5) =?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?= =?UTF-8?q?Remove=20unification=20Flag=20structs=20in=20favor=20of=20new?= =?UTF-8?q?=20WarpGemmParams=20(#8227)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Recently, the way flags are sent down to the intrinsics was changed in CK Tile. At the point where the WarpGemm is invoked, an arbitrary number of template parameters can be passed, and these are passed down all the way to the lowest level intrinsics wrappers. Here `WarpGemmParamsParser<>` is used to extract flags for the intrinsics. In this MR we adapt the the unification framework (amdgcn_mma struct and MmaPipelines) to work in the same way. By doing this, there is no longer a point in our custom intrinsic Flag structs, so these are removed. Unrelated but I also tried removing the MmaPipeline flags because they arn't used for anything except CTranspose, which is already available. This also make test_amdgcn_mma_pipeline completely redundant so removed that as well. --- .../example_tile_distr_enc_calc.cpp | 22 +- include/ck_tile/core.hpp | 1 - include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 28 +- .../ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp | 896 ++++++++---------- .../core/arch/mma/mfma/mfma_selector.hpp | 1 - .../core/arch/mma/mfma/mfma_traits.hpp | 46 - .../ck_tile/core/arch/mma/mma_pipeline.hpp | 112 +-- .../ck_tile/core/arch/mma/mma_selector.hpp | 17 +- include/ck_tile/core/arch/mma/mma_traits.hpp | 11 - .../ck_tile/core/arch/mma/mma_wavewise.hpp | 36 +- .../core/arch/mma/scale/mfma/scale_gfx9.hpp | 141 +-- .../core/arch/mma/scale/mfma/selector.hpp | 1 - .../arch/mma/scale/scale_mma_pipeline.hpp | 34 +- .../core/arch/mma/scale/scale_traits.hpp | 79 +- .../core/arch/mma/sparse/mfma/selector.hpp | 2 - .../core/arch/mma/sparse/mfma/sparse_gfx9.hpp | 450 ++++----- .../ck_tile/core/arch/mma/sparse/sparse.hpp | 1 - .../arch/mma/sparse/sparse_mma_pipeline.hpp | 30 +- .../core/arch/mma/sparse/sparse_traits.hpp | 106 --- .../core/arch/mma/sparse/wmma/selector.hpp | 1 - .../arch/mma/sparse/wmma/sparse_gfx12.hpp | 104 +- .../ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp | 39 +- .../ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp | 102 +- .../core/arch/mma/wmma/wmma_selector.hpp | 1 - .../core/arch/mma/wmma/wmma_traits.hpp | 23 - .../core/arch/mma/wmma/wmma_transforms.hpp | 19 +- .../ops/gemm/warp/warp_gemm_params.hpp | 21 + test/ck_tile/core/arch/mma/CMakeLists.txt | 3 - .../mma/pipeline/test_amdgcn_mma_pipeline.cpp | 66 -- .../mma/pipeline/test_amdgcn_scale_mma.cpp | 17 +- .../mma/pipeline/test_amdgcn_sparse_mma.cpp | 15 +- .../ck_tile/core/arch/mma/test_amdgcn_mma.cpp | 16 +- .../core/arch/mma/test_amdgcn_mma_layout.inc | 261 +++-- 33 files changed, 1059 insertions(+), 1643 deletions(-) delete mode 100644 include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp delete mode 100644 test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp diff --git a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp index 9e62f6e939..41b954a6de 100644 --- a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp +++ b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp @@ -73,17 +73,17 @@ int check_tile_distr_enc() // List of intrinsics to test. // clang-format off using Intrinsics = ck_tile::tuple< - amdgcn_mma, // mfma_f32_16x16x16f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_16x16x32_f16 - amdgcn_mma, // wmma_f32_16x16x16_f16_w32 - amdgcn_mma, // wmma_i32_16x16x16_iu4_w32 - amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 - amdgcn_mma // wmma_i32_16x16x32_iu4_w32_gfx12 + amdgcn_mma, // mfma_f32_16x16x16f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_16x16x32_f16 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32 + amdgcn_mma, // wmma_i32_16x16x16_iu4_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 + amdgcn_mma // wmma_i32_16x16x32_iu4_w32_gfx12 >; // clang-format on diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index b070d0c68a..47ba274a15 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -43,7 +43,6 @@ #include "ck_tile/core/arch/mma/sparse/sparse.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index d4330b8c73..938fc1791d 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -246,25 +246,14 @@ CK_TILE_HOST_DEVICE constexpr const char* to_string(Unsupported) { return "Unsup #if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -/** - * @concept HasExecSignature - * @brief Helper concept for exec signature check. - */ -template -concept HasExecSignature = requires { - { - MmaOp::exec(typename MmaOp::AVecType{}, - typename MmaOp::BVecType{}, - typename MmaOp::CVecType{}, - std::declval()...) - } -> std::convertible_to; -}; - /** * @concept MmaOpI * @brief Expresses the meta-data interface required for each MmaOp policy. */ // TODO: Make sure this actually matches amdgcn_mma. +// NOTE: It is no longer possible to perform a check on the exec() function, since it is now +// templated over the variadic WarpGemmParams template pack for intrinsic flags. It seems like +// concepts do not work for templated device functions. template concept MmaOpI = requires(MmaOp op) { // Requires an op context @@ -287,7 +276,7 @@ concept MmaOpI = requires(MmaOp op) { { MmaOp::kCMPerLane } -> std::convertible_to; { MmaOp::kCMNumAccess } -> std::convertible_to; { MmaOp::kCompressionRatio } -> std::convertible_to; -} && (HasExecSignature || HasExecSignature || HasExecSignature); +}; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER @@ -305,7 +294,6 @@ concept MmaOpI = requires(MmaOp op) { * @tparam FragM M-dimension of mma intrinsic (MmaTile) * @tparam FragN N-dimension of mma intrinsic (MmaTile) * @tparam FragK K-dimension of mma intrinsic (MmaTile) - * @tparam CtrlFlags Control flags for mma operation * @tparam CompilerTarget The current compiler target * @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.) * @tparam Enabler SFINAE enabler @@ -316,7 +304,6 @@ template @@ -326,6 +313,7 @@ struct amdgcn_mma : amdgcn_mma_base CK_TILE_DEVICE static auto exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC) { @@ -347,7 +335,6 @@ template @@ -357,7 +344,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma const& mmaObj) @@ -392,10 +378,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma) - { - print_flags(CtrlFlags{}); - } print(CompilerTarget{}); } diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp index b3f2a90cd4..ad4a055a06 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -17,6 +17,7 @@ #include "ck_tile/core/numeric/tfloat32.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -31,29 +32,26 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp32_t, fp32_t, fp32_t MMA operation on GFX9 * architecture. - * @tparam CtrlFlags Control flags for the MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -61,29 +59,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -91,29 +86,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -121,29 +113,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -151,29 +140,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -181,29 +167,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -211,29 +194,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x2f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x2f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x2f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -241,29 +221,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x4f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -271,25 +248,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -297,25 +274,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -323,25 +300,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -349,25 +326,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -375,25 +352,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -401,25 +378,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -427,25 +404,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x8f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x8f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x8f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -453,25 +430,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x16f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x16f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -479,29 +456,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -509,29 +483,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -539,29 +510,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -569,29 +537,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -599,29 +564,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_4x4x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_4x4x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_4x4x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -629,29 +591,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_4x4x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_4x4x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_4x4x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -659,29 +618,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x8i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x8i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -689,29 +645,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x16i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x16i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -719,25 +672,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -745,25 +698,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -771,25 +724,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -797,25 +750,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -823,25 +776,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -849,25 +802,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -875,25 +828,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x4bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -901,25 +854,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x8bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x8bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x8bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -927,25 +880,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -953,25 +907,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -979,25 +934,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1005,25 +961,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1031,25 +988,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1057,25 +1015,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1083,25 +1042,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x8bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x8bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1109,25 +1069,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x16bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1135,31 +1096,32 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f64_16x16x4f64"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { // Note: BLGP flag has another meaning for f64 builtins: BLGP bits [0:2] cause negation of // the A, B, and C input matrices respectively (ref. ISA docs for MI300 Instinct) + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_f64_16x16x4f64(bit_cast(aVec), bit_cast(bVec), cVec, - CtrlFlags::Cbsz, // CBSZ ignored for f64 - CtrlFlags::Abid, // ABID ignored for f64 - CtrlFlags::Blgp)}; + P::cbsz, // CBSZ ignored for f64 + P::abid, // ABID ignored for f64 + P::blgp)}; } }; @@ -1167,30 +1129,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f64_4x4x4f64"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_f64_4x4x4f64(bit_cast(aVec), bit_cast(bVec), bit_cast(cVec), - CtrlFlags::Cbsz, // CBSZ ignored for f64 - CtrlFlags::Abid, // ABID ignored for f64 - CtrlFlags::Blgp)}; + P::cbsz, // CBSZ ignored for f64 + P::abid, // ABID ignored for f64 + P::blgp)}; } }; @@ -1198,30 +1161,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f64_4x4x4f64"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_f64_4x4x4f64(bit_cast(aVec), bit_cast(bVec), bit_cast(cVec), - CtrlFlags::Cbsz, // CBSZ ignored for f64 - CtrlFlags::Abid, // ABID ignored for f64 - CtrlFlags::Blgp)}; + P::cbsz, // CBSZ ignored for f64 + P::abid, // ABID ignored for f64 + P::blgp)}; } }; @@ -1229,29 +1193,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x32_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x32_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1259,29 +1220,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x16_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x16_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1289,26 +1247,27 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x8_xf32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x8_xf32( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x8_xf32(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1316,26 +1275,27 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4_xf32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4_xf32( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x4_xf32(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1343,29 +1303,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1373,29 +1330,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1403,29 +1357,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1433,29 +1384,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1463,29 +1411,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1493,29 +1438,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1523,29 +1465,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1553,29 +1492,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1583,25 +1519,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x32_f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1609,25 +1546,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x32_bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1635,25 +1573,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x16_f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1661,25 +1600,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x16_bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1687,29 +1627,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x64_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x64_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x64_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1717,29 +1654,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x32_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x32_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x32_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index 28af0c4568..233ee77526 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -53,7 +53,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp index b4f23fc9bc..29340fece7 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp @@ -3,16 +3,8 @@ #pragma once -#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/config.hpp" -#include "ck_tile/core/numeric/integer.hpp" - -#include -#include #include -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER namespace ck_tile::core::arch::mma { @@ -56,42 +48,4 @@ struct is_mma_op_mfma static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma::value; -/** - * @struct DefaultMfmaCtrlFlags - * @brief Default MFMA flags, no broadcasting or rotation of inputs - * @note For f64 MFMA instructions, CBSZ and ABID are ignored and BLGP is repurposed for matrix - * negation. BLGP bits [0:2] negate the A, B, and C input matrices respectively (ref. ISA docs for - * MI300 Instinct). - */ -struct DefaultMfmaCtrlFlags -{ - static constexpr int32_t Cbsz = 0; // CBSZ flag, default 0 - static constexpr int32_t Abid = 0; // ABID flag, default 0 - static constexpr int32_t Blgp = 0; // BLGP flag, default 0 -}; - -CK_TILE_HOST_DEVICE void print_flags(DefaultMfmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags Cbsz / Abid / Blgp : %" PRId32 " / %" PRId32 " / %" PRId32 "\n", - ctrlFlags.Cbsz, - ctrlFlags.Abid, - ctrlFlags.Blgp); -} - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -/** - * @concept CtrlFlagsGfx9I - * @brief Expresses the interface of required members for each CtrlFlags type on Gfx9 - */ -template -concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) { - // Flag members for Gfx9 MFMA instructions - { CtrlFlags::Cbsz } -> std::convertible_to; - { CtrlFlags::Abid } -> std::convertible_to; - { CtrlFlags::Blgp } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/include/ck_tile/core/arch/mma/mma_pipeline.hpp index 3198f2c41f..b687c1adc9 100644 --- a/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -16,88 +16,11 @@ #endif namespace ck_tile::core::arch::mma { -/*! @enum MmaPipelineOptionFlag - * @brief Individual option flags for configuring MmaPipeline behavior. - */ -enum struct MmaPipelineOptionFlag : unsigned -{ - NONE = 0x0, ///< No flags set - ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output - COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input -}; - -/** - * @struct MmaPipelineOptionFlags - * @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values. - * @par Provides bitwise OR, AND, NOT, and equality operators for composing - * and querying pipeline option flags. - */ -struct MmaPipelineOptionFlags -{ - using Type = std::underlying_type_t; - - explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {} - explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {} - constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag)) - { - } - constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original) - : mFlags(original.mFlags) - { - } - - constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue) - { - mFlags |= toType(addValue); - return *this; - } - constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const - { - MmaPipelineOptionFlags result(*this); - result |= addValue; - return result; - } - constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue) - { - mFlags &= toType(maskValue); - return *this; - } - constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const - { - MmaPipelineOptionFlags result(*this); - result &= maskValue; - return result; - } - constexpr MmaPipelineOptionFlags operator~() const - { - MmaPipelineOptionFlags result(*this); - result.mFlags = ~result.mFlags; - return result; - } - constexpr bool testFlag(MmaPipelineOptionFlag flag) const - { - return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag; - } - constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); } - constexpr bool operator==(Type rhs) const { return mFlags == rhs; } - - private: - Type mFlags; - static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast(f); } -}; - -constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs) -{ - return rhs == lhs; -} - /** * @class MmaPipelineBase * @brief CRTP base class that implements the common Mma pipeline logic shared by * all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.). * - * @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling - * pipeline behavior (e.g., C transposition, A compression). * @tparam Derived The concrete CRTP-derived pipeline class. Must expose: * - Type aliases: @c AWarpTensor, @c BWarpTensor, @c CWarpTensor, @c MmaOp * - Transform aliases: @c ATransform, @c BTransform, @c CTransform, @c DTransform @@ -107,14 +30,11 @@ constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOpt * 1. Apply pre-transforms to input buffers (A, B, C). * 2. Delegate to @c Derived::execImpl for the actual mma loop. * 3. Apply post-transform to output buffer (D). - * When @c ABSwap is set, the A and B inputs are swapped before step 1. + * When CTranspose is used, the A and B inputs are swapped before step 1. */ -// TODO: c++20: use MmaPipelineOptionFlags directly -template +template struct MmaPipelineBase { - static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); - /** * @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output). * @tparam ATensor Type of the A WaveTile tensor (static_distributed_tensor). @@ -125,17 +45,17 @@ struct MmaPipelineBase * @param accum Input/output accumulator WaveTile C. * @return The output WaveTile D after accumulation and post-transform. */ - template + template CK_TILE_DEVICE static decltype(auto) exec(ATensor& a, BTensor& b, CTensor& accum) { if constexpr(MmaOpTraits::IsSupported) { - if constexpr(Flags & MmaPipelineOptionFlag::ABSwap) + if constexpr(Derived::CTranspose) { decltype(auto) a_transformed = Derived::ATransform::exec(b); decltype(auto) b_transformed = Derived::BTransform::exec(a); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::execImpl(a_transformed, b_transformed, c_transformed); + Derived::template execImpl(a_transformed, b_transformed, c_transformed); return Derived::DTransform::exec(c_transformed); } else @@ -143,7 +63,7 @@ struct MmaPipelineBase decltype(auto) a_transformed = Derived::ATransform::exec(a); decltype(auto) b_transformed = Derived::BTransform::exec(b); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::execImpl(a_transformed, b_transformed, c_transformed); + Derived::template execImpl(a_transformed, b_transformed, c_transformed); return Derived::DTransform::exec(c_transformed); } } @@ -153,7 +73,7 @@ struct MmaPipelineBase // Code should not reach here, but HOST/DEVICE compile passes are // weirdly intertwined and instead of having constexpr in the calling // site (tests) we do this. See also changes by this commit. - return Derived::MmaOp::exec({}, {}, {}); + return Derived::MmaOp::template exec({}, {}, {}); } } @@ -162,11 +82,10 @@ struct MmaPipelineBase template CK_TILE_DEVICE void operator()(CTensor& c, ATensor& a, const BTensor& b) const { - exec(a, b, c); + exec(a, b, c); } - template ::IsSupported) { - if constexpr(Flags & MmaPipelineOptionFlag::ABSwap) + if constexpr(Derived::CTranspose) { // TODO: Figure out which combination of a/b, scale_A/B, and opselA/B needs to be // AB-swapped in order to get correct results. Note that WarpGemmParamsParser @@ -188,7 +107,7 @@ struct MmaPipelineBase decltype(auto) a_transformed = Derived::ATransform::exec(b); decltype(auto) b_transformed = Derived::BTransform::exec(a); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::template execImpl( + Derived::template execImpl( a_transformed, b_transformed, c_transformed, scale_A, scale_B); return Derived::DTransform::exec(c_transformed); } @@ -197,7 +116,7 @@ struct MmaPipelineBase decltype(auto) a_transformed = Derived::ATransform::exec(a); decltype(auto) b_transformed = Derived::BTransform::exec(b); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::template execImpl( + Derived::template execImpl( a_transformed, b_transformed, c_transformed, scale_A, scale_B); return Derived::DTransform::exec(c_transformed); } @@ -219,8 +138,7 @@ struct MmaPipelineBase const int32_t& a_scale, const int32_t& b_scale) const { - using P = WarpGemmParamsParser; - exec(a, b, c, a_scale, b_scale); + exec(a, b, c, a_scale, b_scale); } }; @@ -232,8 +150,8 @@ struct MmaPipelineBase * @concept MmaPipelineI * @brief Expresses the meta-data interface required for a CRTP MmaPipeline. */ -template -concept MmaPipelineInterface = std::derived_from>; +template +concept MmaPipelineInterface = std::derived_from>; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp index 8491f96837..b8d6f31558 100644 --- a/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -49,7 +49,6 @@ struct MmaDefaultSelector WaveTileM, WaveTileN, WaveTileK, - void, amdgcn_target<>, MmaOpFamily::UNDEFINED>; }; @@ -88,7 +87,6 @@ template MmaOpFamily OpFamily> struct MmaKSearchSelector @@ -102,7 +100,6 @@ struct MmaKSearchSelector WaveTileM, WaveTileN, WaveTileKTest, - CtrlFlags, CompilerTarget, OpFamily>; @@ -118,7 +115,6 @@ struct MmaKSearchSelector WaveTileM, WaveTileN, WaveTileKTest / 2u, - CtrlFlags, CompilerTarget, OpFamily>::SelectedOp>; }; @@ -128,7 +124,6 @@ template MmaOpFamily OpFamily> struct MmaKSearchSelector { // Recursion endpoint: unsupported default implementation. - using SelectedOp = amdgcn_mma; + using SelectedOp = + amdgcn_mma; }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp index 88764a75b0..863ad07bbb 100644 --- a/include/ck_tile/core/arch/mma/mma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -6,11 +6,8 @@ #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/config.hpp" #include "mfma/mfma_traits.hpp" -#include "scale/scale_traits.hpp" -#include "sparse/sparse_traits.hpp" #include "wmma/wmma_traits.hpp" -#include #include #include @@ -61,7 +58,6 @@ struct MmaOpTraits; * @tparam FragM_ Size of the M dimension * @tparam FragN_ Size of the N dimension * @tparam FragK_ Size of the K dimension - * @tparam CtrlFlags_ Control flags for the MMA operation * @tparam CompilerTarget_ The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget_> @@ -80,7 +75,6 @@ struct MmaOpTraits> { @@ -90,12 +84,10 @@ struct MmaOpTraits; // Capture incoming template parameters not already in amdgcn - using CtrlFlags = CtrlFlags_; using CompilerTarget = CompilerTarget_; // TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_; @@ -115,7 +107,6 @@ template CK_TILE_HOST_DEVICE void print(MmaOpTraits> const& traitsObj) { @@ -134,7 +124,6 @@ CK_TILE_HOST_DEVICE void print(MmaOpTraits{}); printf( diff --git a/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/include/ck_tile/core/arch/mma/mma_wavewise.hpp index 5894a520ea..253457525b 100644 --- a/include/ck_tile/core/arch/mma/mma_wavewise.hpp +++ b/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -28,15 +28,6 @@ enum struct MmaAccumPolicy COL_MAJOR }; -namespace dense::wavewise::detail { -// TODO: c++20: return MmaPipelineOptionFlags directly -template -constexpr inline int getPipelineFlags() -{ - return static_cast(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE); -} -} // namespace dense::wavewise::detail - /** * @class Mma * @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation @@ -50,7 +41,7 @@ constexpr inline int getPipelineFlags() * @tparam WaveTileN Mma WaveTile N dimension * @tparam WaveTileK Mma WaveTile K dimension * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout. + * @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout. * @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation. * @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp. * @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp. @@ -72,7 +63,7 @@ template ::SelectedTransforms> // clang-format off -struct WaveWiseMmaPipeline : public MmaPipelineBase(), WaveWiseMmaPipeline> +struct WaveWiseMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase(), WaveWiseMmaPipeline>; + using Base = MmaPipelineBase>; // clang-format on - using MmaOp = MmaOp_; + using MmaOp = MmaOp_; + static constexpr bool CTranspose = CTranspose_; using ADataType = typename MmaOp::ADataType; using BDataType = typename MmaOp::BDataType; @@ -185,7 +177,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase + template CK_TILE_DEVICE static void execImpl(ATensor& a, BTensor& b, CTensor& c) { static_assert( @@ -205,9 +197,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn)); } } } @@ -220,9 +213,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn)); } } } diff --git a/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp b/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp index f9245dc06f..32dd252d0d 100644 --- a/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -23,14 +24,13 @@ namespace ck_tile::core::arch::mma { * This specialization implements the Scale MFMA instruction for fp8_t A and B * matrices with fp32_t accumulator, with 16x16x128 block sizes. * - * @tparam CtrlFlags Control flags for the Scale MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -38,19 +38,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -62,14 +63,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -77,19 +77,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -101,14 +102,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -116,10 +116,11 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; int32x4_t arg_a = bit_cast(aVec); int32x4_t arg_b = bit_cast(bVec); @@ -129,9 +130,9 @@ struct amdgcn_mma, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -143,33 +144,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -182,33 +183,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -221,14 +222,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -236,19 +236,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -260,14 +261,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -275,19 +275,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -299,14 +300,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -314,10 +314,11 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; int32x4_t arg_a = bit_cast(aVec); int32x4_t arg_b = bit_cast(bVec); @@ -327,9 +328,9 @@ struct amdgcn_mma, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -341,33 +342,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -380,33 +381,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; diff --git a/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp b/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp index cb7e68a2c7..bfdb78de09 100644 --- a/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp +++ b/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp @@ -55,7 +55,6 @@ struct MmaDefaultSelector; }; diff --git a/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp index 45c3d71789..bed2a6506b 100644 --- a/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp @@ -32,7 +32,7 @@ namespace ck_tile::core::arch::mma { * @tparam WaveTileN Mma WaveTile N dimension * @tparam WaveTileK Mma WaveTile K dimension * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout. + * @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout. * @tparam SwizzleFactor Swizzlefactor for Tile Distribution Encoding calculation. * @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp. * @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp. @@ -47,7 +47,7 @@ template ::SelectedTransforms> // clang-format off -struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> +struct ScaleMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; + using Base = MmaPipelineBase>; // clang-format on - using MmaOp = MmaOp_; // Expose the selected MmaOp + using MmaOp = MmaOp_; // Expose the selected MmaOp + static constexpr bool CTranspose = CTranspose_; using ADataType = typename MmaOp::ADataType; using BDataType = typename MmaOp::BDataType; @@ -170,8 +171,7 @@ struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOpt static_assert(WaveTileK % MmaOp::kK == 0u, "WaveTileK must be a multiple of MmaOp::kK"); // TODO: Why does this even need to be a template? The types should be known. - template (MmaPipelineOpt for(uint32_t bk = 0u; bk < FragsK; ++bk) { c_buf.at(bm * FragsN + bn) = - MmaOp::template exec(a_buf.at(bm * FragsK + bk), - b_buf.at(bn * FragsK + bk), - c_buf.at(bm * FragsN + bn), - scale_A, - scale_B); + MmaOp::template exec(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn), + scale_A, + scale_B); } } } @@ -216,11 +216,11 @@ struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOpt for(uint32_t bk = 0u; bk < FragsK; ++bk) { c_buf.at(bm * FragsN + bn) = - MmaOp::template exec(a_buf.at(bm * FragsK + bk), - b_buf.at(bn * FragsK + bk), - c_buf.at(bm * FragsN + bn), - scale_A, - scale_B); + MmaOp::template exec(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn), + scale_A, + scale_B); } } } diff --git a/include/ck_tile/core/arch/mma/scale/scale_traits.hpp b/include/ck_tile/core/arch/mma/scale/scale_traits.hpp index fa55522015..ea879feebc 100644 --- a/include/ck_tile/core/arch/mma/scale/scale_traits.hpp +++ b/include/ck_tile/core/arch/mma/scale/scale_traits.hpp @@ -3,85 +3,32 @@ #pragma once -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/pk_f6.hpp" -#include -#include -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - namespace ck_tile::core::arch::mma { - namespace scale::detail { +// Utility for converting the datatype of the A or B input matrix in a scale intrinsics to the +// appropriate datatype flag. Note that this is not the same as the flag indicating the scale +// datatype, see ScaleDataTypeToEnum. template -struct ScaleDataTypeToFlag; - +inline constexpr int32_t ScaleDataTypeToFlag_v = [] { + // sizeof(T) trick to only trigger the static assert for unsupported datatypes. + static_assert(sizeof(T) == 0, "Unsupported scale data type"); + return -1; +}(); template <> -struct ScaleDataTypeToFlag // e4m3 (4 exponent bits 3 mantissa bits) -{ - static constexpr int32_t value = 0; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 0; // e4m3 template <> -struct ScaleDataTypeToFlag // e5m2 -{ - static constexpr int32_t value = 1; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 1; // e5m2 template <> -struct ScaleDataTypeToFlag // e2m3 -{ - static constexpr int32_t value = 2; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 2; // e2m3 template <> -struct ScaleDataTypeToFlag // e3m2 -{ - static constexpr int32_t value = 3; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 3; // e3m2 template <> -struct ScaleDataTypeToFlag // e2m1 -{ - static constexpr int32_t value = 4; -}; - -template -inline constexpr int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag::value; - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -/** - * @concept ScaleMfmaDataTypeToFlag - * @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9 - */ -template -concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) { - // Flag members for scale MFMA instructions - { DataTypeToFlag::value } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +inline constexpr int32_t ScaleDataTypeToFlag_v = 4; // e2m1 } // namespace scale::detail - -// No real flags for now, scale and opsel are handled in higher level and passed down directly. -// OPSEL is now passed as a template arg to exec(), see mma_pipeline.hpp -// We will soon get rid of these flags entirely in favor of variadic template packs passed down to -// the intrinsics directly, see WarpGemmParamsParser<>. -struct DefaultScaleMfmaCtrlFlags -{ -}; - -CK_TILE_HOST_DEVICE void print_flags([[maybe_unused]] DefaultScaleMfmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags: (empty)\n"); -} - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp index 31a84ebf13..9ea90daab8 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp @@ -7,7 +7,6 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { @@ -55,7 +54,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp index 994489e9d0..bdb50cc232 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -7,11 +7,11 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" #include @@ -24,55 +24,55 @@ namespace ck_tile::core::arch::mma { * This specialization implements the SMFMA instruction for fp16_t A and B * matrices with structured sparsity, fp32_t accumulator, with 16x16x32 fragment sizes. * - * @tparam CtrlFlags Control flags for the Sparse MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_f16"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x32_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; - } + using P = WarpGemmParamsParser; + return __builtin_amdgcn_smfmac_f32_16x16x32_f16( + aVec, + bVec, + cVec, + idx, + P::cbsz, // Ignore abid and use first portion Y/N + P::abid); // Portion of idx VGPR containing idx info + }; }; /** * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX942 and GFX950 * architecture. - * @tparam CtrlFlags Control flags for the MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x16_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x16_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -80,27 +80,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -108,27 +105,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -136,27 +130,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x64_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_16x16x64_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_16x16x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -164,27 +155,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x32_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_32x32x32_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_32x32x32_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -192,27 +180,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -220,27 +206,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -248,27 +232,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -276,27 +258,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -304,27 +284,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -332,27 +310,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -360,27 +336,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -388,27 +362,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -416,27 +388,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_16x16x64_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -444,27 +413,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x32_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -472,27 +438,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -500,27 +463,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -528,27 +488,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x128_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_16x16x128_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_16x16x128_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -556,27 +513,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x64_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_32x32x64_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_32x32x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -584,27 +538,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -612,27 +564,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -640,27 +590,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -668,27 +616,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -696,27 +642,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -724,27 +668,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -752,27 +694,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -780,27 +720,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/sparse.hpp b/include/ck_tile/core/arch/mma/sparse/sparse.hpp index e9792196c5..303687b1f8 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse.hpp @@ -11,5 +11,4 @@ namespace ck_tile::core::arch::mma { #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index b2d5d5fac4..b39cdae770 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -12,16 +12,6 @@ namespace ck_tile::core::arch::mma { -namespace sparse::detail { -// TODO: c++20: return MmaPipelineOptionFlags directly -template -constexpr inline int getPipelineFlags() -{ - return static_cast(MmaPipelineOptionFlag::COMPRESS_A) | - static_cast(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE); -} -} // namespace sparse::detail - /** * @class SparseMmaPipeline * @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation @@ -38,7 +28,7 @@ constexpr inline int getPipelineFlags() * @tparam WaveTileN Mma WaveTile N dimension * @tparam WaveTileK Mma WaveTile K dimension * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout. + * @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout. * @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation. * @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp. * @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp. @@ -53,7 +43,7 @@ template ::SelectedTransforms> // clang-format off -struct SparseMmaPipeline : public MmaPipelineBase(), SparseMmaPipeline> +struct SparseMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase(), SparseMmaPipeline>; + using Base = MmaPipelineBase>; // clang-format on - using MmaOp = MmaOp_; + using MmaOp = MmaOp_; + static constexpr bool CTranspose = CTranspose_; using ADataType = typename MmaOp::ADataType; using BDataType = typename MmaOp::BDataType; @@ -86,8 +77,7 @@ struct SparseMmaPipeline : public MmaPipelineBase::IsSupported || std::is_same_v); static_assert(!MmaOpTraits::IsSupported || std::is_same_v); static_assert(!MmaOpTraits::IsSupported || std::is_same_v); - static_assert(!(Base::Flags & MmaPipelineOptionFlag::ABSwap), - "Cannot transpose C in sparse intrinsics."); + static_assert(!CTranspose, "Cannot transpose C in sparse intrinsics."); // WaveTile dimensions (Used to be fragment dims but higher level expects these to include k // iteration!) @@ -180,7 +170,7 @@ struct SparseMmaPipeline : public MmaPipelineBase + template CK_TILE_DEVICE static void execImpl(ATransformResult& a, BTensor& b, CTensor& c) { static_assert( @@ -206,7 +196,7 @@ struct SparseMmaPipeline : public MmaPipelineBase( a_frags[bm][bk], b_buf.at(bn * FragsK + bk), c_buf.at(bm * FragsN + bn), @@ -224,7 +214,7 @@ struct SparseMmaPipeline : public MmaPipelineBase( a_frags[bm][bk], b_buf.at(bn * FragsK + bk), c_buf.at(bm * FragsN + bn), diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp deleted file mode 100644 index f5132b89db..0000000000 --- a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core/config.hpp" - -#include -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include -#endif - -namespace ck_tile::core::arch::mma { - -/** - * @enum SparseCompressionIndex - * @brief Indicates which set of sparse-indices within a VGPR starting at srcC - * containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data) - * of index information for a lane. \see DefaultSparseMfmaCtrlFlags - */ -enum struct SparseCompressionIndex : int -{ - FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively - SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively - THIRD = 2, // Uses bits [23:16] - FOURTH = 3, // Uses bits [31:24] -}; - -// to_string methods for enum classes -CK_TILE_HOST_DEVICE constexpr const char* to_string(SparseCompressionIndex compressionIndex) -{ - switch(compressionIndex) - { - case SparseCompressionIndex::FIRST: return "FIRST"; - case SparseCompressionIndex::SECOND: return "SECOND"; - case SparseCompressionIndex::THIRD: return "THIRD"; - case SparseCompressionIndex::FOURTH: return "FOURTH"; - } - __builtin_unreachable(); -} - -namespace sparse::detail { - -/** - * @struct BuiltinParams - * @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse - * builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data: - * If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC - * containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected - * (VGPR[srcC][7..0]). - * - * 8-bit source data: - * If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC - * containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected - * (VGPR[srcC][15..0]). - */ -struct BuiltinParams -{ - int UseFirstIndex; // CBSZ - int ByteIndexToOverride; // ABID -}; - -template -static constexpr BuiltinParams getBuiltinParams() -{ - // TODO c++20: designated initializers - if constexpr(Idx == SparseCompressionIndex::FIRST) - { - return BuiltinParams{1, 0}; - } - else - { - return BuiltinParams{0, static_cast(Idx)}; - } -} - -} // namespace sparse::detail - -/** - * @struct DefaultSparseMfmaCtrlFlags - * @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is - * 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit. - */ -struct DefaultSparseMfmaCtrlFlags -{ - static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST; -}; - -CK_TILE_HOST_DEVICE void print_flags(DefaultSparseMfmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags CompressionIndex : %s\n", to_string(ctrlFlags.CompressionIndex)); -} - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -/** - * @concept SparseMfmaCtrlFlags - * @brief Expresses the interface of required members for each CtrlFlags type - */ -template -concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { - // Flag members for sparse MFMA instructions - { CtrlFlags::CompressionIndex } -> std::convertible_to; -}; -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp index 1829741d37..87d46bd98f 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp @@ -53,7 +53,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index 2257cf7db8..1418bc909c 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -18,20 +19,20 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 * architecture. - * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -43,20 +44,20 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -68,20 +69,20 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -93,21 +94,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -119,30 +120,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness aVec, true, // B signedness bVec, cVec, idx, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -150,21 +152,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -176,21 +178,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -202,21 +204,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -228,21 +230,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -250,51 +252,55 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, idx, - CtrlFlags::Clamp)}; + P::clamp)}; } }; -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, idx, - CtrlFlags::Clamp)}; + P::clamp)}; } }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp index ec89e26ebc..2c4767fde1 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -16,6 +16,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { // TODO: Specifically for gfx11 wmma, we need to deal with quirks such as: @@ -46,20 +47,20 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11 * architecture. - * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -71,20 +72,20 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -96,29 +97,30 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -126,29 +128,30 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp index 92057b1446..9146d6e250 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -17,6 +17,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -31,21 +32,21 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 * architecture. - * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -57,21 +58,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -83,21 +84,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -109,21 +110,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -135,30 +136,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -166,21 +168,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -193,21 +195,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -220,21 +222,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -247,21 +249,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -274,30 +276,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -305,30 +308,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp index 2f75d68d46..645fdc81e6 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -52,7 +52,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp index 1c7c3e9276..59375da04f 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp @@ -4,8 +4,6 @@ #pragma once #include "ck_tile/core/config.hpp" - -#include #include namespace ck_tile::core::arch::mma { @@ -50,25 +48,4 @@ struct is_mma_op_wmma static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma::value; -/** - * @struct DefaultWmmaCtrlFlags - * @brief Default WMMA control flags for dense and sparse WMMA operations. - */ -struct DefaultWmmaCtrlFlags -{ - constexpr static bool Clamp = false; - - // Only has an effect on gfx11 when the accumulator is 16-bit. - // Determines which half of the 32-bit accum register to use for the 16-bit result. - // false = low bits [15:0], true = high bits [31:16] - constexpr static bool UseHighAccumBits = true; -}; - -CK_TILE_HOST_DEVICE void print_flags(DefaultWmmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags Clamp / UseHighAccumBits : %d / %d\n", - ctrlFlags.Clamp, - ctrlFlags.UseHighAccumBits); -} - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp index fd9cd69813..7a04fb4633 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp @@ -9,21 +9,6 @@ namespace ck_tile::core::arch::mma { -/** - * @struct DuplicateTransform - * @brief Transform to duplicate low register elements to high register elements - */ -struct DuplicateTransform -{ - template - CK_TILE_DEVICE static decltype(auto) exec(VecType&& v) - { - // TODO: Implement duplication logic to broadcast low - // register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)] - return std::forward(v); - } -}; - /** * @struct PadTransform * @brief Transform to pad data from original type to b32 type @@ -59,8 +44,8 @@ struct UnpadTransform */ struct MmaDefaultTransformsGfx11 { - using ATransform = DuplicateTransform; - using BTransform = DuplicateTransform; + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; using CTransform = PadTransform; using DTransform = UnpadTransform; }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp index ace20b923e..359b5bbca9 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp @@ -83,6 +83,21 @@ struct SwapReuse_ : bool_constant { }; +template +struct Cbsz : number +{ +}; + +template +struct Abid : number +{ +}; + +template +struct Blgp : number +{ +}; + struct WarpGemmDefaultParams { using clamp = bool_constant; @@ -94,6 +109,9 @@ struct WarpGemmDefaultParams using swap_reuse = bool_constant; // internal use only using scale_a = number<0>; using scale_b = number<0>; + using cbsz = number<0>; + using abid = number<0>; + using blgp = number<0>; }; template class Tag> @@ -151,6 +169,9 @@ class WarpGemmParamsParser public: static constexpr bool clamp = extract(); static constexpr bool post_nop = extract(); + static constexpr index_t cbsz = extract(); + static constexpr index_t abid = extract(); + static constexpr index_t blgp = extract(); static constexpr bool reuse_a = swap_reuse ? raw_reuse_b : raw_reuse_a; static constexpr bool reuse_b = swap_reuse ? raw_reuse_a : raw_reuse_b; static constexpr index_t op_sel_a = swap_reuse ? raw_op_sel_b : raw_op_sel_a; diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index 1a62205490..727098269b 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -87,6 +87,3 @@ if(GPU_TARGETS MATCHES "gfx120") target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() -_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) -target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp deleted file mode 100644 index cc6cee9b3e..0000000000 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/arch/mma/mma_pipeline.hpp" - -namespace { -using namespace ck_tile::core::arch::mma; -} - -TEST(MmaPipelineOptionFlagsTests, ConversionTests) -{ - MmaPipelineOptionFlags flags_0{}; - MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap}; - MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A}; - MmaPipelineOptionFlags flags_3{0b11}; // TODO c++20 - remove this - - EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::ABSwap)); - - EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE(flags_3.testFlag(MmaPipelineOptionFlag::NONE)); -} - -TEST(MmaPipelineOptionFlagsTests, OperatorsTests) -{ - MmaPipelineOptionFlags flags{}; - - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - - flags |= MmaPipelineOptionFlag::ABSwap; - - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); - - flags |= MmaPipelineOptionFlag::COMPRESS_A; - - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - flags &= MmaPipelineOptionFlag::COMPRESS_A; - - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A)); -} diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp index e424b10d34..864997d5f1 100644 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp @@ -40,7 +40,6 @@ void ScaleMfmaGfx950Specialization_impl() WaveTileM, WaveTileN, WaveTileK, - DefaultScaleMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SCALE>; @@ -79,10 +78,7 @@ TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization) std::cout << "GFX950 scale MFMA specialization is correct" << std::endl; } -// TODO: It seems like the ExecSignature concept (and hence MmaOpI) can not be made to work for a -// templated device function for some reason. Disable test for now and fix this once we are using -// the variadic template pack for flags... -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0 +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER template ; @@ -107,7 +102,7 @@ void TestConceptRequirements_impl() TEST(ScaleMMATrait, TestConceptRequirements) { -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0 +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER TestConceptRequirements_impl(); TestConceptRequirements_impl(); TestConceptRequirements_impl(); @@ -216,9 +211,7 @@ struct ScalePipelineKernel constexpr int32_t replicate_byte = 0x01010101; ScaleAType scale_a = 126u * replicate_byte; ScaleBType scale_b = 129u * replicate_byte; - static constexpr index_t opselA = 0; - static constexpr index_t opselB = 0; - Pipeline::template exec(a, b, c, scale_a, scale_b); + Pipeline::template exec, OpSelB<0>>(a, b, c, scale_a, scale_b); __builtin_memcpy( static_cast(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor)); } @@ -399,9 +392,7 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real) // constexpr int32_t replicate_byte = 0x01010101; // ScaleAType scale_a = 126u * replicate_byte; // ScaleBType scale_b = 129u * replicate_byte; -// static constexpr index_t opselA = 0; -// static constexpr index_t opselB = 0; -// Pipeline::template exec(a, b, c, scale_a, scale_b); +// Pipeline::template exec, OpSelB<0>>(a, b, c, scale_a, scale_b); // __builtin_memcpy( // static_cast(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor)); // } diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp index 5d44d3333b..ff156e8692 100644 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp @@ -39,7 +39,6 @@ TEST(SparseMMATrait, SparseMfmaGfx950Specialization) 16u, 16u, 32u, - DefaultSparseMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SPARSE>; @@ -60,7 +59,6 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration) 16u, 16u, 32u, - DefaultSparseMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SPARSE>; @@ -83,7 +81,6 @@ TEST(SparseMMATrait, TestConceptRequirements) 16u, 16u, 32u, - DefaultSparseMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SPARSE>; EXPECT_TRUE(MmaOpI); @@ -95,15 +92,8 @@ TEST(SparseMMATrait, TestConceptRequirements) TEST(SparseMMATrait, DenseVsSparseDistinction) { // Dense MFMA from mfma/mfma_gfx9.hpp - using DenseMfma = amdgcn_mma; + using DenseMfma = + amdgcn_mma; // Sparse MFMA on GFX950 using SparseMfma = amdgcn_mma; diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 8c8109e78d..f15c21bfe6 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -27,9 +27,6 @@ using namespace ck_tile::core::arch::testing; constexpr uint32_t DummyTargetIdVal = 55555u; using DummyCompilerTarget = amdgcn_target(DummyTargetIdVal)>; struct DummyOpType; -struct DummyCtrlFlags -{ -}; /** @brief Returns true if the given target id matches the dummy */ constexpr bool is_dummy_target(DummyCompilerTarget dummy) @@ -49,7 +46,7 @@ using enable_if_target_id_dummy_t = std::enable_if_t // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { @@ -63,15 +60,8 @@ struct amdgcn_mma template -using DummyAmdgcnMma = amdgcn_mma; +using DummyAmdgcnMma = + amdgcn_mma; /*! @struct MmaDefaultSelector * @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc index 1a3bc55aaf..7cd471f562 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc @@ -23,6 +23,7 @@ #include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" #include #include @@ -256,12 +257,10 @@ struct MmaLayoutTestKernel { // The actual scale is computed as pow(2, scale - 127), so: // 125 -> 2^-2 and 129 -> 2^2. - int scale_A = 125; - int scale_B = 129; - static constexpr index_t opselA = 0; - static constexpr index_t opselB = 0; - c_frag = - MmaOp::template exec(a_frag, b_frag, c_frag, scale_A, scale_B); + int scale_A = 125; + int scale_B = 129; + c_frag = MmaOp::template exec, OpSelB<0>>( + a_frag, b_frag, c_frag, scale_A, scale_B); } else { @@ -357,145 +356,145 @@ void run_mma_layout_test() // available on all gfx9 (gfx908, gfx90a, gfx942, gfx950) using Gfx9CommonIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_32x32x1f32 - amdgcn_mma, // mfma_f32_32x32x1f32 - amdgcn_mma, // mfma_f32_16x16x1f32 - amdgcn_mma, // mfma_f32_16x16x1f32 - amdgcn_mma, // mfma_f32_4x4x1f32 - amdgcn_mma, // mfma_f32_4x4x1f32 - amdgcn_mma, // mfma_f32_32x32x2f32 - amdgcn_mma, // mfma_f32_16x16x4f32 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_16x16x4f16 - amdgcn_mma, // mfma_f32_16x16x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_32x32x8f16 - amdgcn_mma, // mfma_f32_16x16x16f16 - amdgcn_mma, // mfma_i32_32x32x4i8 - amdgcn_mma, // mfma_i32_32x32x4i8 - amdgcn_mma, // mfma_i32_16x16x4i8 - amdgcn_mma, // mfma_i32_16x16x4i8 - amdgcn_mma, // mfma_i32_4x4x4i8 - amdgcn_mma // mfma_i32_4x4x4i8 + amdgcn_mma, // mfma_f32_32x32x1f32 + amdgcn_mma, // mfma_f32_32x32x1f32 + amdgcn_mma, // mfma_f32_16x16x1f32 + amdgcn_mma, // mfma_f32_16x16x1f32 + amdgcn_mma, // mfma_f32_4x4x1f32 + amdgcn_mma, // mfma_f32_4x4x1f32 + amdgcn_mma, // mfma_f32_32x32x2f32 + amdgcn_mma, // mfma_f32_16x16x4f32 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_16x16x4f16 + amdgcn_mma, // mfma_f32_16x16x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_32x32x8f16 + amdgcn_mma, // mfma_f32_16x16x16f16 + amdgcn_mma, // mfma_i32_32x32x4i8 + amdgcn_mma, // mfma_i32_32x32x4i8 + amdgcn_mma, // mfma_i32_16x16x4i8 + amdgcn_mma, // mfma_i32_16x16x4i8 + amdgcn_mma, // mfma_i32_4x4x4i8 + amdgcn_mma // mfma_i32_4x4x4i8 >; using Gfx908andGfx90aIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_32x32x2bf16 - amdgcn_mma, // mfma_f32_32x32x2bf16 - amdgcn_mma, // mfma_f32_16x16x2bf16 - amdgcn_mma, // mfma_f32_16x16x2bf16 - amdgcn_mma, // mfma_f32_4x4x2bf16 - amdgcn_mma, // mfma_f32_4x4x2bf16 - amdgcn_mma, // mfma_f32_32x32x4bf16 - amdgcn_mma, // mfma_f32_16x16x8bf16 - amdgcn_mma, // mfma_i32_32x32x8i8 - amdgcn_mma // mfma_i32_16x16x16i8 + amdgcn_mma, // mfma_f32_32x32x2bf16 + amdgcn_mma, // mfma_f32_32x32x2bf16 + amdgcn_mma, // mfma_f32_16x16x2bf16 + amdgcn_mma, // mfma_f32_16x16x2bf16 + amdgcn_mma, // mfma_f32_4x4x2bf16 + amdgcn_mma, // mfma_f32_4x4x2bf16 + amdgcn_mma, // mfma_f32_32x32x4bf16 + amdgcn_mma, // mfma_f32_16x16x8bf16 + amdgcn_mma, // mfma_i32_32x32x8i8 + amdgcn_mma // mfma_i32_16x16x16i8 >; using Gfx90aAndHigherIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_32x32x4bf16_1k - amdgcn_mma, // mfma_f32_32x32x4bf16_1k - amdgcn_mma, // mfma_f32_16x16x4bf16_1k - amdgcn_mma, // mfma_f32_16x16x4bf16_1k - amdgcn_mma, // mfma_f32_4x4x4bf16_1k - amdgcn_mma, // mfma_f32_4x4x4bf16_1k - amdgcn_mma, // mfma_f32_32x32x8bf16_1k - amdgcn_mma, // mfma_f32_16x16x16bf16_1k - amdgcn_mma, // mfma_f64_16x16x4f64 - amdgcn_mma, // mfma_f64_4x4x4f64 - amdgcn_mma // mfma_f64_4x4x4f64 + amdgcn_mma, // mfma_f32_32x32x4bf16_1k + amdgcn_mma, // mfma_f32_32x32x4bf16_1k + amdgcn_mma, // mfma_f32_16x16x4bf16_1k + amdgcn_mma, // mfma_f32_16x16x4bf16_1k + amdgcn_mma, // mfma_f32_4x4x4bf16_1k + amdgcn_mma, // mfma_f32_4x4x4bf16_1k + amdgcn_mma, // mfma_f32_32x32x8bf16_1k + amdgcn_mma, // mfma_f32_16x16x16bf16_1k + amdgcn_mma, // mfma_f64_16x16x4f64 + amdgcn_mma, // mfma_f64_4x4x4f64 + amdgcn_mma // mfma_f64_4x4x4f64 >; using Gfx942AndHigherIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_i32_16x16x32_i8 - amdgcn_mma, // mfma_i32_32x32x16_i8 - amdgcn_mma, // mfma_f32_16x16x32_bf8_bf8 - amdgcn_mma, // mfma_f32_16x16x32_bf8_fp8 - amdgcn_mma, // mfma_f32_16x16x32_fp8_bf8 - amdgcn_mma, // mfma_f32_16x16x32_fp8_fp8 - amdgcn_mma, // mfma_f32_32x32x16_bf8_bf8 - amdgcn_mma, // mfma_f32_32x32x16_bf8_fp8 - amdgcn_mma, // mfma_f32_32x32x16_fp8_bf8 - amdgcn_mma, // mfma_f32_32x32x16_fp8_fp8 - amdgcn_mma, // smfmac_f32_16x16x32_f16 - amdgcn_mma, // smfmac_f32_32x32x16_f16 - amdgcn_mma, // smfmac_f32_16x16x32_bf16 - amdgcn_mma, // smfmac_f32_32x32x16_bf16 - amdgcn_mma, // smfmac_i32_16x16x64_i8 - amdgcn_mma, // smfmac_i32_32x32x32_i8 - amdgcn_mma, // smfmac_f32_16x16x64_bf8_bf8 - amdgcn_mma, // smfmac_f32_16x16x64_bf8_fp8 - amdgcn_mma, // smfmac_f32_16x16x64_fp8_bf8 - amdgcn_mma, // smfmac_f32_16x16x64_fp8_fp8 - amdgcn_mma, // smfmac_f32_32x32x32_bf8_bf8 - amdgcn_mma, // smfmac_f32_32x32x32_bf8_fp8 - amdgcn_mma, // smfmac_f32_32x32x32_fp8_bf8 - amdgcn_mma // smfmac_f32_32x32x32_fp8_fp8 + amdgcn_mma, // mfma_i32_16x16x32_i8 + amdgcn_mma, // mfma_i32_32x32x16_i8 + amdgcn_mma, // mfma_f32_16x16x32_bf8_bf8 + amdgcn_mma, // mfma_f32_16x16x32_bf8_fp8 + amdgcn_mma, // mfma_f32_16x16x32_fp8_bf8 + amdgcn_mma, // mfma_f32_16x16x32_fp8_fp8 + amdgcn_mma, // mfma_f32_32x32x16_bf8_bf8 + amdgcn_mma, // mfma_f32_32x32x16_bf8_fp8 + amdgcn_mma, // mfma_f32_32x32x16_fp8_bf8 + amdgcn_mma, // mfma_f32_32x32x16_fp8_fp8 + amdgcn_mma, // smfmac_f32_16x16x32_f16 + amdgcn_mma, // smfmac_f32_32x32x16_f16 + amdgcn_mma, // smfmac_f32_16x16x32_bf16 + amdgcn_mma, // smfmac_f32_32x32x16_bf16 + amdgcn_mma, // smfmac_i32_16x16x64_i8 + amdgcn_mma, // smfmac_i32_32x32x32_i8 + amdgcn_mma, // smfmac_f32_16x16x64_bf8_bf8 + amdgcn_mma, // smfmac_f32_16x16x64_bf8_fp8 + amdgcn_mma, // smfmac_f32_16x16x64_fp8_bf8 + amdgcn_mma, // smfmac_f32_16x16x64_fp8_fp8 + amdgcn_mma, // smfmac_f32_32x32x32_bf8_bf8 + amdgcn_mma, // smfmac_f32_32x32x32_bf8_fp8 + amdgcn_mma, // smfmac_f32_32x32x32_fp8_bf8 + amdgcn_mma // smfmac_f32_32x32x32_fp8_fp8 >; using Gfx942Intrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_16x16x8_xf32 - amdgcn_mma // mfma_f32_32x32x4_xf32 + amdgcn_mma, // mfma_f32_16x16x8_xf32 + amdgcn_mma // mfma_f32_32x32x4_xf32 >; using Gfx950Intrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_16x16x32_f16 - amdgcn_mma, // mfma_f32_16x16x32_bf16 - amdgcn_mma, // mfma_f32_32x32x16_f16 - amdgcn_mma, // mfma_f32_32x32x16_bf16 - amdgcn_mma, // mfma_i32_16x16x64_i8 - amdgcn_mma, // mfma_i32_32x32x32_i8 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // smfmac_f32_16x16x64_f16 - amdgcn_mma, // smfmac_f32_32x32x32_f16 - amdgcn_mma, // smfmac_f32_16x16x64_bf16 - amdgcn_mma, // smfmac_f32_32x32x32_bf16 - amdgcn_mma, // smfmac_i32_16x16x128_i8 - amdgcn_mma, // smfmac_i32_32x32x64_i8 - amdgcn_mma, // smfmac_f32_16x16x128_bf8_bf8 - amdgcn_mma, // smfmac_f32_16x16x128_bf8_fp8 - amdgcn_mma, // smfmac_f32_16x16x128_fp8_bf8 - amdgcn_mma, // smfmac_f32_16x16x128_fp8_fp8 - amdgcn_mma, // smfmac_f32_32x32x64_bf8_bf8 - amdgcn_mma, // smfmac_f32_32x32x64_bf8_fp8 - amdgcn_mma, // smfmac_f32_32x32x64_fp8_bf8 - amdgcn_mma // smfmac_f32_32x32x64_fp8_fp8 + amdgcn_mma, // mfma_f32_16x16x32_f16 + amdgcn_mma, // mfma_f32_16x16x32_bf16 + amdgcn_mma, // mfma_f32_32x32x16_f16 + amdgcn_mma, // mfma_f32_32x32x16_bf16 + amdgcn_mma, // mfma_i32_16x16x64_i8 + amdgcn_mma, // mfma_i32_32x32x32_i8 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // smfmac_f32_16x16x64_f16 + amdgcn_mma, // smfmac_f32_32x32x32_f16 + amdgcn_mma, // smfmac_f32_16x16x64_bf16 + amdgcn_mma, // smfmac_f32_32x32x32_bf16 + amdgcn_mma, // smfmac_i32_16x16x128_i8 + amdgcn_mma, // smfmac_i32_32x32x64_i8 + amdgcn_mma, // smfmac_f32_16x16x128_bf8_bf8 + amdgcn_mma, // smfmac_f32_16x16x128_bf8_fp8 + amdgcn_mma, // smfmac_f32_16x16x128_fp8_bf8 + amdgcn_mma, // smfmac_f32_16x16x128_fp8_fp8 + amdgcn_mma, // smfmac_f32_32x32x64_bf8_bf8 + amdgcn_mma, // smfmac_f32_32x32x64_bf8_fp8 + amdgcn_mma, // smfmac_f32_32x32x64_fp8_bf8 + amdgcn_mma // smfmac_f32_32x32x64_fp8_fp8 >; using Gfx11Intrinsics = ::testing::Types< - amdgcn_mma, // wmma_f32_16x16x16_f16_w32 - amdgcn_mma, // wmma_f32_16x16x16_bf16_w32 - amdgcn_mma, // wmma_i32_16x16x16_iu8_w32 - amdgcn_mma // wmma_i32_16x16x16_iu4_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32 + amdgcn_mma, // wmma_f32_16x16x16_bf16_w32 + amdgcn_mma, // wmma_i32_16x16x16_iu8_w32 + amdgcn_mma // wmma_i32_16x16x16_iu4_w32 >; using Gfx12Intrinsics = ::testing::Types< - amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_bf16_w32_gfx12 - amdgcn_mma, // wmma_f16_16x16x16_f16_w32_gfx12 - amdgcn_mma, // wmma_bf16_16x16x16_bf16_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x16_iu8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x32_iu4_w32_gfx12 - amdgcn_mma, // swmmac_f32_16x16x32_f16_w32 - amdgcn_mma, // swmmac_f32_16x16x32_bf16_w32 - amdgcn_mma, // swmmac_f16_16x16x32_f16_w32 - amdgcn_mma, // swmmac_bf16_16x16x32_bf16_w32 - amdgcn_mma, // swmmac_i32_16x16x32_iu8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_fp8_fp8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_fp8_bf8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_bf8_fp8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_bf8_bf8_w32 - amdgcn_mma, // swmmac_i32_16x16x32_iu4_w32 - amdgcn_mma // swmmac_i32_16x16x64_iu4_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_bf16_w32_gfx12 + amdgcn_mma, // wmma_f16_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_bf16_16x16x16_bf16_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x16_iu8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x32_iu4_w32_gfx12 + amdgcn_mma, // swmmac_f32_16x16x32_f16_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_f16_16x16x32_f16_w32 + amdgcn_mma, // swmmac_bf16_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_i32_16x16x32_iu8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_bf8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_bf8_w32 + amdgcn_mma, // swmmac_i32_16x16x32_iu4_w32 + amdgcn_mma // swmmac_i32_16x16x64_iu4_w32 >; // clang-format on