diff --git a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp index 9e62f6e939..41b954a6de 100644 --- a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp +++ b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp @@ -73,17 +73,17 @@ int check_tile_distr_enc() // List of intrinsics to test. // clang-format off using Intrinsics = ck_tile::tuple< - amdgcn_mma, // mfma_f32_16x16x16f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_16x16x32_f16 - amdgcn_mma, // wmma_f32_16x16x16_f16_w32 - amdgcn_mma, // wmma_i32_16x16x16_iu4_w32 - amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 - amdgcn_mma // wmma_i32_16x16x32_iu4_w32_gfx12 + amdgcn_mma, // mfma_f32_16x16x16f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_16x16x32_f16 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32 + amdgcn_mma, // wmma_i32_16x16x16_iu4_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 + amdgcn_mma // wmma_i32_16x16x32_iu4_w32_gfx12 >; // clang-format on diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index b070d0c68a..47ba274a15 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -43,7 +43,6 @@ #include "ck_tile/core/arch/mma/sparse/sparse.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index d4330b8c73..938fc1791d 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -246,25 +246,14 @@ CK_TILE_HOST_DEVICE constexpr const char* to_string(Unsupported) { return "Unsup #if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -/** - * @concept HasExecSignature - * @brief Helper concept for exec signature check. - */ -template -concept HasExecSignature = requires { - { - MmaOp::exec(typename MmaOp::AVecType{}, - typename MmaOp::BVecType{}, - typename MmaOp::CVecType{}, - std::declval()...) - } -> std::convertible_to; -}; - /** * @concept MmaOpI * @brief Expresses the meta-data interface required for each MmaOp policy. */ // TODO: Make sure this actually matches amdgcn_mma. +// NOTE: It is no longer possible to perform a check on the exec() function, since it is now +// templated over the variadic WarpGemmParams template pack for intrinsic flags. It seems like +// concepts do not work for templated device functions. template concept MmaOpI = requires(MmaOp op) { // Requires an op context @@ -287,7 +276,7 @@ concept MmaOpI = requires(MmaOp op) { { MmaOp::kCMPerLane } -> std::convertible_to; { MmaOp::kCMNumAccess } -> std::convertible_to; { MmaOp::kCompressionRatio } -> std::convertible_to; -} && (HasExecSignature || HasExecSignature || HasExecSignature); +}; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER @@ -305,7 +294,6 @@ concept MmaOpI = requires(MmaOp op) { * @tparam FragM M-dimension of mma intrinsic (MmaTile) * @tparam FragN N-dimension of mma intrinsic (MmaTile) * @tparam FragK K-dimension of mma intrinsic (MmaTile) - * @tparam CtrlFlags Control flags for mma operation * @tparam CompilerTarget The current compiler target * @tparam OpFamily_ The type of operation (dense, sparse, scale, etc.) * @tparam Enabler SFINAE enabler @@ -316,7 +304,6 @@ template @@ -326,6 +313,7 @@ struct amdgcn_mma : amdgcn_mma_base CK_TILE_DEVICE static auto exec(AVecType const& regsA, BVecType const& regsB, CVecType const& regsC) { @@ -347,7 +335,6 @@ template @@ -357,7 +344,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma const& mmaObj) @@ -392,10 +378,6 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma) - { - print_flags(CtrlFlags{}); - } print(CompilerTarget{}); } diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp index b3f2a90cd4..ad4a055a06 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -17,6 +17,7 @@ #include "ck_tile/core/numeric/tfloat32.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -31,29 +32,26 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp32_t, fp32_t, fp32_t MMA operation on GFX9 * architecture. - * @tparam CtrlFlags Control flags for the MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -61,29 +59,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -91,29 +86,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -121,29 +113,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -151,29 +140,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -181,29 +167,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x1f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x1f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x1f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -211,29 +194,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x2f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x2f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x2f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -241,29 +221,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4f32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4f32(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x4f32( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -271,25 +248,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -297,25 +274,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -323,25 +300,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -349,25 +326,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -375,25 +352,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -401,25 +378,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x4f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -427,25 +404,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x8f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x8f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x8f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -453,25 +430,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x16f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x16f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x16f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -479,29 +456,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -509,29 +483,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -539,29 +510,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -569,29 +537,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -599,29 +564,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_4x4x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_4x4x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_4x4x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -629,29 +591,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_4x4x4i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_4x4x4i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_4x4x4i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -659,29 +618,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x8i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x8i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -689,29 +645,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x16i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x16i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -719,25 +672,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -745,25 +698,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -771,25 +724,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -797,25 +750,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -823,25 +776,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -849,25 +802,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x2bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x2bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_4x4x2bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -875,25 +828,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x4bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -901,25 +854,25 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x8bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x8bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x8bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -927,25 +880,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -953,25 +907,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -979,25 +934,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1005,25 +961,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1031,25 +988,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1057,25 +1015,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_4x4x4bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_4x4x4bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1083,25 +1042,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x8bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x8bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1109,25 +1069,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x16bf16_1k"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x16bf16_1k( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1135,31 +1096,32 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f64_16x16x4f64"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { // Note: BLGP flag has another meaning for f64 builtins: BLGP bits [0:2] cause negation of // the A, B, and C input matrices respectively (ref. ISA docs for MI300 Instinct) + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_f64_16x16x4f64(bit_cast(aVec), bit_cast(bVec), cVec, - CtrlFlags::Cbsz, // CBSZ ignored for f64 - CtrlFlags::Abid, // ABID ignored for f64 - CtrlFlags::Blgp)}; + P::cbsz, // CBSZ ignored for f64 + P::abid, // ABID ignored for f64 + P::blgp)}; } }; @@ -1167,30 +1129,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f64_4x4x4f64"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_f64_4x4x4f64(bit_cast(aVec), bit_cast(bVec), bit_cast(cVec), - CtrlFlags::Cbsz, // CBSZ ignored for f64 - CtrlFlags::Abid, // ABID ignored for f64 - CtrlFlags::Blgp)}; + P::cbsz, // CBSZ ignored for f64 + P::abid, // ABID ignored for f64 + P::blgp)}; } }; @@ -1198,30 +1161,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f64_4x4x4f64"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_f64_4x4x4f64(bit_cast(aVec), bit_cast(bVec), bit_cast(cVec), - CtrlFlags::Cbsz, // CBSZ ignored for f64 - CtrlFlags::Abid, // ABID ignored for f64 - CtrlFlags::Blgp)}; + P::cbsz, // CBSZ ignored for f64 + P::abid, // ABID ignored for f64 + P::blgp)}; } }; @@ -1229,29 +1193,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x32_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x32_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1259,29 +1220,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x16_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x16_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1289,26 +1247,27 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x8_xf32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x8_xf32( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x8_xf32(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1316,26 +1275,27 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // |A B C DataTypes |MNK | -struct amdgcn_mma> +struct amdgcn_mma> // |WS |AParams |BPar |CPar | : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x4_xf32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x4_xf32( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x4_xf32(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1343,29 +1303,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1373,29 +1330,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1403,29 +1357,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1433,29 +1384,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1463,29 +1411,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1493,29 +1438,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1523,29 +1465,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1553,29 +1492,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1583,25 +1519,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x32_f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1609,25 +1546,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_16x16x32_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_16x16x32_bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_16x16x32_bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1635,25 +1573,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_f16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x16_f16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1661,25 +1600,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_f32_32x32x16_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_f32_32x32x16_bf16( - aVec, bVec, cVec, CtrlFlags::Cbsz, CtrlFlags::Abid, CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_mfma_f32_32x32x16_bf16(aVec, bVec, cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1687,29 +1627,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_16x16x64_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_16x16x64_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_16x16x64_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; @@ -1717,29 +1654,26 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_i32_32x32x32_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { - return {__builtin_amdgcn_mfma_i32_32x32x32_i8(bit_cast(aVec), - bit_cast(bVec), - cVec, - CtrlFlags::Cbsz, - CtrlFlags::Abid, - CtrlFlags::Blgp)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_mfma_i32_32x32x32_i8( + bit_cast(aVec), bit_cast(bVec), cVec, P::cbsz, P::abid, P::blgp)}; } }; diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index 28af0c4568..233ee77526 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -53,7 +53,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp index b4f23fc9bc..29340fece7 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp @@ -3,16 +3,8 @@ #pragma once -#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/config.hpp" -#include "ck_tile/core/numeric/integer.hpp" - -#include -#include #include -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER namespace ck_tile::core::arch::mma { @@ -56,42 +48,4 @@ struct is_mma_op_mfma static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma::value; -/** - * @struct DefaultMfmaCtrlFlags - * @brief Default MFMA flags, no broadcasting or rotation of inputs - * @note For f64 MFMA instructions, CBSZ and ABID are ignored and BLGP is repurposed for matrix - * negation. BLGP bits [0:2] negate the A, B, and C input matrices respectively (ref. ISA docs for - * MI300 Instinct). - */ -struct DefaultMfmaCtrlFlags -{ - static constexpr int32_t Cbsz = 0; // CBSZ flag, default 0 - static constexpr int32_t Abid = 0; // ABID flag, default 0 - static constexpr int32_t Blgp = 0; // BLGP flag, default 0 -}; - -CK_TILE_HOST_DEVICE void print_flags(DefaultMfmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags Cbsz / Abid / Blgp : %" PRId32 " / %" PRId32 " / %" PRId32 "\n", - ctrlFlags.Cbsz, - ctrlFlags.Abid, - ctrlFlags.Blgp); -} - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -/** - * @concept CtrlFlagsGfx9I - * @brief Expresses the interface of required members for each CtrlFlags type on Gfx9 - */ -template -concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) { - // Flag members for Gfx9 MFMA instructions - { CtrlFlags::Cbsz } -> std::convertible_to; - { CtrlFlags::Abid } -> std::convertible_to; - { CtrlFlags::Blgp } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/include/ck_tile/core/arch/mma/mma_pipeline.hpp index 3198f2c41f..b687c1adc9 100644 --- a/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -16,88 +16,11 @@ #endif namespace ck_tile::core::arch::mma { -/*! @enum MmaPipelineOptionFlag - * @brief Individual option flags for configuring MmaPipeline behavior. - */ -enum struct MmaPipelineOptionFlag : unsigned -{ - NONE = 0x0, ///< No flags set - ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output - COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input -}; - -/** - * @struct MmaPipelineOptionFlags - * @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values. - * @par Provides bitwise OR, AND, NOT, and equality operators for composing - * and querying pipeline option flags. - */ -struct MmaPipelineOptionFlags -{ - using Type = std::underlying_type_t; - - explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {} - explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {} - constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag)) - { - } - constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original) - : mFlags(original.mFlags) - { - } - - constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue) - { - mFlags |= toType(addValue); - return *this; - } - constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const - { - MmaPipelineOptionFlags result(*this); - result |= addValue; - return result; - } - constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue) - { - mFlags &= toType(maskValue); - return *this; - } - constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const - { - MmaPipelineOptionFlags result(*this); - result &= maskValue; - return result; - } - constexpr MmaPipelineOptionFlags operator~() const - { - MmaPipelineOptionFlags result(*this); - result.mFlags = ~result.mFlags; - return result; - } - constexpr bool testFlag(MmaPipelineOptionFlag flag) const - { - return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag; - } - constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); } - constexpr bool operator==(Type rhs) const { return mFlags == rhs; } - - private: - Type mFlags; - static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast(f); } -}; - -constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs) -{ - return rhs == lhs; -} - /** * @class MmaPipelineBase * @brief CRTP base class that implements the common Mma pipeline logic shared by * all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.). * - * @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling - * pipeline behavior (e.g., C transposition, A compression). * @tparam Derived The concrete CRTP-derived pipeline class. Must expose: * - Type aliases: @c AWarpTensor, @c BWarpTensor, @c CWarpTensor, @c MmaOp * - Transform aliases: @c ATransform, @c BTransform, @c CTransform, @c DTransform @@ -107,14 +30,11 @@ constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOpt * 1. Apply pre-transforms to input buffers (A, B, C). * 2. Delegate to @c Derived::execImpl for the actual mma loop. * 3. Apply post-transform to output buffer (D). - * When @c ABSwap is set, the A and B inputs are swapped before step 1. + * When CTranspose is used, the A and B inputs are swapped before step 1. */ -// TODO: c++20: use MmaPipelineOptionFlags directly -template +template struct MmaPipelineBase { - static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); - /** * @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output). * @tparam ATensor Type of the A WaveTile tensor (static_distributed_tensor). @@ -125,17 +45,17 @@ struct MmaPipelineBase * @param accum Input/output accumulator WaveTile C. * @return The output WaveTile D after accumulation and post-transform. */ - template + template CK_TILE_DEVICE static decltype(auto) exec(ATensor& a, BTensor& b, CTensor& accum) { if constexpr(MmaOpTraits::IsSupported) { - if constexpr(Flags & MmaPipelineOptionFlag::ABSwap) + if constexpr(Derived::CTranspose) { decltype(auto) a_transformed = Derived::ATransform::exec(b); decltype(auto) b_transformed = Derived::BTransform::exec(a); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::execImpl(a_transformed, b_transformed, c_transformed); + Derived::template execImpl(a_transformed, b_transformed, c_transformed); return Derived::DTransform::exec(c_transformed); } else @@ -143,7 +63,7 @@ struct MmaPipelineBase decltype(auto) a_transformed = Derived::ATransform::exec(a); decltype(auto) b_transformed = Derived::BTransform::exec(b); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::execImpl(a_transformed, b_transformed, c_transformed); + Derived::template execImpl(a_transformed, b_transformed, c_transformed); return Derived::DTransform::exec(c_transformed); } } @@ -153,7 +73,7 @@ struct MmaPipelineBase // Code should not reach here, but HOST/DEVICE compile passes are // weirdly intertwined and instead of having constexpr in the calling // site (tests) we do this. See also changes by this commit. - return Derived::MmaOp::exec({}, {}, {}); + return Derived::MmaOp::template exec({}, {}, {}); } } @@ -162,11 +82,10 @@ struct MmaPipelineBase template CK_TILE_DEVICE void operator()(CTensor& c, ATensor& a, const BTensor& b) const { - exec(a, b, c); + exec(a, b, c); } - template ::IsSupported) { - if constexpr(Flags & MmaPipelineOptionFlag::ABSwap) + if constexpr(Derived::CTranspose) { // TODO: Figure out which combination of a/b, scale_A/B, and opselA/B needs to be // AB-swapped in order to get correct results. Note that WarpGemmParamsParser @@ -188,7 +107,7 @@ struct MmaPipelineBase decltype(auto) a_transformed = Derived::ATransform::exec(b); decltype(auto) b_transformed = Derived::BTransform::exec(a); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::template execImpl( + Derived::template execImpl( a_transformed, b_transformed, c_transformed, scale_A, scale_B); return Derived::DTransform::exec(c_transformed); } @@ -197,7 +116,7 @@ struct MmaPipelineBase decltype(auto) a_transformed = Derived::ATransform::exec(a); decltype(auto) b_transformed = Derived::BTransform::exec(b); decltype(auto) c_transformed = Derived::CTransform::exec(accum); - Derived::template execImpl( + Derived::template execImpl( a_transformed, b_transformed, c_transformed, scale_A, scale_B); return Derived::DTransform::exec(c_transformed); } @@ -219,8 +138,7 @@ struct MmaPipelineBase const int32_t& a_scale, const int32_t& b_scale) const { - using P = WarpGemmParamsParser; - exec(a, b, c, a_scale, b_scale); + exec(a, b, c, a_scale, b_scale); } }; @@ -232,8 +150,8 @@ struct MmaPipelineBase * @concept MmaPipelineI * @brief Expresses the meta-data interface required for a CRTP MmaPipeline. */ -template -concept MmaPipelineInterface = std::derived_from>; +template +concept MmaPipelineInterface = std::derived_from>; #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp index 8491f96837..b8d6f31558 100644 --- a/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -49,7 +49,6 @@ struct MmaDefaultSelector WaveTileM, WaveTileN, WaveTileK, - void, amdgcn_target<>, MmaOpFamily::UNDEFINED>; }; @@ -88,7 +87,6 @@ template MmaOpFamily OpFamily> struct MmaKSearchSelector @@ -102,7 +100,6 @@ struct MmaKSearchSelector WaveTileM, WaveTileN, WaveTileKTest, - CtrlFlags, CompilerTarget, OpFamily>; @@ -118,7 +115,6 @@ struct MmaKSearchSelector WaveTileM, WaveTileN, WaveTileKTest / 2u, - CtrlFlags, CompilerTarget, OpFamily>::SelectedOp>; }; @@ -128,7 +124,6 @@ template MmaOpFamily OpFamily> struct MmaKSearchSelector { // Recursion endpoint: unsupported default implementation. - using SelectedOp = amdgcn_mma; + using SelectedOp = + amdgcn_mma; }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp index 88764a75b0..863ad07bbb 100644 --- a/include/ck_tile/core/arch/mma/mma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -6,11 +6,8 @@ #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/config.hpp" #include "mfma/mfma_traits.hpp" -#include "scale/scale_traits.hpp" -#include "sparse/sparse_traits.hpp" #include "wmma/wmma_traits.hpp" -#include #include #include @@ -61,7 +58,6 @@ struct MmaOpTraits; * @tparam FragM_ Size of the M dimension * @tparam FragN_ Size of the N dimension * @tparam FragK_ Size of the K dimension - * @tparam CtrlFlags_ Control flags for the MMA operation * @tparam CompilerTarget_ The compiler target */ template // TODO: c++20 amdgcn_target_arch_id CompilerTarget_> @@ -80,7 +75,6 @@ struct MmaOpTraits> { @@ -90,12 +84,10 @@ struct MmaOpTraits; // Capture incoming template parameters not already in amdgcn - using CtrlFlags = CtrlFlags_; using CompilerTarget = CompilerTarget_; // TODO c++20static constexpr amdgcn_target_arch_id GfxTargetId = CompilerTarget_; @@ -115,7 +107,6 @@ template CK_TILE_HOST_DEVICE void print(MmaOpTraits> const& traitsObj) { @@ -134,7 +124,6 @@ CK_TILE_HOST_DEVICE void print(MmaOpTraits{}); printf( diff --git a/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/include/ck_tile/core/arch/mma/mma_wavewise.hpp index 5894a520ea..253457525b 100644 --- a/include/ck_tile/core/arch/mma/mma_wavewise.hpp +++ b/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -28,15 +28,6 @@ enum struct MmaAccumPolicy COL_MAJOR }; -namespace dense::wavewise::detail { -// TODO: c++20: return MmaPipelineOptionFlags directly -template -constexpr inline int getPipelineFlags() -{ - return static_cast(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE); -} -} // namespace dense::wavewise::detail - /** * @class Mma * @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation @@ -50,7 +41,7 @@ constexpr inline int getPipelineFlags() * @tparam WaveTileN Mma WaveTile N dimension * @tparam WaveTileK Mma WaveTile K dimension * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout. + * @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout. * @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation. * @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp. * @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp. @@ -72,7 +63,7 @@ template ::SelectedTransforms> // clang-format off -struct WaveWiseMmaPipeline : public MmaPipelineBase(), WaveWiseMmaPipeline> +struct WaveWiseMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase(), WaveWiseMmaPipeline>; + using Base = MmaPipelineBase>; // clang-format on - using MmaOp = MmaOp_; + using MmaOp = MmaOp_; + static constexpr bool CTranspose = CTranspose_; using ADataType = typename MmaOp::ADataType; using BDataType = typename MmaOp::BDataType; @@ -185,7 +177,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase + template CK_TILE_DEVICE static void execImpl(ATensor& a, BTensor& b, CTensor& c) { static_assert( @@ -205,9 +197,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn)); } } } @@ -220,9 +213,10 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn)); } } } diff --git a/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp b/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp index f9245dc06f..32dd252d0d 100644 --- a/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/scale/mfma/scale_gfx9.hpp @@ -13,6 +13,7 @@ #include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -23,14 +24,13 @@ namespace ck_tile::core::arch::mma { * This specialization implements the Scale MFMA instruction for fp8_t A and B * matrices with fp32_t accumulator, with 16x16x128 block sizes. * - * @tparam CtrlFlags Control flags for the Scale MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -38,19 +38,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -62,14 +63,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -77,19 +77,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -101,14 +102,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -116,10 +116,11 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; int32x4_t arg_a = bit_cast(aVec); int32x4_t arg_b = bit_cast(bVec); @@ -129,9 +130,9 @@ struct amdgcn_mma, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -143,33 +144,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -182,33 +183,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -221,14 +222,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -236,19 +236,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -260,14 +261,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -275,19 +275,20 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( bit_cast(aVec), bit_cast(bVec), cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -299,14 +300,13 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | // clang-format on @@ -314,10 +314,11 @@ struct amdgcn_mma + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; int32x4_t arg_a = bit_cast(aVec); int32x4_t arg_b = bit_cast(bVec); @@ -327,9 +328,9 @@ struct amdgcn_mma, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -341,33 +342,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; @@ -380,33 +381,33 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | { static constexpr const char* instruction_name = "__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4"; - template + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int scale_A, int scale_B) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( int32x8_t{aVec.data[0], aVec.data[1], aVec.data[2], aVec.data[3], aVec.data[4], aVec.data[5], 0, 0}, int32x8_t{bVec.data[0], bVec.data[1], bVec.data[2], bVec.data[3], bVec.data[4], bVec.data[5], 0, 0}, cVec, scale::detail::ScaleDataTypeToFlag_v, scale::detail::ScaleDataTypeToFlag_v, - opselA, + P::op_sel_a, scale_A, - opselB, + P::op_sel_b, scale_B)}; } }; diff --git a/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp b/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp index cb7e68a2c7..bfdb78de09 100644 --- a/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp +++ b/include/ck_tile/core/arch/mma/scale/mfma/selector.hpp @@ -55,7 +55,6 @@ struct MmaDefaultSelector; }; diff --git a/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp index 45c3d71789..bed2a6506b 100644 --- a/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp @@ -32,7 +32,7 @@ namespace ck_tile::core::arch::mma { * @tparam WaveTileN Mma WaveTile N dimension * @tparam WaveTileK Mma WaveTile K dimension * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout. + * @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout. * @tparam SwizzleFactor Swizzlefactor for Tile Distribution Encoding calculation. * @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp. * @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp. @@ -47,7 +47,7 @@ template ::SelectedTransforms> // clang-format off -struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> +struct ScaleMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; + using Base = MmaPipelineBase>; // clang-format on - using MmaOp = MmaOp_; // Expose the selected MmaOp + using MmaOp = MmaOp_; // Expose the selected MmaOp + static constexpr bool CTranspose = CTranspose_; using ADataType = typename MmaOp::ADataType; using BDataType = typename MmaOp::BDataType; @@ -170,8 +171,7 @@ struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOpt static_assert(WaveTileK % MmaOp::kK == 0u, "WaveTileK must be a multiple of MmaOp::kK"); // TODO: Why does this even need to be a template? The types should be known. - template (MmaPipelineOpt for(uint32_t bk = 0u; bk < FragsK; ++bk) { c_buf.at(bm * FragsN + bn) = - MmaOp::template exec(a_buf.at(bm * FragsK + bk), - b_buf.at(bn * FragsK + bk), - c_buf.at(bm * FragsN + bn), - scale_A, - scale_B); + MmaOp::template exec(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn), + scale_A, + scale_B); } } } @@ -216,11 +216,11 @@ struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOpt for(uint32_t bk = 0u; bk < FragsK; ++bk) { c_buf.at(bm * FragsN + bn) = - MmaOp::template exec(a_buf.at(bm * FragsK + bk), - b_buf.at(bn * FragsK + bk), - c_buf.at(bm * FragsN + bn), - scale_A, - scale_B); + MmaOp::template exec(a_buf.at(bm * FragsK + bk), + b_buf.at(bn * FragsK + bk), + c_buf.at(bm * FragsN + bn), + scale_A, + scale_B); } } } diff --git a/include/ck_tile/core/arch/mma/scale/scale_traits.hpp b/include/ck_tile/core/arch/mma/scale/scale_traits.hpp index fa55522015..ea879feebc 100644 --- a/include/ck_tile/core/arch/mma/scale/scale_traits.hpp +++ b/include/ck_tile/core/arch/mma/scale/scale_traits.hpp @@ -3,85 +3,32 @@ #pragma once -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/pk_f6.hpp" -#include -#include -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - namespace ck_tile::core::arch::mma { - namespace scale::detail { +// Utility for converting the datatype of the A or B input matrix in a scale intrinsics to the +// appropriate datatype flag. Note that this is not the same as the flag indicating the scale +// datatype, see ScaleDataTypeToEnum. template -struct ScaleDataTypeToFlag; - +inline constexpr int32_t ScaleDataTypeToFlag_v = [] { + // sizeof(T) trick to only trigger the static assert for unsupported datatypes. + static_assert(sizeof(T) == 0, "Unsupported scale data type"); + return -1; +}(); template <> -struct ScaleDataTypeToFlag // e4m3 (4 exponent bits 3 mantissa bits) -{ - static constexpr int32_t value = 0; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 0; // e4m3 template <> -struct ScaleDataTypeToFlag // e5m2 -{ - static constexpr int32_t value = 1; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 1; // e5m2 template <> -struct ScaleDataTypeToFlag // e2m3 -{ - static constexpr int32_t value = 2; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 2; // e2m3 template <> -struct ScaleDataTypeToFlag // e3m2 -{ - static constexpr int32_t value = 3; -}; - +inline constexpr int32_t ScaleDataTypeToFlag_v = 3; // e3m2 template <> -struct ScaleDataTypeToFlag // e2m1 -{ - static constexpr int32_t value = 4; -}; - -template -inline constexpr int32_t ScaleDataTypeToFlag_v = ScaleDataTypeToFlag::value; - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -/** - * @concept ScaleMfmaDataTypeToFlag - * @brief Expresses the interface of required members for each DataTypeToFlag type on Gfx9 - */ -template -concept ScaleMfmaDataTypeToFlag = requires(DataTypeToFlag dataTypeToFlag) { - // Flag members for scale MFMA instructions - { DataTypeToFlag::value } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +inline constexpr int32_t ScaleDataTypeToFlag_v = 4; // e2m1 } // namespace scale::detail - -// No real flags for now, scale and opsel are handled in higher level and passed down directly. -// OPSEL is now passed as a template arg to exec(), see mma_pipeline.hpp -// We will soon get rid of these flags entirely in favor of variadic template packs passed down to -// the intrinsics directly, see WarpGemmParamsParser<>. -struct DefaultScaleMfmaCtrlFlags -{ -}; - -CK_TILE_HOST_DEVICE void print_flags([[maybe_unused]] DefaultScaleMfmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags: (empty)\n"); -} - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp index 31a84ebf13..9ea90daab8 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp @@ -7,7 +7,6 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { @@ -55,7 +54,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp index 994489e9d0..bdb50cc232 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -7,11 +7,11 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" #include @@ -24,55 +24,55 @@ namespace ck_tile::core::arch::mma { * This specialization implements the SMFMA instruction for fp16_t A and B * matrices with structured sparsity, fp32_t accumulator, with 16x16x32 fragment sizes. * - * @tparam CtrlFlags Control flags for the Sparse MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_f16"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x32_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; - } + using P = WarpGemmParamsParser; + return __builtin_amdgcn_smfmac_f32_16x16x32_f16( + aVec, + bVec, + cVec, + idx, + P::cbsz, // Ignore abid and use first portion Y/N + P::abid); // Portion of idx VGPR containing idx info + }; }; /** * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX942 and GFX950 * architecture. - * @tparam CtrlFlags Control flags for the MFMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x16_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x16_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -80,27 +80,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x32_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_16x16x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -108,27 +105,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x16_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x16_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -136,27 +130,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x64_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_16x16x64_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_16x16x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -164,27 +155,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x32_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_32x32x32_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_32x32x32_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -192,27 +180,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -220,27 +206,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -248,27 +232,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -276,27 +258,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -304,27 +284,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -332,27 +310,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -360,27 +336,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -388,27 +362,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x32_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -416,27 +388,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_16x16x64_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -444,27 +413,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_f16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_f16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x32_f16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -472,27 +438,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x64_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_16x16x64_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -500,27 +463,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x32_bf16"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_f32_32x32x32_bf16(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -528,27 +488,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_16x16x128_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_16x16x128_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_16x16x128_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -556,27 +513,24 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_i32_32x32x64_i8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_i32_32x32x64_i8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return {__builtin_amdgcn_smfmac_i32_32x32x64_i8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -584,27 +538,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -612,27 +564,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -640,27 +590,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -668,27 +616,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_16x16x128_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -696,27 +642,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_bf8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -724,27 +668,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_bf8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -752,27 +694,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_fp8_bf8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; @@ -780,27 +720,25 @@ struct amdgcn_mma -// TODO: c++20 requires -template +// TODO: c++20 template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { - using namespace sparse::detail; - static constexpr BuiltinParams PARAMS = getBuiltinParams(); - return {__builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8( - aVec, bVec, cVec, idx, PARAMS.UseFirstIndex, PARAMS.ByteIndexToOverride)}; + using P = WarpGemmParamsParser; + return { + __builtin_amdgcn_smfmac_f32_32x32x64_fp8_fp8(aVec, bVec, cVec, idx, P::cbsz, P::abid)}; } }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/sparse.hpp b/include/ck_tile/core/arch/mma/sparse/sparse.hpp index e9792196c5..303687b1f8 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse.hpp @@ -11,5 +11,4 @@ namespace ck_tile::core::arch::mma { #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" -#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index b2d5d5fac4..b39cdae770 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -12,16 +12,6 @@ namespace ck_tile::core::arch::mma { -namespace sparse::detail { -// TODO: c++20: return MmaPipelineOptionFlags directly -template -constexpr inline int getPipelineFlags() -{ - return static_cast(MmaPipelineOptionFlag::COMPRESS_A) | - static_cast(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE); -} -} // namespace sparse::detail - /** * @class SparseMmaPipeline * @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation @@ -38,7 +28,7 @@ constexpr inline int getPipelineFlags() * @tparam WaveTileN Mma WaveTile N dimension * @tparam WaveTileK Mma WaveTile K dimension * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CTranspose Swaps A and B input vectors and interprets C with transposed layout. + * @tparam CTranspose_ Swaps A and B input vectors and interprets C with transposed layout. * @tparam SwizzleFactor SwizzleFactor for Tile Distribution Encoding calculation. * @tparam AttrNumAccessAV Extra unmerge factor for vector dimension for A vec, see amdgcn_mma.hpp. * @tparam AttrNumAccessBV Extra unmerge factor for vector dimension for B vec, see amdgcn_mma.hpp. @@ -53,7 +43,7 @@ template ::SelectedTransforms> // clang-format off -struct SparseMmaPipeline : public MmaPipelineBase(), SparseMmaPipeline> +struct SparseMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase(), SparseMmaPipeline>; + using Base = MmaPipelineBase>; // clang-format on - using MmaOp = MmaOp_; + using MmaOp = MmaOp_; + static constexpr bool CTranspose = CTranspose_; using ADataType = typename MmaOp::ADataType; using BDataType = typename MmaOp::BDataType; @@ -86,8 +77,7 @@ struct SparseMmaPipeline : public MmaPipelineBase::IsSupported || std::is_same_v); static_assert(!MmaOpTraits::IsSupported || std::is_same_v); static_assert(!MmaOpTraits::IsSupported || std::is_same_v); - static_assert(!(Base::Flags & MmaPipelineOptionFlag::ABSwap), - "Cannot transpose C in sparse intrinsics."); + static_assert(!CTranspose, "Cannot transpose C in sparse intrinsics."); // WaveTile dimensions (Used to be fragment dims but higher level expects these to include k // iteration!) @@ -180,7 +170,7 @@ struct SparseMmaPipeline : public MmaPipelineBase + template CK_TILE_DEVICE static void execImpl(ATransformResult& a, BTensor& b, CTensor& c) { static_assert( @@ -206,7 +196,7 @@ struct SparseMmaPipeline : public MmaPipelineBase( a_frags[bm][bk], b_buf.at(bn * FragsK + bk), c_buf.at(bm * FragsN + bn), @@ -224,7 +214,7 @@ struct SparseMmaPipeline : public MmaPipelineBase( a_frags[bm][bk], b_buf.at(bn * FragsK + bk), c_buf.at(bm * FragsN + bn), diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp deleted file mode 100644 index f5132b89db..0000000000 --- a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core/config.hpp" - -#include -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include -#endif - -namespace ck_tile::core::arch::mma { - -/** - * @enum SparseCompressionIndex - * @brief Indicates which set of sparse-indices within a VGPR starting at srcC - * containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data) - * of index information for a lane. \see DefaultSparseMfmaCtrlFlags - */ -enum struct SparseCompressionIndex : int -{ - FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively - SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively - THIRD = 2, // Uses bits [23:16] - FOURTH = 3, // Uses bits [31:24] -}; - -// to_string methods for enum classes -CK_TILE_HOST_DEVICE constexpr const char* to_string(SparseCompressionIndex compressionIndex) -{ - switch(compressionIndex) - { - case SparseCompressionIndex::FIRST: return "FIRST"; - case SparseCompressionIndex::SECOND: return "SECOND"; - case SparseCompressionIndex::THIRD: return "THIRD"; - case SparseCompressionIndex::FOURTH: return "FOURTH"; - } - __builtin_unreachable(); -} - -namespace sparse::detail { - -/** - * @struct BuiltinParams - * @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse - * builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data: - * If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC - * containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected - * (VGPR[srcC][7..0]). - * - * 8-bit source data: - * If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC - * containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected - * (VGPR[srcC][15..0]). - */ -struct BuiltinParams -{ - int UseFirstIndex; // CBSZ - int ByteIndexToOverride; // ABID -}; - -template -static constexpr BuiltinParams getBuiltinParams() -{ - // TODO c++20: designated initializers - if constexpr(Idx == SparseCompressionIndex::FIRST) - { - return BuiltinParams{1, 0}; - } - else - { - return BuiltinParams{0, static_cast(Idx)}; - } -} - -} // namespace sparse::detail - -/** - * @struct DefaultSparseMfmaCtrlFlags - * @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is - * 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit. - */ -struct DefaultSparseMfmaCtrlFlags -{ - static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST; -}; - -CK_TILE_HOST_DEVICE void print_flags(DefaultSparseMfmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags CompressionIndex : %s\n", to_string(ctrlFlags.CompressionIndex)); -} - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -/** - * @concept SparseMfmaCtrlFlags - * @brief Expresses the interface of required members for each CtrlFlags type - */ -template -concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { - // Flag members for sparse MFMA instructions - { CtrlFlags::CompressionIndex } -> std::convertible_to; -}; -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - -} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp index 1829741d37..87d46bd98f 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/selector.hpp @@ -53,7 +53,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index 2257cf7db8..1418bc909c 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -18,20 +19,20 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 * architecture. - * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -43,20 +44,20 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -68,20 +69,20 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -93,21 +94,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -119,30 +120,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32(true, // A signedness aVec, true, // B signedness bVec, cVec, idx, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -150,21 +152,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -176,21 +178,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -202,21 +204,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -228,21 +230,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { @@ -250,51 +252,55 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, idx, - CtrlFlags::Clamp)}; + P::clamp)}; } }; -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec, int32_t idx) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, idx, - CtrlFlags::Clamp)}; + P::clamp)}; } }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp index ec89e26ebc..2c4767fde1 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp @@ -16,6 +16,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { // TODO: Specifically for gfx11 wmma, we need to deal with quirks such as: @@ -46,20 +47,20 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX11 * architecture. - * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -71,20 +72,20 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -96,29 +97,30 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -126,29 +128,30 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp index 92057b1446..9146d6e250 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_gfx12.hpp @@ -17,6 +17,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" namespace ck_tile::core::arch::mma { @@ -31,21 +32,21 @@ namespace ck_tile::core::arch::mma { * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for fp16_t, fp16_t, fp32_t MMA operation on GFX12 * architecture. - * @tparam CtrlFlags Control flags for the WMMA operation * @tparam CompilerTarget Current compiler target */ -// TODO: c++20 template +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -57,21 +58,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -83,21 +84,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -109,21 +110,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -135,30 +136,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -166,21 +168,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -193,21 +195,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -220,21 +222,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -247,21 +249,21 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { @@ -274,30 +276,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12"; + template CK_TILE_DEVICE static CVecType exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; @@ -305,30 +308,31 @@ struct amdgcn_mma +// TODO: c++20 template // TODO: c++20 requires -template +template // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { static constexpr const char* instruction_name = "__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12"; - CK_TILE_DEVICE static auto - exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) -> CVecType + template + CK_TILE_DEVICE static CVecType + exec(AVecType const& aVec, BVecType const& bVec, CVecType const& cVec) { + using P = WarpGemmParamsParser; return {__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12(true, // A signedness bit_cast(aVec), true, // B signedness bit_cast(bVec), cVec, - CtrlFlags::Clamp)}; + P::clamp)}; } }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp index 2f75d68d46..645fdc81e6 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_selector.hpp @@ -52,7 +52,6 @@ struct MmaDefaultSelector::SelectedOp; }; diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp index 1c7c3e9276..59375da04f 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_traits.hpp @@ -4,8 +4,6 @@ #pragma once #include "ck_tile/core/config.hpp" - -#include #include namespace ck_tile::core::arch::mma { @@ -50,25 +48,4 @@ struct is_mma_op_wmma static constexpr bool is_mma_op_wmma_v = is_mma_op_wmma::value; -/** - * @struct DefaultWmmaCtrlFlags - * @brief Default WMMA control flags for dense and sparse WMMA operations. - */ -struct DefaultWmmaCtrlFlags -{ - constexpr static bool Clamp = false; - - // Only has an effect on gfx11 when the accumulator is 16-bit. - // Determines which half of the 32-bit accum register to use for the 16-bit result. - // false = low bits [15:0], true = high bits [31:16] - constexpr static bool UseHighAccumBits = true; -}; - -CK_TILE_HOST_DEVICE void print_flags(DefaultWmmaCtrlFlags const& ctrlFlags) -{ - printf("CtrlFlags Clamp / UseHighAccumBits : %d / %d\n", - ctrlFlags.Clamp, - ctrlFlags.UseHighAccumBits); -} - } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp index fd9cd69813..7a04fb4633 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp @@ -9,21 +9,6 @@ namespace ck_tile::core::arch::mma { -/** - * @struct DuplicateTransform - * @brief Transform to duplicate low register elements to high register elements - */ -struct DuplicateTransform -{ - template - CK_TILE_DEVICE static decltype(auto) exec(VecType&& v) - { - // TODO: Implement duplication logic to broadcast low - // register elements to high elements [0 - (N/2 -1)] -> [N/2 - (N-1)] - return std::forward(v); - } -}; - /** * @struct PadTransform * @brief Transform to pad data from original type to b32 type @@ -59,8 +44,8 @@ struct UnpadTransform */ struct MmaDefaultTransformsGfx11 { - using ATransform = DuplicateTransform; - using BTransform = DuplicateTransform; + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; using CTransform = PadTransform; using DTransform = UnpadTransform; }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp index ace20b923e..359b5bbca9 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_params.hpp @@ -83,6 +83,21 @@ struct SwapReuse_ : bool_constant { }; +template +struct Cbsz : number +{ +}; + +template +struct Abid : number +{ +}; + +template +struct Blgp : number +{ +}; + struct WarpGemmDefaultParams { using clamp = bool_constant; @@ -94,6 +109,9 @@ struct WarpGemmDefaultParams using swap_reuse = bool_constant; // internal use only using scale_a = number<0>; using scale_b = number<0>; + using cbsz = number<0>; + using abid = number<0>; + using blgp = number<0>; }; template class Tag> @@ -151,6 +169,9 @@ class WarpGemmParamsParser public: static constexpr bool clamp = extract(); static constexpr bool post_nop = extract(); + static constexpr index_t cbsz = extract(); + static constexpr index_t abid = extract(); + static constexpr index_t blgp = extract(); static constexpr bool reuse_a = swap_reuse ? raw_reuse_b : raw_reuse_a; static constexpr bool reuse_b = swap_reuse ? raw_reuse_a : raw_reuse_b; static constexpr index_t op_sel_a = swap_reuse ? raw_op_sel_b : raw_op_sel_a; diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index 1a62205490..727098269b 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -87,6 +87,3 @@ if(GPU_TARGETS MATCHES "gfx120") target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() -_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) -target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp deleted file mode 100644 index cc6cee9b3e..0000000000 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include -#include -#include - -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/arch/mma/mma_pipeline.hpp" - -namespace { -using namespace ck_tile::core::arch::mma; -} - -TEST(MmaPipelineOptionFlagsTests, ConversionTests) -{ - MmaPipelineOptionFlags flags_0{}; - MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap}; - MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A}; - MmaPipelineOptionFlags flags_3{0b11}; // TODO c++20 - remove this - - EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::ABSwap)); - - EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE(flags_3.testFlag(MmaPipelineOptionFlag::NONE)); -} - -TEST(MmaPipelineOptionFlagsTests, OperatorsTests) -{ - MmaPipelineOptionFlags flags{}; - - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - - flags |= MmaPipelineOptionFlag::ABSwap; - - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); - - flags |= MmaPipelineOptionFlag::COMPRESS_A; - - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - flags &= MmaPipelineOptionFlag::COMPRESS_A; - - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); - - EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE)); - EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::ABSwap)); - EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A)); -} diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp index e424b10d34..864997d5f1 100644 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_scale_mma.cpp @@ -40,7 +40,6 @@ void ScaleMfmaGfx950Specialization_impl() WaveTileM, WaveTileN, WaveTileK, - DefaultScaleMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SCALE>; @@ -79,10 +78,7 @@ TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization) std::cout << "GFX950 scale MFMA specialization is correct" << std::endl; } -// TODO: It seems like the ExecSignature concept (and hence MmaOpI) can not be made to work for a -// templated device function for some reason. Disable test for now and fix this once we are using -// the variadic template pack for flags... -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0 +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER template ; @@ -107,7 +102,7 @@ void TestConceptRequirements_impl() TEST(ScaleMMATrait, TestConceptRequirements) { -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER && 0 +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER TestConceptRequirements_impl(); TestConceptRequirements_impl(); TestConceptRequirements_impl(); @@ -216,9 +211,7 @@ struct ScalePipelineKernel constexpr int32_t replicate_byte = 0x01010101; ScaleAType scale_a = 126u * replicate_byte; ScaleBType scale_b = 129u * replicate_byte; - static constexpr index_t opselA = 0; - static constexpr index_t opselB = 0; - Pipeline::template exec(a, b, c, scale_a, scale_b); + Pipeline::template exec, OpSelB<0>>(a, b, c, scale_a, scale_b); __builtin_memcpy( static_cast(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor)); } @@ -399,9 +392,7 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real) // constexpr int32_t replicate_byte = 0x01010101; // ScaleAType scale_a = 126u * replicate_byte; // ScaleBType scale_b = 129u * replicate_byte; -// static constexpr index_t opselA = 0; -// static constexpr index_t opselB = 0; -// Pipeline::template exec(a, b, c, scale_a, scale_b); +// Pipeline::template exec, OpSelB<0>>(a, b, c, scale_a, scale_b); // __builtin_memcpy( // static_cast(c_per_lane) + lane * sizeof(CTensor), &c, sizeof(CTensor)); // } diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp index 5d44d3333b..ff156e8692 100644 --- a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp @@ -39,7 +39,6 @@ TEST(SparseMMATrait, SparseMfmaGfx950Specialization) 16u, 16u, 32u, - DefaultSparseMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SPARSE>; @@ -60,7 +59,6 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration) 16u, 16u, 32u, - DefaultSparseMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SPARSE>; @@ -83,7 +81,6 @@ TEST(SparseMMATrait, TestConceptRequirements) 16u, 16u, 32u, - DefaultSparseMfmaCtrlFlags, CompilerTargetGfx950, MmaOpFamily::SPARSE>; EXPECT_TRUE(MmaOpI); @@ -95,15 +92,8 @@ TEST(SparseMMATrait, TestConceptRequirements) TEST(SparseMMATrait, DenseVsSparseDistinction) { // Dense MFMA from mfma/mfma_gfx9.hpp - using DenseMfma = amdgcn_mma; + using DenseMfma = + amdgcn_mma; // Sparse MFMA on GFX950 using SparseMfma = amdgcn_mma; diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 8c8109e78d..f15c21bfe6 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -27,9 +27,6 @@ using namespace ck_tile::core::arch::testing; constexpr uint32_t DummyTargetIdVal = 55555u; using DummyCompilerTarget = amdgcn_target(DummyTargetIdVal)>; struct DummyOpType; -struct DummyCtrlFlags -{ -}; /** @brief Returns true if the given target id matches the dummy */ constexpr bool is_dummy_target(DummyCompilerTarget dummy) @@ -49,7 +46,7 @@ using enable_if_target_id_dummy_t = std::enable_if_t // clang-format off // | A B C DataTypes | MNK + WaveSize |AParams |BPar |CPar | -struct amdgcn_mma> +struct amdgcn_mma> : amdgcn_mma_base // clang-format on { @@ -63,15 +60,8 @@ struct amdgcn_mma template -using DummyAmdgcnMma = amdgcn_mma; +using DummyAmdgcnMma = + amdgcn_mma; /*! @struct MmaDefaultSelector * @brief For dummy Id only, instantiate tests for both MFMA and WMMA selectors so we can them both diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc index 1a3bc55aaf..7cd471f562 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma_layout.inc @@ -23,6 +23,7 @@ #include "ck_tile/host/hip_check_error.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_config.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_params.hpp" #include #include @@ -256,12 +257,10 @@ struct MmaLayoutTestKernel { // The actual scale is computed as pow(2, scale - 127), so: // 125 -> 2^-2 and 129 -> 2^2. - int scale_A = 125; - int scale_B = 129; - static constexpr index_t opselA = 0; - static constexpr index_t opselB = 0; - c_frag = - MmaOp::template exec(a_frag, b_frag, c_frag, scale_A, scale_B); + int scale_A = 125; + int scale_B = 129; + c_frag = MmaOp::template exec, OpSelB<0>>( + a_frag, b_frag, c_frag, scale_A, scale_B); } else { @@ -357,145 +356,145 @@ void run_mma_layout_test() // available on all gfx9 (gfx908, gfx90a, gfx942, gfx950) using Gfx9CommonIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_32x32x1f32 - amdgcn_mma, // mfma_f32_32x32x1f32 - amdgcn_mma, // mfma_f32_16x16x1f32 - amdgcn_mma, // mfma_f32_16x16x1f32 - amdgcn_mma, // mfma_f32_4x4x1f32 - amdgcn_mma, // mfma_f32_4x4x1f32 - amdgcn_mma, // mfma_f32_32x32x2f32 - amdgcn_mma, // mfma_f32_16x16x4f32 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_32x32x4f16 - amdgcn_mma, // mfma_f32_16x16x4f16 - amdgcn_mma, // mfma_f32_16x16x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_4x4x4f16 - amdgcn_mma, // mfma_f32_32x32x8f16 - amdgcn_mma, // mfma_f32_16x16x16f16 - amdgcn_mma, // mfma_i32_32x32x4i8 - amdgcn_mma, // mfma_i32_32x32x4i8 - amdgcn_mma, // mfma_i32_16x16x4i8 - amdgcn_mma, // mfma_i32_16x16x4i8 - amdgcn_mma, // mfma_i32_4x4x4i8 - amdgcn_mma // mfma_i32_4x4x4i8 + amdgcn_mma, // mfma_f32_32x32x1f32 + amdgcn_mma, // mfma_f32_32x32x1f32 + amdgcn_mma, // mfma_f32_16x16x1f32 + amdgcn_mma, // mfma_f32_16x16x1f32 + amdgcn_mma, // mfma_f32_4x4x1f32 + amdgcn_mma, // mfma_f32_4x4x1f32 + amdgcn_mma, // mfma_f32_32x32x2f32 + amdgcn_mma, // mfma_f32_16x16x4f32 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_32x32x4f16 + amdgcn_mma, // mfma_f32_16x16x4f16 + amdgcn_mma, // mfma_f32_16x16x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_4x4x4f16 + amdgcn_mma, // mfma_f32_32x32x8f16 + amdgcn_mma, // mfma_f32_16x16x16f16 + amdgcn_mma, // mfma_i32_32x32x4i8 + amdgcn_mma, // mfma_i32_32x32x4i8 + amdgcn_mma, // mfma_i32_16x16x4i8 + amdgcn_mma, // mfma_i32_16x16x4i8 + amdgcn_mma, // mfma_i32_4x4x4i8 + amdgcn_mma // mfma_i32_4x4x4i8 >; using Gfx908andGfx90aIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_32x32x2bf16 - amdgcn_mma, // mfma_f32_32x32x2bf16 - amdgcn_mma, // mfma_f32_16x16x2bf16 - amdgcn_mma, // mfma_f32_16x16x2bf16 - amdgcn_mma, // mfma_f32_4x4x2bf16 - amdgcn_mma, // mfma_f32_4x4x2bf16 - amdgcn_mma, // mfma_f32_32x32x4bf16 - amdgcn_mma, // mfma_f32_16x16x8bf16 - amdgcn_mma, // mfma_i32_32x32x8i8 - amdgcn_mma // mfma_i32_16x16x16i8 + amdgcn_mma, // mfma_f32_32x32x2bf16 + amdgcn_mma, // mfma_f32_32x32x2bf16 + amdgcn_mma, // mfma_f32_16x16x2bf16 + amdgcn_mma, // mfma_f32_16x16x2bf16 + amdgcn_mma, // mfma_f32_4x4x2bf16 + amdgcn_mma, // mfma_f32_4x4x2bf16 + amdgcn_mma, // mfma_f32_32x32x4bf16 + amdgcn_mma, // mfma_f32_16x16x8bf16 + amdgcn_mma, // mfma_i32_32x32x8i8 + amdgcn_mma // mfma_i32_16x16x16i8 >; using Gfx90aAndHigherIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_32x32x4bf16_1k - amdgcn_mma, // mfma_f32_32x32x4bf16_1k - amdgcn_mma, // mfma_f32_16x16x4bf16_1k - amdgcn_mma, // mfma_f32_16x16x4bf16_1k - amdgcn_mma, // mfma_f32_4x4x4bf16_1k - amdgcn_mma, // mfma_f32_4x4x4bf16_1k - amdgcn_mma, // mfma_f32_32x32x8bf16_1k - amdgcn_mma, // mfma_f32_16x16x16bf16_1k - amdgcn_mma, // mfma_f64_16x16x4f64 - amdgcn_mma, // mfma_f64_4x4x4f64 - amdgcn_mma // mfma_f64_4x4x4f64 + amdgcn_mma, // mfma_f32_32x32x4bf16_1k + amdgcn_mma, // mfma_f32_32x32x4bf16_1k + amdgcn_mma, // mfma_f32_16x16x4bf16_1k + amdgcn_mma, // mfma_f32_16x16x4bf16_1k + amdgcn_mma, // mfma_f32_4x4x4bf16_1k + amdgcn_mma, // mfma_f32_4x4x4bf16_1k + amdgcn_mma, // mfma_f32_32x32x8bf16_1k + amdgcn_mma, // mfma_f32_16x16x16bf16_1k + amdgcn_mma, // mfma_f64_16x16x4f64 + amdgcn_mma, // mfma_f64_4x4x4f64 + amdgcn_mma // mfma_f64_4x4x4f64 >; using Gfx942AndHigherIntrinsics = ::testing::Types< - amdgcn_mma, // mfma_i32_16x16x32_i8 - amdgcn_mma, // mfma_i32_32x32x16_i8 - amdgcn_mma, // mfma_f32_16x16x32_bf8_bf8 - amdgcn_mma, // mfma_f32_16x16x32_bf8_fp8 - amdgcn_mma, // mfma_f32_16x16x32_fp8_bf8 - amdgcn_mma, // mfma_f32_16x16x32_fp8_fp8 - amdgcn_mma, // mfma_f32_32x32x16_bf8_bf8 - amdgcn_mma, // mfma_f32_32x32x16_bf8_fp8 - amdgcn_mma, // mfma_f32_32x32x16_fp8_bf8 - amdgcn_mma, // mfma_f32_32x32x16_fp8_fp8 - amdgcn_mma, // smfmac_f32_16x16x32_f16 - amdgcn_mma, // smfmac_f32_32x32x16_f16 - amdgcn_mma, // smfmac_f32_16x16x32_bf16 - amdgcn_mma, // smfmac_f32_32x32x16_bf16 - amdgcn_mma, // smfmac_i32_16x16x64_i8 - amdgcn_mma, // smfmac_i32_32x32x32_i8 - amdgcn_mma, // smfmac_f32_16x16x64_bf8_bf8 - amdgcn_mma, // smfmac_f32_16x16x64_bf8_fp8 - amdgcn_mma, // smfmac_f32_16x16x64_fp8_bf8 - amdgcn_mma, // smfmac_f32_16x16x64_fp8_fp8 - amdgcn_mma, // smfmac_f32_32x32x32_bf8_bf8 - amdgcn_mma, // smfmac_f32_32x32x32_bf8_fp8 - amdgcn_mma, // smfmac_f32_32x32x32_fp8_bf8 - amdgcn_mma // smfmac_f32_32x32x32_fp8_fp8 + amdgcn_mma, // mfma_i32_16x16x32_i8 + amdgcn_mma, // mfma_i32_32x32x16_i8 + amdgcn_mma, // mfma_f32_16x16x32_bf8_bf8 + amdgcn_mma, // mfma_f32_16x16x32_bf8_fp8 + amdgcn_mma, // mfma_f32_16x16x32_fp8_bf8 + amdgcn_mma, // mfma_f32_16x16x32_fp8_fp8 + amdgcn_mma, // mfma_f32_32x32x16_bf8_bf8 + amdgcn_mma, // mfma_f32_32x32x16_bf8_fp8 + amdgcn_mma, // mfma_f32_32x32x16_fp8_bf8 + amdgcn_mma, // mfma_f32_32x32x16_fp8_fp8 + amdgcn_mma, // smfmac_f32_16x16x32_f16 + amdgcn_mma, // smfmac_f32_32x32x16_f16 + amdgcn_mma, // smfmac_f32_16x16x32_bf16 + amdgcn_mma, // smfmac_f32_32x32x16_bf16 + amdgcn_mma, // smfmac_i32_16x16x64_i8 + amdgcn_mma, // smfmac_i32_32x32x32_i8 + amdgcn_mma, // smfmac_f32_16x16x64_bf8_bf8 + amdgcn_mma, // smfmac_f32_16x16x64_bf8_fp8 + amdgcn_mma, // smfmac_f32_16x16x64_fp8_bf8 + amdgcn_mma, // smfmac_f32_16x16x64_fp8_fp8 + amdgcn_mma, // smfmac_f32_32x32x32_bf8_bf8 + amdgcn_mma, // smfmac_f32_32x32x32_bf8_fp8 + amdgcn_mma, // smfmac_f32_32x32x32_fp8_bf8 + amdgcn_mma // smfmac_f32_32x32x32_fp8_fp8 >; using Gfx942Intrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_16x16x8_xf32 - amdgcn_mma // mfma_f32_32x32x4_xf32 + amdgcn_mma, // mfma_f32_16x16x8_xf32 + amdgcn_mma // mfma_f32_32x32x4_xf32 >; using Gfx950Intrinsics = ::testing::Types< - amdgcn_mma, // mfma_f32_16x16x32_f16 - amdgcn_mma, // mfma_f32_16x16x32_bf16 - amdgcn_mma, // mfma_f32_32x32x16_f16 - amdgcn_mma, // mfma_f32_32x32x16_bf16 - amdgcn_mma, // mfma_i32_16x16x64_i8 - amdgcn_mma, // mfma_i32_32x32x32_i8 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 - amdgcn_mma, // smfmac_f32_16x16x64_f16 - amdgcn_mma, // smfmac_f32_32x32x32_f16 - amdgcn_mma, // smfmac_f32_16x16x64_bf16 - amdgcn_mma, // smfmac_f32_32x32x32_bf16 - amdgcn_mma, // smfmac_i32_16x16x128_i8 - amdgcn_mma, // smfmac_i32_32x32x64_i8 - amdgcn_mma, // smfmac_f32_16x16x128_bf8_bf8 - amdgcn_mma, // smfmac_f32_16x16x128_bf8_fp8 - amdgcn_mma, // smfmac_f32_16x16x128_fp8_bf8 - amdgcn_mma, // smfmac_f32_16x16x128_fp8_fp8 - amdgcn_mma, // smfmac_f32_32x32x64_bf8_bf8 - amdgcn_mma, // smfmac_f32_32x32x64_bf8_fp8 - amdgcn_mma, // smfmac_f32_32x32x64_fp8_bf8 - amdgcn_mma // smfmac_f32_32x32x64_fp8_fp8 + amdgcn_mma, // mfma_f32_16x16x32_f16 + amdgcn_mma, // mfma_f32_16x16x32_bf16 + amdgcn_mma, // mfma_f32_32x32x16_f16 + amdgcn_mma, // mfma_f32_32x32x16_bf16 + amdgcn_mma, // mfma_i32_16x16x64_i8 + amdgcn_mma, // mfma_i32_32x32x32_i8 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_16x16x128_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // mfma_scale_f32_32x32x64_f8f6f4 + amdgcn_mma, // smfmac_f32_16x16x64_f16 + amdgcn_mma, // smfmac_f32_32x32x32_f16 + amdgcn_mma, // smfmac_f32_16x16x64_bf16 + amdgcn_mma, // smfmac_f32_32x32x32_bf16 + amdgcn_mma, // smfmac_i32_16x16x128_i8 + amdgcn_mma, // smfmac_i32_32x32x64_i8 + amdgcn_mma, // smfmac_f32_16x16x128_bf8_bf8 + amdgcn_mma, // smfmac_f32_16x16x128_bf8_fp8 + amdgcn_mma, // smfmac_f32_16x16x128_fp8_bf8 + amdgcn_mma, // smfmac_f32_16x16x128_fp8_fp8 + amdgcn_mma, // smfmac_f32_32x32x64_bf8_bf8 + amdgcn_mma, // smfmac_f32_32x32x64_bf8_fp8 + amdgcn_mma, // smfmac_f32_32x32x64_fp8_bf8 + amdgcn_mma // smfmac_f32_32x32x64_fp8_fp8 >; using Gfx11Intrinsics = ::testing::Types< - amdgcn_mma, // wmma_f32_16x16x16_f16_w32 - amdgcn_mma, // wmma_f32_16x16x16_bf16_w32 - amdgcn_mma, // wmma_i32_16x16x16_iu8_w32 - amdgcn_mma // wmma_i32_16x16x16_iu4_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32 + amdgcn_mma, // wmma_f32_16x16x16_bf16_w32 + amdgcn_mma, // wmma_i32_16x16x16_iu8_w32 + amdgcn_mma // wmma_i32_16x16x16_iu4_w32 >; using Gfx12Intrinsics = ::testing::Types< - amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_bf16_w32_gfx12 - amdgcn_mma, // wmma_f16_16x16x16_f16_w32_gfx12 - amdgcn_mma, // wmma_bf16_16x16x16_bf16_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x16_iu8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12 - amdgcn_mma, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 - amdgcn_mma, // wmma_i32_16x16x32_iu4_w32_gfx12 - amdgcn_mma, // swmmac_f32_16x16x32_f16_w32 - amdgcn_mma, // swmmac_f32_16x16x32_bf16_w32 - amdgcn_mma, // swmmac_f16_16x16x32_f16_w32 - amdgcn_mma, // swmmac_bf16_16x16x32_bf16_w32 - amdgcn_mma, // swmmac_i32_16x16x32_iu8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_fp8_fp8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_fp8_bf8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_bf8_fp8_w32 - amdgcn_mma, // swmmac_f32_16x16x32_bf8_bf8_w32 - amdgcn_mma, // swmmac_i32_16x16x32_iu4_w32 - amdgcn_mma // swmmac_i32_16x16x64_iu4_w32 + amdgcn_mma, // wmma_f32_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_bf16_w32_gfx12 + amdgcn_mma, // wmma_f16_16x16x16_f16_w32_gfx12 + amdgcn_mma, // wmma_bf16_16x16x16_bf16_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x16_iu8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_fp8_fp8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_fp8_bf8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_bf8_fp8_w32_gfx12 + amdgcn_mma, // wmma_f32_16x16x16_bf8_bf8_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x16_iu4_w32_gfx12 + amdgcn_mma, // wmma_i32_16x16x32_iu4_w32_gfx12 + amdgcn_mma, // swmmac_f32_16x16x32_f16_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_f16_16x16x32_f16_w32 + amdgcn_mma, // swmmac_bf16_16x16x32_bf16_w32 + amdgcn_mma, // swmmac_i32_16x16x32_iu8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_fp8_bf8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_fp8_w32 + amdgcn_mma, // swmmac_f32_16x16x32_bf8_bf8_w32 + amdgcn_mma, // swmmac_i32_16x16x32_iu4_w32 + amdgcn_mma // swmmac_i32_16x16x64_iu4_w32 >; // clang-format on