diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 45c0e302e5..3a9309e41e 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -19,14 +19,16 @@ #include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" -#include "ck_tile/core/arch/mma/mma.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/selector.hpp" #include "ck_tile/core/arch/mma/sparse/mfma/sparse_gfx9.hpp" #include "ck_tile/core/arch/mma/sparse/sparse.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_selector.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_traits.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index bbf1217919..072ac0bc36 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -204,6 +204,20 @@ struct Unsupported; #if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER #include +/** + * @concept HasExecSignature + * @brief Helper concept for exec signature check. + */ +template +concept HasExecSignature = requires { + { + MmaOp::exec(typename MmaOp::AVecType{}, + typename MmaOp::BVecType{}, + typename MmaOp::CVecType{}, + std::declval()...) + } -> std::convertible_to; +}; + /** * @concept MmaOpI * @brief Expresses the meta-data interface required for each MmaOp policy. @@ -213,7 +227,7 @@ template concept MmaOpI = requires(MmaOp op) { // Requires an op context typename MmaOp::OpType; - typename MmaOp::OpFamily; + { MmaOp::OpFamily } -> std::convertible_to; // Captures types for inputs / outputs to mma function typename MmaOp::ADataType; @@ -230,13 +244,8 @@ concept MmaOpI = requires(MmaOp op) { { MmaOp::kBRepeat } -> std::convertible_to; { MmaOp::kCMPerLane } -> std::convertible_to; { MmaOp::kCMNumAccess } -> std::convertible_to; - - // Static exec function - { - MmaOp::exec( - typename MmaOp::AVecType{}, typename MmaOp::BVecType{}, typename MmaOp::CVecType{}) - } -> std::convertible_to; -}; + { MmaOp::kCompressionRatio } -> std::convertible_to; +} && (HasExecSignature || HasExecSignature); #endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER @@ -275,7 +284,7 @@ struct amdgcn_mma : amdgcn_mma_base // TODO: c++20 requires template -struct MmaTransformsDefaultSelector> +struct MmaTransformsDefaultSelector< + MmaOp, + CompilerTarget, + enable_if_all, + std::enable_if_t>> { using SelectedTransforms = MmaDefaultTransformsGfx9; }; diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp deleted file mode 100644 index b0eb507b49..0000000000 --- a/include/ck_tile/core/arch/mma/mma.hpp +++ /dev/null @@ -1,230 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/numeric/vector_type.hpp" - -#include "amdgcn_mma.hpp" -#include "mma_selector.hpp" -#include "mma_transforms.hpp" - -#include "mfma/mfma.hpp" -#include "wmma/wmma.hpp" - -namespace ck_tile::core::arch::mma { - -/*! @enum MmaAccumPolicy - * @brief Accumulation order for Mma decomposition - */ -enum struct MmaAccumPolicy -{ - // Decomposition and accumulation in row-major fragment order - ROW_MAJOR, - // Decomposition and accumulation in col-major fragment order - COL_MAJOR -}; - -/** - * @class Mma - * @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation - * (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to - * matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and - * accumulates results into output WaveTile (C: WaveTileM x WaveTileN). - * @tparam ADataType Data type of input WaveTile A - * @tparam BDataType Data type of input WaveTile B - * @tparam CDataType Data type of input/output WaveTile C (accumulator) - * @tparam WaveTileM Mma WaveTile M dimension - * @tparam WaveTileN Mma WaveTile K dimension - * @tparam WaveTileK Mma WaveTile M dimension - * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) - * @tparam CompilerTarget The compiler target - * @tparam MmaOp Backend wrapper class that will perform the mma op (e.g., mfma or wmma) - * @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles - * @par This is an example of an Mma decomposition driver class that can be used in a wave-tile - * context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments - * that are natively supported by the hardware (e.g., mfma or wmma). The class also supports - * applying transforms to the input/output frags as needed (e.g., layout conversions, data type - * conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the - * output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver - * that can adapt to different hardware capabilities and requirements. - */ -template ::SelectedOp, - typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = - typename MmaTransformsDefaultSelector::SelectedTransforms> -struct WaveWiseMma -{ - using FragWiseMmaOp = MmaOp; - - // Fragment dimensions - constexpr static uint32_t FragM = MmaOp::kM; - constexpr static uint32_t FragN = MmaOp::kN; - constexpr static uint32_t FragK = MmaOp::kK; - - // Fragment counts for decomposition - constexpr static uint32_t FragsM = WaveTileM / FragM; - constexpr static uint32_t FragsN = WaveTileN / FragN; - constexpr static uint32_t FragsK = WaveTileK / FragK; - constexpr static uint32_t FragsC = FragsM * FragsN; - - // Vector types for packed registers in each fragment - using AVecType = typename MmaOp::AVecType; - using BVecType = typename MmaOp::BVecType; - using CVecType = typename MmaOp::CVecType; - - // Buffer types for WaveTiles - using ABufferType = AVecType[FragsM][FragsK]; - using BBufferType = BVecType[FragsN][FragsK]; - using CBufferType = CVecType[FragsM][FragsN]; - - // Transforms - using ATransform = typename MmaTransforms::ATransform; - using BTransform = typename MmaTransforms::BTransform; - using CTransform = typename MmaTransforms::CTransform; - using DTransform = typename MmaTransforms::DTransform; - - // Sanity checks - static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM"); - static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN"); - static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK"); - static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); - static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); - static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); - - private: - template - CK_TILE_DEVICE static auto formatBuffer(SrcT const& inputBuffer) - { - // TODO: Implement formatting logic as needed. - // This is intended to convert input WaveTiles to the native vector types - // required by the FragWiseMma operation for iteration - static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); - return reinterpret_cast(inputBuffer); - } - - template - CK_TILE_DEVICE static auto formatBuffer(SrcT& inputBuffer) - { - // TODO: Implement formatting logic as needed. - // This is intended to convert input WaveTiles to the native vector types - // required by the FragWiseMma operation for iteration - static_assert(sizeof(DstT) == sizeof(SrcT), "Size mismatch in formatBuffer"); - return reinterpret_cast(inputBuffer); - } - - /*! @brief Execute Mma in row-major accumulation order. - * @tparam VecTA The input WaveTile A vector type - * @tparam VecTB The input WaveTile B vector type - * @tparam VecTC The input/output WaveTile C vector type - */ - template - CK_TILE_DEVICE static decltype(auto) exec_col_major(VecTA&& a, VecTB&& b, VecTC&& accum) - { - // We implement an example wave-tile pipeline here. - // First, we apply the necessary transforms to the input fragments, - // then we convert the result into buffers of native vector formats - // that we can easily index. Native vector formats are necessary inputs - // to the given MmaOp exec function. - auto a_frag = formatBuffer(ATransform::exec(a)); - auto b_frag = formatBuffer(BTransform::exec(b)); - auto c_frag = formatBuffer(CTransform::exec(accum)); - - // "Col-major" accumulation over the M-dimension fragments first. - // Pseudo code here, but we would basically iterate over the fragments in col-major order - for(uint32_t bn = 0u; bn < FragsN; ++bn) - { - for(uint32_t bm = 0u; bm < FragsM; ++bm) - { - for(uint32_t bk = 0u; bk < FragsK; ++bk) - { - c_frag[bm][bn] = - FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); - } - } - } - - // Convert native vector results back to the output WaveTile format - // and then return after we apply the final output transform. - return DTransform::exec(formatBuffer>(c_frag)); - } - - /*! @brief Execute Mma in row-major accumulation order. - * @tparam VecTA The input WaveTile A vector type - * @tparam VecTB The input WaveTile B vector type - * @tparam VecTC The input/output WaveTile C vector type - */ - template - CK_TILE_DEVICE static decltype(auto) exec_row_major(VecTA&& a, VecTB&& b, VecTC&& accum) - { - // We implement an example wave-tile pipeline here. - // First, we apply the necessary transforms to the input WaveTiles, - // then we convert the result into buffers of native vector formats - // that we can easily index. Native vector formats are necessary inputs - // to the given MmaOp exec function. - auto a_frag = formatBuffer(ATransform::exec(a)); - auto b_frag = formatBuffer(BTransform::exec(b)); - auto c_frag = formatBuffer(CTransform::exec(accum)); - - // "Row-major" accumulation over the N-dimension fragments first. - // Pseudo code here, but we would basically iterate over the fragments in row-major order. - // We also have to ensure that the incoming vector WaveTiles are converted to native vector - // types before passing to the FragWiseMma exec function. - for(uint32_t bm = 0u; bm < FragsM; ++bm) - { - for(uint32_t bn = 0u; bn < FragsN; ++bn) - { - for(uint32_t bk = 0u; bk < FragsK; ++bk) - { - c_frag[bm][bn] = - FragWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); - } - } - } - - // Convert native vector results back to the output WaveTile format - // and then return after we apply the final output transform. - return DTransform::exec(formatBuffer>(c_frag)); - } - - public: - /*! @brief Forward to Mma operation with specified accumulation order. - * @tparam VecTA The input WaveTile A vector type - * @tparam VecTB The input WaveTile B vector type - * @tparam VecTC The input/output WaveTile C vector type - */ - template - CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) - { - if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) - { - return exec_row_major( - std::forward(a), std::forward(b), std::forward(accum)); - } - else // if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) - { - return exec_col_major( - std::forward(a), std::forward(b), std::forward(accum)); - } - } -}; - -} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/include/ck_tile/core/arch/mma/mma_pipeline.hpp new file mode 100644 index 0000000000..fb5e2b1b21 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -0,0 +1,299 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +#include "amdgcn_mma.hpp" +#include "mma_selector.hpp" +#include "mma_traits.hpp" +#include "mma_transforms.hpp" + +namespace ck_tile::core::arch::mma { + +/*! @enum MmaPipelineOptionFlag + * @brief Individual option flags for configuring MmaPipeline behavior. + */ +enum struct MmaPipelineOptionFlag : unsigned +{ + NONE = 0x0, ///< No flags set + ABSwap = 0x1, ///< Swap A and B inputs to transpose the C output + COMPRESS_A = 0x2, ///< Enable compressed (sparse) A matrix input +}; + +/** + * @struct MmaPipelineOptionFlags + * @brief Type-safe bitmask wrapper for combining @ref MmaPipelineOptionFlag values. + * @par Provides bitwise OR, AND, NOT, and equality operators for composing + * and querying pipeline option flags. + */ +struct MmaPipelineOptionFlags +{ + using Type = std::underlying_type_t; + + explicit constexpr MmaPipelineOptionFlags() : mFlags(0) {} + explicit constexpr MmaPipelineOptionFlags(Type value) : mFlags(value) {} + constexpr MmaPipelineOptionFlags(MmaPipelineOptionFlag singleFlag) : mFlags(toType(singleFlag)) + { + } + constexpr MmaPipelineOptionFlags(const MmaPipelineOptionFlags& original) + : mFlags(original.mFlags) + { + } + + constexpr MmaPipelineOptionFlags& operator|=(MmaPipelineOptionFlag addValue) + { + mFlags |= toType(addValue); + return *this; + } + constexpr MmaPipelineOptionFlags operator|(MmaPipelineOptionFlag addValue) const + { + MmaPipelineOptionFlags result(*this); + result |= addValue; + return result; + } + constexpr MmaPipelineOptionFlags& operator&=(MmaPipelineOptionFlag maskValue) + { + mFlags &= toType(maskValue); + return *this; + } + constexpr MmaPipelineOptionFlags operator&(MmaPipelineOptionFlag maskValue) const + { + MmaPipelineOptionFlags result(*this); + result &= maskValue; + return result; + } + constexpr MmaPipelineOptionFlags operator~() const + { + MmaPipelineOptionFlags result(*this); + result.mFlags = ~result.mFlags; + return result; + } + constexpr bool testFlag(MmaPipelineOptionFlag flag) const + { + return (flag == MmaPipelineOptionFlag::NONE) ? mFlags == toType(flag) : *this & flag; + } + constexpr operator bool() const { return mFlags != toType(MmaPipelineOptionFlag::NONE); } + constexpr bool operator==(Type rhs) const { return mFlags == rhs; } + + private: + Type mFlags; + static constexpr Type toType(MmaPipelineOptionFlag f) { return static_cast(f); } +}; + +constexpr bool operator==(MmaPipelineOptionFlags::Type lhs, const MmaPipelineOptionFlags& rhs) +{ + return rhs == lhs; +} + +/** + * @class MmaPipelineBase + * @brief CRTP base class that implements the common Mma pipeline logic shared by + * all concrete pipeline drivers (e.g., dense wave-wise, sparse, etc.). + * + * @tparam Flags_ Compile-time bitmask of @ref MmaPipelineOptionFlag controlling + * pipeline behavior (e.g., C transposition, A compression). + * @tparam Derived The concrete CRTP-derived pipeline class. Must expose: + * - Type aliases: @c InternalAVecT, @c InternalBVecT, @c InternalCVecT, + * @c CVecType, @c MmaOp + * - Transform aliases: @c ATransform, @c BTransform, @c CTransform, + * @c DTransform + * - A static @c execImpl(std::tuple&) method. + * + * @par The pipeline performs the following steps in @c exec(): + * 1. Apply pre-transforms and format input buffers (A, B, C). + * 2. Delegate to @c Derived::execImpl for the actual mma loop. + * 3. Apply post-transform and format the output buffer (D) back to the user type. + * When @c ABSwap is set, the A and B inputs are swapped before step 1. + */ +// TODO: c++20: use MmaPipelineOptionFlags directly +template +struct MmaPipelineBase +{ + static constexpr auto Flags = MmaPipelineOptionFlags(Flags_); + + private: + /** + * @brief Reconstruct a tuple with its first element passed through @c formatBuffer + * while preserving all remaining elements unchanged. + * @tparam DstT Target type for the formatted first element. + * @tparam SrcT Forwarding-reference type of the input tuple. + * @tparam Is Index pack for elements 1..N-1 of the tuple. + * @param inputTuple The source tuple whose first element will be formatted. + * @return A new tuple with the formatted first element and the remaining elements forwarded. + */ + template + CK_TILE_DEVICE static auto formatBufferTupleImpl(SrcT&& inputTuple, std::index_sequence) + { + auto&& first_elem = std::get<0>(std::forward(inputTuple)); + using FirstElemResultType = + decltype(formatBuffer(std::forward(first_elem))); + using InputTupleType = ck_tile::remove_cvref_t; + return std::tuple...>( + formatBuffer(std::forward(first_elem)), + std::get(std::forward(inputTuple))...); + } + + /** + * @brief Format (reinterpret-cast) a buffer to the hardware-native vector type @p DstT. + * + * Three cases are handled: + * - **Tuple**: recursively format the first element via @c formatBufferTupleImpl, + * preserving any metadata in the remaining tuple elements. + * - **Array / Pointer**: forwarded unchanged. + * - **Scalar / Vector**: reinterpret-cast to @p DstT (sizes must match). + * + * @tparam DstT The target hardware vector type. + * @tparam SrcT Forwarding-reference type of the input buffer. + * @param inputBuffer The buffer to format. + * @return A reference (or value) of type @p DstT corresponding to @p inputBuffer. + */ + template + CK_TILE_DEVICE static decltype(auto) formatBuffer(SrcT&& inputBuffer) + { + using DecayedSrcT = ck_tile::remove_cvref_t; + + // If SrcT is a tuple, extract the first element (the vector) and format it + // while preserving all remaining elements (metadata) + if constexpr(is_std_tuple_v) + { + // Create index sequence for all remaining elements (skip first) + constexpr std::size_t tuple_size = std::tuple_size_v; + return formatBufferTupleImpl(std::forward(inputBuffer), + std::make_index_sequence{}); + } + else if constexpr(std::is_array_v || std::is_pointer_v) + { + return std::forward(inputBuffer); + } + else + { + static_assert(sizeof(DstT) == sizeof(DecayedSrcT), "Size mismatch in formatBuffer"); + + using QualifiedDstT = + std::conditional_t, DstT const, DstT>; + + return reinterpret_cast(inputBuffer); + } + } + + protected: + /** @brief Query whether a specific @ref MmaPipelineOptionFlag is set. */ + template + constexpr CK_TILE_DEVICE static bool hasFlag() + { + return Flags.testFlag(Flag); + } + + /** + * @brief Apply a transform **then** format the result to @p DstT. + * Used for input operands (A, B, C) before the mma loop. + */ + template + CK_TILE_DEVICE static auto preApplyTransform(Args&&... args) + { + return formatBuffer(Transform::exec(std::forward(args)...)); + } + + /** + * @brief Format a buffer to @p DstT **then** apply a transform. + * Used for the output operand (D) after the mma loop. + */ + template + CK_TILE_DEVICE static auto postApplyTransform(Args&&... args) + { + return Transform::exec(formatBuffer(std::forward(args)...)); + } + + /** + * @brief Apply the per-operand pre-transforms and buffer formatting to A, B, and C. + * @return A @c std::tuple of the transformed (A, B, C) vectors ready for the mma loop. + */ + template + CK_TILE_DEVICE static decltype(auto) + applyTransformsToInputs(ATransformInputs&& a, BTransformInputs&& b, CTransformInputs&& accum) + { + using InternalAVecT = typename Derived::InternalAVecT; + using InternalBVecT = typename Derived::InternalBVecT; + using InternalCVecT = typename Derived::InternalCVecT; + + using ATransform = typename Derived::ATransform; + using BTransform = typename Derived::BTransform; + using CTransform = typename Derived::CTransform; + + return std::make_tuple( + preApplyTransform(std::forward(a)), + preApplyTransform(std::forward(b)), + preApplyTransform(std::forward(accum))); + } + + /** + * @brief Apply the post-transform and buffer formatting to the C (accumulator) output. + * @param vecs The (A, B, C) tuple after @c execImpl; only C is consumed. + * @return The final D output in the user-facing vector type. + */ + template + CK_TILE_DEVICE static auto + applyTransformToOutput(std::tuple&& vecs) + { + auto&& [a_result, b_result, c_result] = vecs; + static_assert(!is_std_tuple_v, + "If CTransform returns more than the vector, update this function."); + + using CVecT = typename Derived::CVecType; + using DTransform = typename Derived::DTransform; + return postApplyTransform(c_result); + } + + public: + /** + * @brief Entry point: execute the full Mma pipeline (transforms + mma loop + output). + * @tparam VecTA Type of the A WaveTile buffer. + * @tparam VecTB Type of the B WaveTile buffer. + * @tparam VecTC Type of the C (accumulator) WaveTile buffer. + * @param a Input WaveTile A. + * @param b Input WaveTile B. + * @param accum Input/output accumulator WaveTile C. + * @return The output WaveTile D after accumulation and post-transform. + */ + template + CK_TILE_DEVICE static decltype(auto) exec(VecTA&& a, VecTB&& b, VecTC&& accum) + { + if constexpr(MmaOpTraits::IsSupported) + { + auto transformed_inputs = applyTransformsToInputs( + hasFlag() ? std::forward(b) + : std::forward(a), + hasFlag() ? std::forward(a) + : std::forward(b), + std::forward(accum)); + + Derived::execImpl(transformed_inputs); + + return applyTransformToOutput(std::move(transformed_inputs)); + } + else + { + // Return the unsupported exec. This should print a runtime warning. (amdgcn_mma.hpp) + // Code should not reach here, but HOST/DEVICE compile passes are + // weirdly intertwined and instead of having constexpr in the calling + // site (tests) we do this. See also changes by this commit. + return Derived::MmaOp::exec({}, {}, {}); + } + } +}; + +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +#include + +/** + * @concept MmaPipelineI + * @brief Expresses the meta-data interface required for a CRTP MmaPipeline. + */ +template +concept MmaPipelineInterface = std::derived_from>; + +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/mma_selector.hpp b/include/ck_tile/core/arch/mma/mma_selector.hpp index 208b90d273..740f0f3c33 100644 --- a/include/ck_tile/core/arch/mma/mma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mma_selector.hpp @@ -72,3 +72,4 @@ concept MmaSelectorI = requires(MmaSelector op) { // Include the implementations #include "wmma/wmma_selector.hpp" #include "mfma/mfma_selector.hpp" +#include "sparse/sparse_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/mma_transforms.hpp b/include/ck_tile/core/arch/mma/mma_transforms.hpp index 811df04364..c41aa0ae11 100644 --- a/include/ck_tile/core/arch/mma/mma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/mma_transforms.hpp @@ -18,6 +18,18 @@ struct PassThroughTransform } }; +/** + * @struct MmaDefaultPassThroughTransforms + * @brief Implements the default MMA transforms + */ +struct MmaDefaultPassThroughTransforms +{ + using ATransform = PassThroughTransform; + using BTransform = PassThroughTransform; + using CTransform = PassThroughTransform; + using DTransform = PassThroughTransform; +}; + /** * @class MmaTransformsDefaultSelector * @brief Default selector for MmaTransforms based on MmaOp and CompilerTarget @@ -27,7 +39,10 @@ struct PassThroughTransform */ template // TODO: c++20 template -struct MmaTransformsDefaultSelector; +struct MmaTransformsDefaultSelector +{ + using SelectedTransforms = MmaDefaultPassThroughTransforms; +}; #if CK_TILE_CONCEPTS diff --git a/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/include/ck_tile/core/arch/mma/mma_wavewise.hpp new file mode 100644 index 0000000000..9fbbab411e --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -0,0 +1,177 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" + +#include "amdgcn_mma.hpp" +#include "mma_pipeline.hpp" +#include "mma_selector.hpp" +#include "mma_transforms.hpp" + +#include "mfma/mfma.hpp" +#include "wmma/wmma.hpp" +#include + +namespace ck_tile::core::arch::mma { + +/*! @enum MmaAccumPolicy + * @brief Accumulation order for Mma decomposition + */ +enum struct MmaAccumPolicy +{ + // Decomposition and accumulation in row-major fragment order + ROW_MAJOR, + // Decomposition and accumulation in col-major fragment order + COL_MAJOR +}; + +namespace dense::wavewise::detail { +// TODO: c++20: return MmaPipelineOptionFlags directly +template +constexpr inline int getPipelineFlags() +{ + return static_cast(SwapAB ? MmaPipelineOptionFlag::ABSwap : MmaPipelineOptionFlag::NONE); +} +} // namespace dense::wavewise::detail + +/** + * @class Mma + * @brief Driver for the wave-tile Mma operation. Given a backend MmaOp implementation + * (e.g., mfma or wmma), this class performs fragment-wise (MmaTile) decomposition to + * matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and + * accumulates results into output WaveTile (C: WaveTileM x WaveTileN). + * @tparam ADataType Data type of input WaveTile A + * @tparam BDataType Data type of input WaveTile B + * @tparam CDataType Data type of input/output WaveTile C (accumulator) + * @tparam WaveTileM Mma WaveTile M dimension + * @tparam WaveTileN Mma WaveTile K dimension + * @tparam WaveTileK Mma WaveTile M dimension + * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) + * @tparam SwapAB Swaps A and B input vectors + * @tparam CompilerTarget The compiler target + * @tparam MmaOp_ Backend wrapper class that will perform the mma op (e.g., mfma or wmma) + * @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles + * @par This is an example of an Mma decomposition driver class that can be used in a wave-tile + * context. Given a WaveTile size, we can decompose the WaveTile into smaller mma op fragments + * that are natively supported by the hardware (e.g., mfma or wmma). The class also supports + * applying transforms to the input/output frags as needed (e.g., layout conversions, data type + * conversions, etc.). We may also specify the accumulation order (row-major or col-major) for the + * output WaveTile. This is a powerful example of how to build a flexible and reusable mma driver + * that can adapt to different hardware capabilities and requirements. + */ +template ::SelectedOp, + typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = + typename MmaTransformsDefaultSelector::SelectedTransforms> +// clang-format off +struct WaveWiseMmaPipeline : public MmaPipelineBase(), + WaveWiseMmaPipeline> +{ + using Base = MmaPipelineBase(), + WaveWiseMmaPipeline>; + // clang-format on + using MmaOp = MmaOp_; + + // Fragment dimensions + constexpr static uint32_t FragM = MmaOp::kM; + constexpr static uint32_t FragN = MmaOp::kN; + constexpr static uint32_t FragK = MmaOp::kK; + + // Fragment counts for decomposition + constexpr static uint32_t FragsM = WaveTileM / FragM; + constexpr static uint32_t FragsN = WaveTileN / FragN; + constexpr static uint32_t FragsK = WaveTileK / FragK; + constexpr static uint32_t FragsC = FragsM * FragsN; + + // Vector types for packed registers in each fragment + using InternalAVecT = typename MmaOp::AVecType; + using InternalBVecT = typename MmaOp::BVecType; + using InternalCVecT = typename MmaOp::CVecType; + + // Buffer types for WaveTiles + using AVecType = InternalAVecT[FragsM][FragsK]; + using BVecType = InternalBVecT[FragsN][FragsK]; + using CVecType = InternalCVecT[FragsM][FragsN]; + + // Transforms + using ATransform = typename MmaTransforms::ATransform; + using BTransform = typename MmaTransforms::BTransform; + using CTransform = typename MmaTransforms::CTransform; + using DTransform = typename MmaTransforms::DTransform; + + // Sanity checks + static_assert(WaveTileM >= FragM, "WaveTileM must be larger than FragM"); + static_assert(WaveTileN >= FragN, "WaveTileN must be larger than FragN"); + static_assert(WaveTileK >= FragK, "WaveTileK must be larger than FragK"); + static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); + static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); + static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); + + template + CK_TILE_DEVICE static void execImpl(std::tuple& vecs) + { + auto& [a_frag, b_frag, c_frag] = vecs; + + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + { + // "Row-major" accumulation over the N-dimension fragments first. + // Pseudo code here, but we would basically iterate over the fragments in row-major + // order. We also have to ensure that the incoming vector WaveTiles are converted to + // native vector types before passing to the FragWiseMma exec function. + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = + MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + } + } + } + } + else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) + { + // "Col-major" accumulation over the M-dimension fragments first. + // Pseudo code here, but we would basically iterate over the blocks in col-major order + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = + MmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]); + } + } + } + } + else + { + static_assert(false); + } + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp new file mode 100644 index 0000000000..d57f544a41 --- /dev/null +++ b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -0,0 +1,100 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include +#include + +namespace ck_tile::core::arch::mma { + +namespace sparse::detail { +// TODO: c++20: return MmaPipelineOptionFlags directly +constexpr inline int getPipelineFlags() +{ + return static_cast(MmaPipelineOptionFlag::COMPRESS_A); +} +} // namespace sparse::detail + +template ::SelectedOp, + typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = + typename MmaTransformsDefaultSelector::SelectedTransforms> +// clang-format off +struct SparseMmaPipeline : public MmaPipelineBase> +{ + using Base = MmaPipelineBase>; + // clang-format on + + static_assert(!Base::template hasFlag(), + "Cannot transpose C in sparse intrinsics."); + + using MmaOp = MmaOp_; // Expose the selected MmaOp + + // Calculate the uncompressed A vector type + struct ExternalAVecCalculator + { + using AVecTraits = vector_traits; + static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio; + using AVecType = ext_vector_t; + }; + + // Expose caller-side vector types + using AVecType = typename ExternalAVecCalculator::AVecType; + using BVecType = typename MmaOp::BVecType; + using CVecType = typename MmaOp::CVecType; + + // Expose internal vector types + using InternalAVecT = typename MmaOp::AVecType; + using InternalBVecT = typename MmaOp::BVecType; + using InternalCVecT = typename MmaOp::CVecType; + + // Transforms + using ATransform = typename MmaTransforms::ATransform; + using BTransform = typename MmaTransforms::BTransform; + using CTransform = typename MmaTransforms::CTransform; + using DTransform = typename MmaTransforms::DTransform; + + template + CK_TILE_DEVICE static void + execImpl(std::tuple& vecs) + { + checkATransformResult(); + auto& [a_result, b_vec, c_vec] = vecs; + auto& [a_vec, idx] = a_result; + c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx); + } + + private: + // Type check helper - not a device function, so std::declval is available + template + static constexpr void checkATransformResult() + { + using ExternalAvecRef = std::add_lvalue_reference_t; + static_assert(std::is_same_v()))>); + } +}; + +} // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index 7da8f4f616..4b0effc2bf 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -6,22 +6,101 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_transforms.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include namespace ck_tile::core::arch::mma { +namespace sparse::detail { /** - * @struct MmaDefaultTransformsSparse + * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + * elements into lower part of a_vec to half its effective size. + * @param a_vec Vector to be compressed. + * @tparam ADataType The data type of a_vec + * @tparam CompressedSize The target compression size + * @tparam AVec The vector type of a_vec (deduced) + * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. + * Each field encodes the original position (0–3) of the corresponding + * non‑zero element in the input. If fewer than CompressedSize + * non‑zeros are found, remaining fields default to 2 (see below). + */ +template +static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) +{ + // idx holds one 2‑bit index per output element (total CompressedSize entries). + // It is initialized to the pattern 0b10 for every field. This matches + // what the hardware expects when there are fewer than two non‑zero values + // in a 4‑element group – the unused output is treated as coming from slot 2. + // The loop below will clear and set each field as real non‑zeros are seen. + int32_t idx = 0; + static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); }); + + static_for<0, CompressedSize / 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 4, 1>{}([&](auto j) { + if(static_cast(a_vec[i * 4 + j]) != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + // clear the two‑bit field for this output and insert j + idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos))); + idx |= static_cast(j) << (2u * (i * 2 + non_zero_pos)); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; +} +} // namespace sparse::detail + +/** + * @class SparseCompressTransform + * @brief Performs 2:4 structured sparsity compression to the vector v and produces an index mask. + * @note Returns a tuple of two. The first element is the vector v with the same scalar type but + * its size halved. The second element is the index mask. + */ +template +struct SparseCompressTransform +{ + template + CK_TILE_DEVICE static decltype(auto) exec(VecType& v) + { + using VecTraits = vector_traits>; + using ScalarT = typename VecTraits::scalar_type; + static constexpr auto VecN = VecTraits::vector_size; + static constexpr index_t CompressedSize = VecN / CompressionRatio; + using VecCompressed = ext_vector_t; + + static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio"); + static_assert(CompressedSize > 0, "CompressedSize must be > 0"); + + const auto idx = sparse::detail::compress_a_impl(v); + + // TODO c++20: Use bit_cast + return std::tuple( + *std::launder(reinterpret_cast(&v)), idx); + } +}; + +/** + * @class MmaDefaultTransformsSparse * @brief Implements the default transforms for Sparse * * For 2:4 structured sparsity with inline register metadata: - * - ATransform: Pass-through (sparse operands formatted in Exec) TODO! + * - ATransform: 2:4 structured sparsity compression * - BTransform: Pass-through (sparse operands already formatted) * - CTransform: Pass-through (input accumulator) * - DTransform: Pass-through (output accumulator as-is) */ +template struct MmaDefaultTransformsSparse { - using ATransform = PassThroughTransform; + using ATransform = SparseCompressTransform; using BTransform = PassThroughTransform; using CTransform = PassThroughTransform; using DTransform = PassThroughTransform; @@ -42,7 +121,7 @@ struct MmaTransformsDefaultSelector> { - using SelectedTransforms = MmaDefaultTransformsSparse; + using SelectedTransforms = MmaDefaultTransformsSparse; }; } // namespace ck_tile::core::arch::mma diff --git a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp index eb87c38e87..fd9cd69813 100644 --- a/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/wmma/wmma_transforms.hpp @@ -86,9 +86,11 @@ struct MmaDefaultTransformsGfx12 template // TODO: c++20 template // TODO: c++20 requires -struct MmaTransformsDefaultSelector> +struct MmaTransformsDefaultSelector< + MmaOp, + CompilerTarget, + enable_if_all, + std::enable_if_t>> { using SelectedTransforms = MmaDefaultTransformsGfx11; }; @@ -102,9 +104,11 @@ struct MmaTransformsDefaultSelector // TODO: c++20 template // TODO: c++20 requires -struct MmaTransformsDefaultSelector> +struct MmaTransformsDefaultSelector< + MmaOp, + CompilerTarget, + enable_if_all, + std::enable_if_t>> { using SelectedTransforms = MmaDefaultTransformsGfx12; }; diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 7e0c0886bb..391fc0e4d7 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -209,4 +209,21 @@ template using largest_type_t = std::conditional_t= sizeof(BDataType), ADataType, BDataType>; +/** + * @brief Type trait to detect whether a type is a @c std::tuple specialization. + * @tparam T The type to inspect. + */ +template +struct is_std_tuple : std::false_type +{ +}; + +template +struct is_std_tuple> : std::true_type +{ +}; + +template +static constexpr bool is_std_tuple_v = is_std_tuple::value; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp index 0a184cfacf..b99fc91fa7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -5,52 +5,9 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" + namespace ck_tile { -/** - * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero - * elements into lower part of a_vec to half its effective size. - * @param a_vec Vector to be compressed. - * @tparam ADataType The data type of a_vec - * @tparam CompressedSize The target compression size - * @tparam AVec The vector type of a_vec (deduced) - * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. - * Each field encodes the original position (0–3) of the corresponding - * non‑zero element in the input. If fewer than CompressedSize - * non‑zeros are found, remaining fields default to 2 (see below). - */ -template -static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) -{ - // idx holds one 2‑bit index per output element (total CompressedSize entries). - // It is initialized to the pattern 0b10 for every field. This matches - // what the hardware expects when there are fewer than two non‑zero values - // in a 4‑element group – the unused output is treated as coming from slot 2. - // The loop below will clear and set each field as real non‑zeros are seen. - int32_t idx = 0; - static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2 << (2 * k)); }); - - static_for<0, CompressedSize / 2, 1>{}([&](auto i) { - ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; - int32_t non_zero_pos = 0; - - static_for<0, 3, 1>{}([&](auto j) { - if(a_vec[i * 4 + j] != 0.0f) - { - nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; - // clear the two‑bit field for this output and insert j - idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); - idx |= j << 2 * (i * 2 + non_zero_pos); - ++non_zero_pos; - } - }); - a_vec[i * 2] = nonzero_elems[0]; - a_vec[i * 2 + 1] = nonzero_elems[1]; - }); - - return idx; -} - template struct WarpGemmSmfmacImpl { @@ -86,10 +43,37 @@ struct WarpGemmSmfmacImpl return WarpGemmAttribute_::get_num_of_access(); } - template - CK_TILE_DEVICE int32_t compress_a_vec(AVec& a_vec) + //---------------------------------------------------------------------------------------------- + /// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + /// elements into lower part of a_vec to half its effective size. + /// + /// @param a_vec Vector to be compressed. + /// + /// @return Four 2-bit indexes of non-zero elements locations + /// + template + CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const { - return compress_a_impl(a_vec); + int32_t idx = 0b11101110; + + static_for<0, 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; } template @@ -102,11 +86,10 @@ struct WarpGemmSmfmacImpl constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; using AVec = ext_vector_t; - static constexpr index_t CompressedSize = - ATensor::get_thread_buffer_size() / CompressionRatio; - using AVecCompressed = ext_vector_t; - using BVec = ext_vector_t; - using CVec = ext_vector_t; + using AVecCompressed = + ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; constexpr auto I0 = number<0>{}; @@ -114,9 +97,8 @@ struct WarpGemmSmfmacImpl const auto b_vec = b.get_thread_buffer().template get_as()[I0]; auto c_vec = c.get_thread_buffer().template get_as()[I0]; - const int32_t idx = compress_a_vec(a_vec); + const int32_t idx = compress_a(a_vec); - static_assert(CompressedSize == 4); // @TODO can we simply set a_vec_pruned to a_vec[0:3]? const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index 99ebd6ece3..d93de32fea 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -7,14 +7,15 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -# TODO: This test is temporarily disabled for cooperation / work planning reasons. Re-enable after merging related work. -# if(GPU_TARGETS MATCHES "gfx9|gfx12") -# add_gtest_executable(test_amdgcn_sparse_mma test_amdgcn_sparse_mma.cpp) -# target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -# endif() +if(GPU_TARGETS MATCHES "gfx9|gfx12") + add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) + target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp) + target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() @@ -44,3 +45,6 @@ if(GPU_TARGETS MATCHES "gfx12") target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() +add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) +target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + diff --git a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp new file mode 100644 index 0000000000..a23cf08b1e --- /dev/null +++ b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp @@ -0,0 +1,123 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include + +#include "ck_tile/core/arch/arch.hpp" +#include +#include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" + +#include "../get_wave_size_helper.hpp" + +template +struct MmaPipelineTest +{ + using AType = AType_; + using BType = BType_; + using CType = CType_; + static constexpr auto WaveTileM = WaveTileM_; + static constexpr auto WaveTileN = WaveTileN_; + static constexpr auto WaveTileK = WaveTileK_; + + void test_pipeline(std::function shouldSkip, + std::function kernel, + std::function getExpected, + std::function aInitializer = nullptr) + { + using namespace ck_tile; + using namespace ck_tile::core::arch; + + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + int deviceWarpSize = devProp.warpSize; + + if(!hasDevice || shouldSkip(currentArchId)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + // WaveTile size, also the expected fragment size (MmaTile) from the selector. + // Note: Actual FragK might be slightly different due to hardware implementation, but the + // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is + // correct. + static constexpr uint32_t FragM = WaveTileM; + static constexpr uint32_t FragN = WaveTileN; + static constexpr uint32_t FragK = WaveTileK; + + // The number of elements per thread + uint32_t AElements = FragM * FragK / deviceWarpSize; + uint32_t BElements = FragN * FragK / deviceWarpSize; + uint32_t CElements = FragM * FragN / deviceWarpSize; + + uint32_t ASize = AElements * sizeof(AType); + uint32_t BSize = BElements * sizeof(BType); + uint32_t CSize = CElements * sizeof(CType); + + // Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's + std::vector h_a(AElements); + if(aInitializer) + { + for(size_t i = 0; i < AElements; ++i) + h_a[i] = aInitializer(i); + } + else + { + std::fill(h_a.begin(), h_a.end(), type_convert(1)); + } + std::vector h_b(BElements, type_convert(1)); + std::vector h_c(CElements, type_convert(0)); + std::vector h_out(CElements, type_convert(0)); + + AType* d_a; + BType* d_b; + CType* d_c; + CType* d_out; + + HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); + HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); + HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); + HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); + + const auto wave_size = getDeviceWaveSize(); + kernel(wave_size, d_a, d_b, d_c, d_out); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); + + // Verify output against expected value for all elements + for(size_t i = 0; i < CElements; ++i) + { + EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3); + } + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); + HIP_CHECK_ERROR(hipFree(d_out)); + } +}; diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp new file mode 100644 index 0000000000..da3800fdda --- /dev/null +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_mma_pipeline.cpp @@ -0,0 +1,66 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_pipeline.hpp" + +namespace { +using namespace ck_tile::core::arch::mma; +} + +TEST(MmaPipelineOptionFlagsTests, ConversionTests) +{ + MmaPipelineOptionFlags flags_0{}; + MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap}; + MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A}; + MmaPipelineOptionFlags flags_3{0b11}; + + EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap)); + EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + EXPECT_TRUE(flags_1.testFlag(MmaPipelineOptionFlag::ABSwap)); + EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_FALSE(flags_1.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + EXPECT_TRUE(flags_2.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_FALSE(flags_2.testFlag(MmaPipelineOptionFlag::ABSwap)); + + EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + EXPECT_TRUE(flags_3.testFlag(MmaPipelineOptionFlag::ABSwap)); + EXPECT_FALSE(flags_3.testFlag(MmaPipelineOptionFlag::NONE)); +} + +TEST(MmaPipelineOptionFlagsTests, OperatorsTests) +{ + MmaPipelineOptionFlags flags{}; + + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + + flags |= MmaPipelineOptionFlag::ABSwap; + + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); + + flags |= MmaPipelineOptionFlag::COMPRESS_A; + + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + flags &= MmaPipelineOptionFlag::COMPRESS_A; + + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_FALSE(flags.testFlag(MmaPipelineOptionFlag::ABSwap)); + EXPECT_TRUE(flags.testFlag(MmaPipelineOptionFlag::COMPRESS_A)); + + EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::NONE)); + EXPECT_TRUE((~flags).testFlag(MmaPipelineOptionFlag::ABSwap)); + EXPECT_FALSE((~flags).testFlag(MmaPipelineOptionFlag::COMPRESS_A)); +} diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp new file mode 100644 index 0000000000..be631f0659 --- /dev/null +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_sparse_mma.cpp @@ -0,0 +1,523 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" +#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp" +#include +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +#include "pipeline_tests_helper.hpp" + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target()); + +TEST(SparseMMATrait, SparseMfmaGfx950Specialization) +{ + // Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32) + using TestSparseMfma16x16 = amdgcn_mma; + + static_assert(std::is_same_v && + TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE, + "GFX950 sparse 16x16x32 should have SparseMFMAOp type"); + + static_assert(is_mma_op_of_family_v, + "GFX950 sparse 16x16x32 should be detected as Sparse"); + + std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl; +} + +TEST(SparseMMATrait, MmaOpTraitsIntegration) +{ + // Create a sparse MMA op (16x16x32 fp16 specialization) + using TestSparseMmma = amdgcn_mma; + + // Get its traits + using TestTraits = MmaOpTraits; + + // Verify trait detection + static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse"); + static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported"); + static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA"); + static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA"); + + std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl; +} + +TEST(SparseMMATrait, TestConceptRequirements) +{ +#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER + using TestSparseMmma = amdgcn_mma; + static_assert(MmaOpI); +#else + GTEST_SKIP() << "Not compiled with concepts. Skipping test."; +#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER +} + +TEST(SparseMMATrait, DenseVsSparseDistinction) +{ + // Dense MFMA from mfma/mfma_gfx9.hpp + using DenseMfma = amdgcn_mma; + + // Sparse MFMA on GFX950 + using SparseMfma = amdgcn_mma; + + // Verify they have different operation types + static_assert(std::is_same_v && + DenseMfma::OpFamily != SparseMfma::OpFamily, + "Dense and Sparse MFMA should have the same OpType tags and different OpFamily"); + + // Verify traits correctly identify them + static_assert(MmaOpTraits::IsMfma && MmaOpTraits::IsDense && + !MmaOpTraits::IsSparse && !MmaOpTraits::IsScale && + MmaOpTraits::IsSupported, + "Dense MFMA should be identified correctly"); + + static_assert(MmaOpTraits::IsSparse && MmaOpTraits::IsMfma && + !MmaOpTraits::IsDense && !MmaOpTraits::IsScale && + MmaOpTraits::IsSupported, + "Sparse MFMA should be identified correctly"); + + std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl; +} + +TEST(SparseMMATrait, SparseSelector) +{ + static_for<1, 33, 1>{}([](auto i) { + using Selected = typename MmaDefaultSelector(i), + static_cast(i), + static_cast(2 * i), + CompilerTargetGfx950, + MmaOpFamily::SPARSE>::SelectedOp; + + static constexpr bool isValid = (i == 16) || (i == 32); + if constexpr(isValid) + { + // Selector should pick a sparse MFMA implementation + static_assert(MmaOpTraits::IsSparse); + static_assert(MmaOpTraits::IsMfma); + static_assert(MmaOpTraits::IsSupported); + static_assert((std::is_same::value)); + } + else + { + // Selector should pick the unsupported pass through + static_assert(!MmaOpTraits::IsSupported); + } + }); +} + +template +__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) +{ + using Pipeline = SparseMmaPipeline; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + static constexpr uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK; + + // Initialize the accumulator + CVecType result = *reinterpret_cast(c); + + // Accumulate input AxB over WaveTileK/FragK iterations + for(uint32_t i = 0; i < kIters; ++i) + { + result = Pipeline::exec( + *reinterpret_cast(a), *reinterpret_cast(b), result); + } + + *reinterpret_cast(out) = result; +} + +// Live test on real hardware for sparse selection and execution. +TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real) +{ + MmaPipelineTest<> test; + const auto should_skip = [](amdgcn_target_id currentArchId) { + bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) && + (currentArchId <= amdgcn_target_id::GFX12_GENERIC); + bool isSupportedMfma = (currentArchId >= amdgcn_target_id::GFX942) && + (currentArchId <= amdgcn_target_id::GFX950); + return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma)); + }; + const std::function validator = [](uint32_t waveTileK) { + return static_cast(waveTileK) / 2; + }; + const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) { + test_sparse_accum_over_k::AType, + MmaPipelineTest<>::BType, + MmaPipelineTest<>::CType, + MmaPipelineTest<>::WaveTileM, + MmaPipelineTest<>::WaveTileN, + MmaPipelineTest<>::WaveTileK><<<1, waveSize>>>(a, b, c, out); + }; + // Initialize A with 2:4 structured sparsity pattern: {1, 0, 1, 0, ...} + // This ensures the sparse compression transform is actually exercised — + // a no-op or broken compression would pass zeros through, causing incorrect results. + const std::function sparseAInit = [](size_t i) -> fp16_t { + return (i % 2 == 0) ? type_convert(1) : type_convert(0); + }; + test.test_pipeline(should_skip, kernel, validator, sparseAInit); +} + +template +__global__ void test_sparse_transform(void* a, void* idx) +{ + using ResultT = + decltype(SparseCompressTransform::exec(*static_cast(a))); + using FirstT = std::tuple_element_t<0, ResultT>; + const auto& [vec, i] = SparseCompressTransform::exec(*static_cast(a)); + *reinterpret_cast*>(a) = vec; + *reinterpret_cast(idx) = i; +} + +// Generalized helper: runs the sparse transform kernel and verifies compressed output and index. +template +void sparse_transform_verify(const std::vector& input, + const std::vector& expected_output, + int32_t expected_idx) +{ + static_assert(RATIO == 2, "Extend functionality if other ratio is used."); + ASSERT_EQ(static_cast(input.size()), NUM); + ASSERT_EQ(static_cast(expected_output.size()), NUM / RATIO); + + int devCount; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); + + hipDeviceProp_t devProp; + HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); + + auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); + bool hasDevice = static_cast(devCount > 0); + + // TODO: c++20 add check for arch id + if(!hasDevice || (currentArchId == amdgcn_target_id::HOST)) + { + GTEST_SKIP() << "No HIP device found. Skipping test."; + } + + float* d_v; + int32_t* d_idx; + + static constexpr auto Size = sizeof(Type) * NUM; + HIP_CHECK_ERROR(hipMalloc(&d_v, Size)); + HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(int32_t))); + + // Copy inputs to device + HIP_CHECK_ERROR(hipMemcpy(d_v, input.data(), Size, hipMemcpyHostToDevice)); + + test_sparse_transform><<<1, 32>>>(d_v, d_idx); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + std::vector h_out(NUM / RATIO, static_cast(0)); + HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_v, Size / RATIO, hipMemcpyDeviceToHost)); + int32_t h_idx; + HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(int32_t), hipMemcpyDeviceToHost)); + + EXPECT_EQ(h_idx, expected_idx) << "Index mask mismatch"; + for(int i = 0; i < NUM / RATIO; ++i) + { + EXPECT_EQ(h_out[i], expected_output[i]) << "Output mismatch at position " << i; + } + + // Semantic index validation: each 2-bit field in h_idx encodes the original + // slot (0–3) within the group of 4 that the corresponding compressed element + // came from. Verify that the index is consistent with input and output. + // + // Note: when a group has fewer than 2 non-zeros, unused output slots contain + // initialization values (from nonzero_elems init) that don't correspond to the + // default index (slot 2). We only validate entries where the index was explicitly + // set, i.e. where input[slot] is non-zero. + constexpr int CompressedSize = NUM / RATIO; + for(int i = 0; i < CompressedSize; ++i) + { + int slot = (h_idx >> (2 * i)) & 0b11; + int group = i / 2; + Type input_at_slot = input[group * 4 + slot]; + // Only check when input at the indexed slot is non-zero (explicitly assigned) + // or when both are zero (consistent default for all-zero groups). + if(static_cast(input_at_slot) != 0.0f || static_cast(h_out[i]) == 0.0f) + { + EXPECT_EQ(h_out[i], input_at_slot) + << "Index field " << i << " points to slot " << slot << " in group " << group + << " but output[" << i << "] != input[" << (group * 4 + slot) << "]"; + } + } + + HIP_CHECK_ERROR(hipFree(d_v)); + HIP_CHECK_ERROR(hipFree(d_idx)); +} + +// Helper: build expected index from a per-group 4-bit pattern, repeated for all groups. +// Each group of 4 input elements contributes 2 compressed elements → 2 x 2-bit index fields = 4 +// bits. +static int32_t build_repeated_group_idx(int num_groups, int32_t group_bits_4) +{ + int32_t idx = 0; + for(int g = 0; g < num_groups; ++g) + idx |= (group_bits_4 << (4 * g)); + return idx; +} + +// Helper: build expected index from alternating even/odd 4-bit group patterns. +static int32_t build_alternating_group_idx(int num_groups, int32_t even_bits_4, int32_t odd_bits_4) +{ + int32_t idx = 0; + for(int g = 0; g < num_groups; ++g) + idx |= ((g % 2 == 0 ? even_bits_4 : odd_bits_4) << (4 * g)); + return idx; +} + +// 1. Basic correctness: valid divisible sizes +// Input pattern: {1, 0, 3, 0, 5, 0, 7, 0, ...} → non-zeros at slots 0,2 +// Group idx pattern: field0=0b00 (slot 0), field1=0b10 (slot 2) → 0b1000 +template +void sparse_transform_test_case() +{ + std::vector v(NUM); + for(int i = 0; i < NUM; ++i) + { + v[i] = i % 2 == 0 ? i + 1 : 0; + } + + std::vector expected_out(NUM / RATIO); + for(int i = 0; i < NUM / RATIO; ++i) + { + expected_out[i] = v[i * 2]; + } + + int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1000); + sparse_transform_verify(v, expected_out, expected_idx); +} + +TEST(SparseTransformsTest, ValidCompressionRatio) +{ + // TODO: extend those when new sparse builtins are + // introduced and use different type combinations + sparse_transform_test_case<8, 2, fp16_t>(); + sparse_transform_test_case<16, 2, fp16_t>(); + sparse_transform_test_case<32, 2, fp16_t>(); +} + +// All-zero input: no non-zeros in any group of 4. +// Each output pair defaults to {a_vec[slot2], a_vec[slot3]} = {0, 0}, +// and the index uses default slot-2 encoding (0b10) for every 2-bit field. +// Group idx pattern: 0b1010 +template +void sparse_transform_all_zero() +{ + using T = fp16_t; + std::vector input(NUM, static_cast(0)); + std::vector expected_output(NUM / 2, static_cast(0)); + int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1010); + sparse_transform_verify(input, expected_output, expected_idx); +} + +TEST(SparseTransformsTest, AllZeroInput) +{ + sparse_transform_all_zero<8>(); + sparse_transform_all_zero<16>(); + sparse_transform_all_zero<32>(); +} + +// Single non-zero per group of 4 (at slot 3). +// nonzero_elems initializes to {a_vec[slot2]=0, a_vec[slot3]=V}. +// Only j=3 triggers: nonzero_elems[0]=V, field0=0b11, pos becomes 1. +// nonzero_elems[1] keeps its init V. Output: {V, V}. +// Group idx pattern: field0=0b11, field1=0b10 (default) → 0b1011 +template +void sparse_transform_single_nonzero() +{ + using T = fp16_t; + std::vector input(NUM, static_cast(0)); + std::vector expected_output(NUM / 2); + + for(int g = 0; g < NUM / 4; ++g) + { + T val = static_cast(g + 5); + input[g * 4 + 3] = val; + expected_output[g * 2] = val; + expected_output[g * 2 + 1] = val; + } + + int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1011); + sparse_transform_verify(input, expected_output, expected_idx); +} + +TEST(SparseTransformsTest, SingleNonZeroPerGroup) +{ + sparse_transform_single_nonzero<8>(); + sparse_transform_single_nonzero<16>(); + sparse_transform_single_nonzero<32>(); +} + +// Non-zeros at slots 1 and 3 in each group. +// Input: {0, a, 0, b, ...}. Output: {a, b, ...}. +// Group idx pattern: field0=0b01 (slot 1), field1=0b11 (slot 3) → 0b1101 +template +void sparse_transform_slots_1_and_3() +{ + using T = fp16_t; + std::vector input(NUM, static_cast(0)); + std::vector expected_output(NUM / 2); + + for(int g = 0; g < NUM / 4; ++g) + { + T a = static_cast(g * 2 + 3); + T b = static_cast(g * 2 + 4); + input[g * 4 + 1] = a; + input[g * 4 + 3] = b; + expected_output[g * 2] = a; + expected_output[g * 2 + 1] = b; + } + + int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1101); + sparse_transform_verify(input, expected_output, expected_idx); +} + +TEST(SparseTransformsTest, NonZerosAtSlots1And3) +{ + sparse_transform_slots_1_and_3<8>(); + sparse_transform_slots_1_and_3<16>(); + sparse_transform_slots_1_and_3<32>(); +} + +// Non-zeros at slots 0 and 3 in each group (non-adjacent). +// Input: {a, 0, 0, b, ...}. Output: {a, b, ...}. +// Group idx pattern: field0=0b00 (slot 0), field1=0b11 (slot 3) → 0b1100 +template +void sparse_transform_slots_0_and_3() +{ + using T = fp16_t; + std::vector input(NUM, static_cast(0)); + std::vector expected_output(NUM / 2); + + for(int g = 0; g < NUM / 4; ++g) + { + T a = static_cast(g * 2 + 2); + T b = static_cast(g * 2 + 3); + input[g * 4] = a; + input[g * 4 + 3] = b; + expected_output[g * 2] = a; + expected_output[g * 2 + 1] = b; + } + + int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1100); + sparse_transform_verify(input, expected_output, expected_idx); +} + +TEST(SparseTransformsTest, NonZerosAtSlots0And3) +{ + sparse_transform_slots_0_and_3<8>(); + sparse_transform_slots_0_and_3<16>(); + sparse_transform_slots_0_and_3<32>(); +} + +// Mixed sparsity pattern: even groups have non-zeros at slots 0,2; odd groups at slots 1,3. +// Even group idx: field0=0b00, field1=0b10 → 0b1000 +// Odd group idx: field0=0b01, field1=0b11 → 0b1101 +template +void sparse_transform_mixed() +{ + using T = fp16_t; + std::vector input(NUM, static_cast(0)); + std::vector expected_output(NUM / 2); + + for(int g = 0; g < NUM / 4; ++g) + { + T a = static_cast(g * 2 + 1); + T b = static_cast(g * 2 + 2); + if(g % 2 == 0) + { + // Slots 0, 2 + input[g * 4] = a; + input[g * 4 + 2] = b; + } + else + { + // Slots 1, 3 + input[g * 4 + 1] = a; + input[g * 4 + 3] = b; + } + expected_output[g * 2] = a; + expected_output[g * 2 + 1] = b; + } + + int32_t expected_idx = build_alternating_group_idx(NUM / 4, 0b1000, 0b1101); + sparse_transform_verify(input, expected_output, expected_idx); +} + +TEST(SparseTransformsTest, MixedSparsityPattern) +{ + sparse_transform_mixed<8>(); + sparse_transform_mixed<16>(); + sparse_transform_mixed<32>(); +} diff --git a/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp new file mode 100644 index 0000000000..a3ee03c5eb --- /dev/null +++ b/test/ck_tile/core/arch/mma/pipeline/test_amdgcn_wavewise_mma.cpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_op_family.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" + +#include "pipeline_tests_helper.hpp" +#include + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; + +template +__global__ void test_wavewise_pipeline(void* a, void* b, void* c, void* out) +{ + using CompilerTarget = decltype(get_compiler_target()); + + using Pipeline = WaveWiseMmaPipeline; + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + auto result = Pipeline::exec(*reinterpret_cast(a), + *reinterpret_cast(b), + *reinterpret_cast(c)); + + if constexpr(MmaOpTraits::IsSupported) + { + // When the MmaOp is Unsupported (default) it returns the CVecType by value + // so this cast is impossible... + __builtin_memcpy(out, static_cast(result), sizeof(CVecType)); + } +} + +namespace { +const auto should_skip = [](amdgcn_target_id currentArchId) { + bool isSupportedWmma = false; + bool isSupportedMfma = + (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); + return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma)); +}; +const std::function validator = [](uint32_t waveTileK) { + return static_cast(waveTileK); +}; +} // namespace + +TEST(WaveWiseMmaPipeline, testKIter) +{ + MmaPipelineTest<> test; + const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) { + test_wavewise_pipeline::AType, + MmaPipelineTest<>::BType, + MmaPipelineTest<>::CType, + MmaPipelineTest<>::WaveTileM, + MmaPipelineTest<>::WaveTileN, + MmaPipelineTest<>::WaveTileK, + false><<<1, waveSize>>>(a, b, c, out); + }; + test.test_pipeline(should_skip, kernel, validator); +} + +TEST(WaveWiseMmaPipeline, testKIterSwapAB) +{ + MmaPipelineTest<> test; + const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) { + test_wavewise_pipeline::AType, + MmaPipelineTest<>::BType, + MmaPipelineTest<>::CType, + MmaPipelineTest<>::WaveTileM, + MmaPipelineTest<>::WaveTileN, + MmaPipelineTest<>::WaveTileK, + true><<<1, waveSize>>>(a, b, c, out); + }; + test.test_pipeline(should_skip, kernel, validator); +} diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp index 865c3e1011..5a8f478f48 100644 --- a/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp +++ b/test/ck_tile/core/arch/mma/test_amdgcn_mma.cpp @@ -7,7 +7,7 @@ #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" -#include "ck_tile/core/arch/mma/mma.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/host/hip_check_error.hpp" diff --git a/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp b/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp deleted file mode 100644 index 03abcb5772..0000000000 --- a/test/ck_tile/core/arch/mma/test_amdgcn_sparse_mma.cpp +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include - -#include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" -#include "ck_tile/core/arch/mma/mma_op_family.hpp" -#include "ck_tile/core/arch/mma/mma_selector.hpp" -#include -#include "ck_tile/host/hip_check_error.hpp" -#include "ck_tile/core/arch/mma/mma_traits.hpp" -#include "ck_tile/core/utility/type_traits.hpp" - -#include "get_wave_size_helper.hpp" - -using namespace ck_tile; -using namespace ck_tile::core::arch; -using namespace ck_tile::core::arch::mma; - -using CompilerTargetGfx950 = decltype(make_amdgcn_gfx9_target()); - -TEST(SparseMMATrait, SparseMfmaGfx950Specialization) -{ - // Test fp16 → fp32 sparse MFMA for GFX950 (16x16x32) - using TestSparseMfma16x16 = amdgcn_mma; - - static_assert(std::is_same_v && - TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE, - "GFX950 sparse 16x16x32 should have SparseMFMAOp type"); - - static_assert(is_mma_op_of_family_v, - "GFX950 sparse 16x16x32 should be detected as Sparse"); - - std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl; -} - -TEST(SparseMMATrait, MmaOpTraitsIntegration) -{ - // Create a sparse MMA op (16x16x32 fp16 specialization) - using TestSparseMmma = amdgcn_mma; - - // Get its traits - using TestTraits = MmaOpTraits; - - // Verify trait detection - static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse"); - static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported"); - static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA"); - static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA"); - - std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl; -} - -TEST(SparseMMATrait, DenseVsSparseDistinction) -{ - // Dense MFMA from mfma/mfma_gfx9.hpp - using DenseMfma = amdgcn_mma; - - // Sparse MFMA on GFX950 - using SparseMfma = amdgcn_mma; - - // Verify they have different operation types - static_assert(std::is_same_v && - DenseMfma::OpFamily != SparseMfma::OpFamily, - "Dense and Sparse MFMA should have the same OpType tags and different OpFamily"); - - // Verify traits correctly identify them - static_assert(MmaOpTraits::IsMfma && MmaOpTraits::IsDense && - !MmaOpTraits::IsSparse && !MmaOpTraits::IsScale && - MmaOpTraits::IsSupported, - "Dense MFMA should be identified correctly"); - - static_assert(MmaOpTraits::IsSparse && MmaOpTraits::IsMfma && - !MmaOpTraits::IsDense && !MmaOpTraits::IsScale && - MmaOpTraits::IsSupported, - "Sparse MFMA should be identified correctly"); - - std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl; -} - -TEST(SparseMMATrait, SparseSelector) -{ - static_for<1, 33, 1>{}([](auto i) { - using Selected = typename MmaDefaultSelector(i), - static_cast(i), - static_cast(2 * i), - CompilerTargetGfx950, - MmaOpFamily::SPARSE>::SelectedOp; - - static constexpr bool isValid = (i == 16) || (i == 32); - if constexpr(isValid) - { - // Selector should pick a sparse MFMA implementation - static_assert(MmaOpTraits::IsSparse); - static_assert(MmaOpTraits::IsMfma); - static_assert(MmaOpTraits::IsSupported); - static_assert((std::is_same::value)); - } - else - { - // Selector should pick the unsupported pass through - static_assert(!MmaOpTraits::IsSupported); - } - }); -} - -template -__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out) -{ - using CompilerTarget = decltype(get_compiler_target()); - using Selector = MmaDefaultSelector; - using MmaOp = typename Selector::SelectedOp; - using CVecType = typename MmaOp::CVecType; - - static constexpr uint32_t kIters = WaveTileK / MmaOp::kK; - - // Initialize the accumulator - CVecType result = *reinterpret_cast(c); - - // Accumulate input AxB over WaveTileK/FragK iterations - for(uint32_t i = 0; i < kIters; ++i) - { - result = MmaOp::exec(*reinterpret_cast(a), - *reinterpret_cast(b), - result); - } - - *reinterpret_cast(out) = result; -} - -// Live test on real hardware for sparse selection and execution. -TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real) -{ - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) && - (currentArchId <= amdgcn_target_id::GFX12_GENERIC); - bool isSupportedMfma = - (currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950); - // TODO: c++20 add check for arch id - if(!hasDevice || (currentArchId == amdgcn_target_id::HOST) || - !(isSupportedWmma || isSupportedMfma)) - { - GTEST_SKIP() << "No HIP device found. Skipping test."; - } - - using AType = fp16_t; - using BType = fp16_t; - using CType = fp32_t; - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t WaveTileM = 16; - static constexpr uint32_t WaveTileN = 16; - static constexpr uint32_t WaveTileK = 32; - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize; - uint32_t BElements = FragN * FragK / deviceWarpSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - - // Initialize A and B to all 1's, C to all 0's - std::vector h_a(AElements, static_cast(1)); - std::vector h_b(BElements, static_cast(1)); - std::vector h_c(CElements, static_cast(0)); - std::vector h_out(CElements, static_cast(0)); - - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; - - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - test_sparse_accum_over_k - <<<1, wave_size>>>(d_a, d_b, d_c, d_out); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Output should be FragK for all elements, because the inputs are all 1's - for(size_t i = 0; i < CElements; ++i) - { - // In sparse only half of the A values are non-zero, thus the /2. - CType expected = static_cast(FragK) / 2; - - EXPECT_NEAR(h_out[i], expected, 1e-3); - } - - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); -}