From 6c49b6a6703afb24b0f4bbdb03d63b271b602b4d Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc Date: Thu, 12 Mar 2026 09:26:58 +0100 Subject: [PATCH] Changed the include order of the new WMMA/MFMA unification framework (#5241) Those changes are to fix the include order and make header files independent of one another. Also the `remod.py` sript has run and changed the `grouped_convolution.hpp` and `core.hpp` files. ## Motivation Some headers appear to depend on include order. For example, when moving `#include "wmma/wmma.hpp"` in [amdgcn_mma.hpp](https://github.com/ROCm/rocm-libraries/blob/develop/projects/composablekernel/include/ck_tile/core/arch/mma/amdgcn_mma.hpp) later in the include list, it is causing compilation errors. Also the pre-commit script `remod.py` is shuffling includes to be in alphabetical order and is causing compilation issues. Expected behaviour: Headers should be independent of one another: no header should require another to be included first. Each header should compile correctly on its own. ## Test Plan The CI (that runs `remod.py`) should compile. ## Test Result Existing CI should compile and be green. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Signed-off-by: Chris Tsiaousis --- include/ck_tile/core.hpp | 9 ++ include/ck_tile/core/arch/mma/amdgcn_mma.hpp | 1 + .../ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp | 28 ------ .../core/arch/mma/mfma/mfma_traits.hpp | 28 ++++++ include/ck_tile/core/arch/mma/mma_traits.hpp | 1 + .../core/arch/mma/sparse/mfma/selector.hpp | 1 + .../core/arch/mma/sparse/mfma/sparse_gfx9.hpp | 26 +----- .../ck_tile/core/arch/mma/sparse/sparse.hpp | 55 +----------- .../core/arch/mma/sparse/sparse_traits.hpp | 89 +++++++++++++++++++ .../arch/mma/sparse/wmma/sparse_gfx12.hpp | 8 +- .../ops/gemm/warp/warp_gemm_smfmac_impl.hpp | 1 - include/ck_tile/ops/gemm_mx.hpp | 6 +- include/ck_tile/ops/grouped_convolution.hpp | 2 +- 13 files changed, 140 insertions(+), 115 deletions(-) create mode 100644 include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index f42526ddf7..ed063b3abb 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -20,9 +20,18 @@ #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" #include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/selector.hpp" +#include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/wmma/wmma.hpp" #include "ck_tile/core/arch/mma/wmma/wmma_gfx11.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 52943dc2e4..f9f9d7ca37 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -21,6 +21,7 @@ namespace ck_tile::core::arch::mma { struct Unsupported; #if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + #include /** * @concept MmaOpI diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp index 225ceb60f5..1d1267a839 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_gfx9.hpp @@ -20,34 +20,6 @@ namespace ck_tile::core::arch::mma { // For flexibility, it is recommended that for each backend wrapper it supports at least // one packed register for each input to be able to process smaller K values by padding. -/** - * @struct DefaultMmaCtrlFlags - * @brief Default MFMA flags, no broadcasting or rotation of inputs - */ -struct DefaultMfmaCtrlFlags -{ - static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0 - static constexpr uint32_t Abid = 0; // ABID flag, default 0 - static constexpr uint32_t Blgp = 0; // BLGP flag, default 0 -}; - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include - -/** - * @concept CtrlFlagsGfx9I - * @brief Expresses the interface of required members for each CtrlFlags type on Gfx9 - */ -template -concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) { - // Flag members for Gfx9 MFMA instructions - { CtrlFlags::Cbsz } -> std::convertible_to; - { CtrlFlags::Abid } -> std::convertible_to; - { CtrlFlags::Blgp } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - /** * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for MFMA on GFX9 targets diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp index 170e06f08c..009fa0a36e 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_traits.hpp @@ -41,4 +41,32 @@ struct is_mma_op_mfma static constexpr bool is_mma_op_mfma_v = is_mma_op_mfma::value; +/** + * @struct DefaultMfmaCtrlFlags + * @brief Default MFMA flags, no broadcasting or rotation of inputs + */ +struct DefaultMfmaCtrlFlags +{ + static constexpr uint32_t Cbsz = 0; // CBSZ flag, default 0 + static constexpr uint32_t Abid = 0; // ABID flag, default 0 + static constexpr uint32_t Blgp = 0; // BLGP flag, default 0 +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +#include + +/** + * @concept CtrlFlagsGfx9I + * @brief Expresses the interface of required members for each CtrlFlags type on Gfx9 + */ +template +concept CtrlFlagsGfx9I = requires(CtrlFlags ctrlFlags) { + // Flag members for Gfx9 MFMA instructions + { CtrlFlags::Cbsz } -> std::convertible_to; + { CtrlFlags::Abid } -> std::convertible_to; + { CtrlFlags::Blgp } -> std::convertible_to; +}; + +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_traits.hpp b/include/ck_tile/core/arch/mma/mma_traits.hpp index fca2dd058c..90cfd8aaf2 100644 --- a/include/ck_tile/core/arch/mma/mma_traits.hpp +++ b/include/ck_tile/core/arch/mma/mma_traits.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/arch/arch.hpp" #include "mfma/mfma_traits.hpp" #include "wmma/wmma_traits.hpp" +#include "sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp index 92d14a257d..8daea552b7 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/selector.hpp @@ -7,6 +7,7 @@ #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { diff --git a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp index 89fb6688c0..3c89bc5c0f 100644 --- a/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp +++ b/include/ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp @@ -7,34 +7,10 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { -/** - * @struct DefaultSparseMfmaCtrlFlags - * @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is - * 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit. - */ -struct DefaultSparseMfmaCtrlFlags -{ - static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST; -}; - -#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER -#include - -/** - * @concept SparseMfmaCtrlFlags - * @brief Expresses the interface of required members for each CtrlFlags type - */ -template -concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { - // Flag members for sparse MFMA instructions - { CtrlFlags::CompressionIndex } -> std::convertible_to; -}; - -#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER - /** * @struct amdgcn_mma * @brief Specialization of amdgcn_mma for Sparse MFMA (SMFMA) on GFX942, GFX950 targets diff --git a/include/ck_tile/core/arch/mma/sparse/sparse.hpp b/include/ck_tile/core/arch/mma/sparse/sparse.hpp index 5adadd371b..e9792196c5 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse.hpp @@ -5,64 +5,11 @@ namespace ck_tile::core::arch::mma { -/** - * @enum SparseCompressionIndex - * @brief Indicates which set of sparse-indices within a VGPR starting at srcC - * containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data) - * of index information for a lane. \see DefaultSparseMfmaCtrlFlags - */ -enum struct SparseCompressionIndex : int -{ - FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively - SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively - THIRD = 2, // Uses bits [23:16] - FOURTH = 3, // Uses bits [31:24] -}; - -namespace sparse::detail { - -/** - * @struct BuiltinParams - * @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse - * builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data: - * If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC - * containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected - * (VGPR[srcC][7..0]). - * - * 8-bit source data: - * If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC - * containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected - * (VGPR[srcC][15..0]). - */ -struct BuiltinParams -{ - int UseFirstIndex; // CBSZ - int ByteIndexToOverride; // ABID -}; - -template -static constexpr BuiltinParams getBuiltinParams() -{ - BuiltinParams params; - if constexpr(Idx == SparseCompressionIndex::FIRST) - { - params.UseFirstIndex = 1; - params.ByteIndexToOverride = 0; - } - else - { - params.UseFirstIndex = 0; - params.ByteIndexToOverride = static_cast(Idx); - } - return params; -} - -} // namespace sparse::detail - } // namespace ck_tile::core::arch::mma // Include sparse MFMA traits and architecture-specific implementations #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp new file mode 100644 index 0000000000..946a44c221 --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/sparse_traits.hpp @@ -0,0 +1,89 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile::core::arch::mma { + +/** + * @enum SparseCompressionIndex + * @brief Indicates which set of sparse-indices within a VGPR starting at srcC + * containing 8-bits (for 16-bit source data) or 16-bits (for 8-bit source data) + * of index information for a lane. \see DefaultSparseMfmaCtrlFlags + */ +enum struct SparseCompressionIndex : int +{ + FIRST = 0, // Uses bits [7:0] or [15..0], for 16 and 8 bit data respectively + SECOND = 1, // Uses bits [15:8] or [31:16], for 16 and 8 bit data respectively + THIRD = 2, // Uses bits [23:16] + FOURTH = 3, // Uses bits [31:24] +}; + +namespace sparse::detail { + +/** + * @struct BuiltinParams + * @brief Translates the SparseCompressionIndex to the correct CBSZ and ABID pairs for sparse + * builtins. The actual behavior of the builtin depends on the input data type: 16-bit source data: + * If CBSZ=0, ABID selects one of four 8-bit sets of sparse-indices within a VGPR starting at srcC + * containing 8-bits of index information for a lane. If CBSZ!=0 the very first is selected + * (VGPR[srcC][7..0]). + * + * 8-bit source data: + * If CBSZ=0, ABID selects one of two 16-bit sets of sparse-indices within a VGPR starting at srcC + * containing 16-bits of index information for a lane. If CBSZ!=0; the very first is selected + * (VGPR[srcC][15..0]). + */ +struct BuiltinParams +{ + int UseFirstIndex; // CBSZ + int ByteIndexToOverride; // ABID +}; + +template +static constexpr BuiltinParams getBuiltinParams() +{ + BuiltinParams params; + if constexpr(Idx == SparseCompressionIndex::FIRST) + { + params.UseFirstIndex = 1; + params.ByteIndexToOverride = 0; + } + else + { + params.UseFirstIndex = 0; + params.ByteIndexToOverride = static_cast(Idx); + } + return params; +} + +} // namespace sparse::detail + +/** + * @struct DefaultSparseMfmaCtrlFlags + * @brief Default MFMA sparse flags, select (VGPR[srcC][7..0]) if srcC is + * 16-bit or (VGPR[srcC][15..0]) if srcC is 8-bit. + */ +struct DefaultSparseMfmaCtrlFlags +{ + static constexpr SparseCompressionIndex CompressionIndex = SparseCompressionIndex::FIRST; +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +#include +/** + * @concept SparseMfmaCtrlFlags + * @brief Expresses the interface of required members for each CtrlFlags type + */ +template +concept SparseMfmaCtrlFlags = requires(CtrlFlags ctrlFlags) { + // Flag members for sparse MFMA instructions + { CtrlFlags::CompressionIndex } -> std::convertible_to; +}; +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +struct DefaultSparseWmmaCtrlFlags +{ +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp index a1406a7f8c..c0d0e4169a 100644 --- a/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp +++ b/include/ck_tile/core/arch/mma/sparse/wmma/sparse_gfx12.hpp @@ -7,13 +7,11 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" namespace ck_tile::core::arch::mma { -struct DefaultSparseWmmaCtrlFlags -{ -}; - // TODO: c++20 template // TODO: c++20 requires template @@ -61,7 +59,7 @@ struct amdgcn_mma(aVec); + const int32_t idx = ::ck_tile::compress_a_impl(aVec); const AVecCompressed a_vec_pruned = { aVec[0], aVec[1], aVec[2], aVec[3], aVec[4], aVec[5], aVec[6], aVec[7]}; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp index 9b72839755..0a184cfacf 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -3,7 +3,6 @@ #pragma once -#include "ck_tile/core.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm_mx.hpp b/include/ck_tile/ops/gemm_mx.hpp index c8b328ab60..29fccf8057 100644 --- a/include/ck_tile/ops/gemm_mx.hpp +++ b/include/ck_tile/ops/gemm_mx.hpp @@ -1,9 +1,13 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT - #pragma once #include "ck_tile/ops/gemm_mx/kernel/gemm_mx_kernel.hpp" #include "ck_tile/ops/gemm_mx/kernel/scale_pointer.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" #include "ck_tile/ops/gemm_mx/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 3c7b00782f..5bc4f0c6a0 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -2,10 +2,10 @@ // SPDX-License-Identifier: MIT #pragma once -#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp" #include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp" +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" #include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" #include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"