diff --git a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp index a491c0d2b9..9e62f6e939 100644 --- a/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp +++ b/example/ck_tile/51_tile_distr_enc_reg_map/example_tile_distr_enc_calc.cpp @@ -5,7 +5,7 @@ #include #include #include "ck_tile/core/arch/arch.hpp" -#include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp" #include "ck_tile/core/container/tuple.hpp" diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 4afba77d6a..2b7066cabf 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -24,6 +24,7 @@ #include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp" #include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp" +#include "ck_tile/core/arch/mma/mma.hpp" #include "ck_tile/core/arch/mma/mma_op_family.hpp" #include "ck_tile/core/arch/mma/mma_pipeline.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" diff --git a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp index 5985f63440..4cb28762ba 100644 --- a/include/ck_tile/core/arch/mma/amdgcn_mma.hpp +++ b/include/ck_tile/core/arch/mma/amdgcn_mma.hpp @@ -393,10 +393,3 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma= 23 #pragma clang diagnostic pop #endif - -// Include the implementations -#include "wmma/wmma.hpp" // should be included before the below headers - -#include "mfma/mfma.hpp" -#include "scale/scale.hpp" -#include "sparse/sparse.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp index 2140e3317a..0fa1bada78 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_selector.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/mma/amdgcn_mma.hpp" +#include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp index 5a3fc9a7e4..9609ed5116 100644 --- a/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp +++ b/include/ck_tile/core/arch/mma/mfma/mfma_transforms.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/mma_transforms.hpp" namespace ck_tile::core::arch::mma { diff --git a/include/ck_tile/core/arch/mma/mma.hpp b/include/ck_tile/core/arch/mma/mma.hpp new file mode 100644 index 0000000000..ec38fe78e3 --- /dev/null +++ b/include/ck_tile/core/arch/mma/mma.hpp @@ -0,0 +1,8 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "wmma/wmma.hpp" +#include "mfma/mfma.hpp" +#include "sparse/sparse.hpp" diff --git a/include/ck_tile/core/arch/mma/mma_pipeline.hpp b/include/ck_tile/core/arch/mma/mma_pipeline.hpp index de01760620..0994f2d178 100644 --- a/include/ck_tile/core/arch/mma/mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/mma_pipeline.hpp @@ -270,12 +270,20 @@ struct MmaPipelineBase { if constexpr(MmaOpTraits::IsSupported) { - auto transformed_inputs = applyTransformsToInputs( - hasFlag() ? std::forward(b) - : std::forward(a), - hasFlag() ? std::forward(a) - : std::forward(b), - std::forward(accum)); + constexpr bool swap_a_and_b = hasFlag(); + + auto transformed_inputs = [&]() { + if constexpr(swap_a_and_b) + { + return applyTransformsToInputs( + std::forward(b), std::forward(a), std::forward(accum)); + } + else + { + return applyTransformsToInputs( + std::forward(a), std::forward(b), std::forward(accum)); + } + }(); Derived::execImpl(transformed_inputs); @@ -302,17 +310,18 @@ struct MmaPipelineBase { if constexpr(MmaOpTraits::IsSupported) { - // TODO: c++20: Call template functions with MmaPipelineOptionFlags directly + static_assert(MmaOpTraits::IsScale, + "This exec variant is intended for scale policy structs"); + constexpr bool swap_a_and_b = hasFlag(); + auto transformed_inputs = applyTransformsToInputs( - hasFlag() ? std::forward(b) - : std::forward(a), - hasFlag() ? std::forward(a) - : std::forward(b), + swap_a_and_b ? std::forward(b) : std::forward(a), + swap_a_and_b ? std::forward(a) : std::forward(b), std::forward(accum), - hasFlag() ? std::forward(scale_B) - : std::forward(scale_A), - hasFlag() ? std::forward(scale_A) - : std::forward(scale_B)); + swap_a_and_b ? std::forward(scale_B) + : std::forward(scale_A), + swap_a_and_b ? std::forward(scale_A) + : std::forward(scale_B)); Derived::execImpl(transformed_inputs); diff --git a/include/ck_tile/core/arch/mma/mma_wavewise.hpp b/include/ck_tile/core/arch/mma/mma_wavewise.hpp index 9fbbab411e..bc7e383f6d 100644 --- a/include/ck_tile/core/arch/mma/mma_wavewise.hpp +++ b/include/ck_tile/core/arch/mma/mma_wavewise.hpp @@ -169,7 +169,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> // clang-format off -struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> +struct ScaleMmaPipeline : public MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline> { - using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; + using Base = MmaPipelineBase(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline>; // clang-format on using MmaOp = MmaOp_; // Expose the selected MmaOp - // Expose caller-side vector types - using AVecType = typename MmaOp::AVecType; - using BVecType = typename MmaOp::BVecType; - using CVecType = typename MmaOp::CVecType; + // Fragment dimensions (from the hardware MmaOp) + constexpr static uint32_t FragM = MmaOp::kM; + constexpr static uint32_t FragN = MmaOp::kN; + constexpr static uint32_t FragK = MmaOp::kK; - // Expose internal vector types + // Fragment counts for decomposition + constexpr static uint32_t FragsM = WaveTileM / FragM; + constexpr static uint32_t FragsN = WaveTileN / FragN; + constexpr static uint32_t FragsK = WaveTileK / FragK; + + // Vector types for packed registers in each fragment using InternalAVecT = typename MmaOp::AVecType; using InternalBVecT = typename MmaOp::BVecType; using InternalCVecT = typename MmaOp::CVecType; + // Buffer types for WaveTiles + using AVecType = InternalAVecT[FragsM][FragsK]; + using BVecType = InternalBVecT[FragsN][FragsK]; + using CVecType = InternalCVecT[FragsM][FragsN]; + // Transforms using ATransform = typename MmaTransforms::ATransform; using BTransform = typename MmaTransforms::BTransform; using CTransform = typename MmaTransforms::CTransform; using DTransform = typename MmaTransforms::DTransform; + // Sanity checks + static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM"); + static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN"); + static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK"); + static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); + static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); + static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); + template (MmaPipelineOpt CK_TILE_DEVICE static void execImpl(std::tuple& vecs) { - auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs; - c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B); + auto& [a_frag, b_frag, c_frag, scale_A, scale_B] = vecs; + + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B); + } + } + } + } + else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B); + } + } + } + } + else + { + static_assert(false, "Invalid accumulation policy"); + } } }; diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp index d57f544a41..7b2f24dea8 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp @@ -5,10 +5,10 @@ #include "ck_tile/core/arch/mma/mma_pipeline.hpp" #include "ck_tile/core/arch/mma/mma_selector.hpp" #include "ck_tile/core/arch/mma/mma_traits.hpp" +#include "ck_tile/core/arch/mma/mma_wavewise.hpp" #include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include -#include namespace ck_tile::core::arch::mma { @@ -20,12 +20,33 @@ constexpr inline int getPipelineFlags() } } // namespace sparse::detail +/** + * @class SparseMmaPipeline + * @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation + * (e.g., smfmac), this class performs fragment-wise (MmaTile) decomposition to matrix-multiply + * input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and accumulates + * results into output WaveTile (C: WaveTileM x WaveTileN). + * Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates + * internally over FragsM × FragsN × FragsK. The A operand is provided in uncompressed form; + * 2:4 structured sparsity compression (SparseCompressTransform) is applied. + * @tparam ADataType Data type of input WaveTile A + * @tparam BDataType Data type of input WaveTile B + * @tparam CDataType Data type of input/output WaveTile C (accumulator) + * @tparam WaveTileM Mma WaveTile M dimension + * @tparam WaveTileN Mma WaveTile N dimension + * @tparam WaveTileK Mma WaveTile K dimension + * @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order) + * @tparam CompilerTarget The compiler target + * @tparam MmaOp_ Backend wrapper class that will perform the mma op + * @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles + */ template ::SelectedOp, typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms = typename MmaTransformsDefaultSelector::SelectedTransforms> // clang-format off -struct SparseMmaPipeline : public MmaPipelineBase> +struct SparseMmaPipeline : public MmaPipelineBase> { - using Base = MmaPipelineBase>; + using Base = MmaPipelineBase>; // clang-format on static_assert(!Base::template hasFlag(), @@ -52,48 +73,153 @@ struct SparseMmaPipeline : public MmaPipelineBase; static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio; using AVecType = ext_vector_t; }; + using ExternalAFragVecT = typename ExternalAVecCalculator::AVecType; - // Expose caller-side vector types - using AVecType = typename ExternalAVecCalculator::AVecType; - using BVecType = typename MmaOp::BVecType; - using CVecType = typename MmaOp::CVecType; + // Scalar type of A + using AScalarT = typename ExternalAVecCalculator::AVecTraits::scalar_type; - // Expose internal vector types - using InternalAVecT = typename MmaOp::AVecType; + // Per-fragment sizes + static constexpr uint32_t ExternalAFragSize = ExternalAVecCalculator::ASize; + static constexpr uint32_t InternalAFragSize = + vector_traits::vector_size; + + // Full wave-tile sizes (all fragments combined) + static constexpr uint32_t TotalUncompressedElems = FragsM * FragsK * ExternalAFragSize; + static constexpr uint32_t TotalCompressedElems = + TotalUncompressedElems / MmaOp::kCompressionRatio; + + // Variable-length idx type for the whole wave-tile (spans multiple int32_t words if needed) + static constexpr index_t IdxNumWords = sparse::detail::idx_words_needed; + using IdxType = sparse::detail::SparseIdxPack; + + // Per-fragment compressed vector type (for individual MmaOp::exec calls) + using FragAVecT = typename MmaOp::AVecType; + + // Internal vector types used by the base class formatBuffer. + // InternalAVecT matches the full compressed wave-tile so the base class can + // format the SparseCompressTransform result via formatBuffer. + using InternalAVecT = ext_vector_t; using InternalBVecT = typename MmaOp::BVecType; using InternalCVecT = typename MmaOp::CVecType; + // Buffer types for WaveTiles (caller-facing). + // A is a single flat uncompressed vector covering the whole wave-tile. + // The base class compresses it in one pass via ATransform. + using AVecType = ext_vector_t; + using BVecType = InternalBVecT[FragsN][FragsK]; + using CVecType = InternalCVecT[FragsM][FragsN]; + // Transforms using ATransform = typename MmaTransforms::ATransform; using BTransform = typename MmaTransforms::BTransform; using CTransform = typename MmaTransforms::CTransform; using DTransform = typename MmaTransforms::DTransform; + // Sanity checks + static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM"); + static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN"); + static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK"); + static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM"); + static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN"); + static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK"); + template CK_TILE_DEVICE static void - execImpl(std::tuple& vecs) + execImpl(std::tuple& transformedInputs) { + auto& [a, b_frag, c_frag] = transformedInputs; + auto& [a_compressed_whole, idx] = a; + + // Validate that the ATransform result and per-fragment reinterpretation are correct checkATransformResult(); - auto& [a_result, b_vec, c_vec] = vecs; - auto& [a_vec, idx] = a_result; - c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx); + + // Reinterpret the full compressed vector as per-fragment arrays + auto* a_frags = ck_tile::bit_cast(&a_compressed_whole); + + // Accumulation loop with per-fragment idx extraction + if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frags[bm][bk], + b_frag[bn][bk], + c_frag[bm][bn], + sparse::detail::extract_fragment_idx( + idx, bm, bk)); + } + } + } + } + else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR) + { + for(uint32_t bn = 0u; bn < FragsN; ++bn) + { + for(uint32_t bm = 0u; bm < FragsM; ++bm) + { + for(uint32_t bk = 0u; bk < FragsK; ++bk) + { + c_frag[bm][bn] = MmaOp::exec( + a_frags[bm][bk], + b_frag[bn][bk], + c_frag[bm][bn], + sparse::detail::extract_fragment_idx( + idx, bm, bk)); + } + } + } + } + else + { + static_assert(false, "Invalid accumulation policy"); + } } private: - // Type check helper - not a device function, so std::declval is available + // Compile-time validation of ATransform result and per-fragment reinterpretation. + // Ensures the compressed vector returned by ATransform::exec can be safely + // reinterpreted as FragAVecT[FragsM][FragsK] for per-fragment MmaOp dispatch. template static constexpr void checkATransformResult() { using ExternalAvecRef = std::add_lvalue_reference_t; static_assert(std::is_same_v()))>); + decltype(ATransform::exec(std::declval()))>, + "ATransformResult must match the return type of ATransform::exec"); + + using CompressedVecType = + std::remove_reference_t>; + static_assert(sizeof(CompressedVecType) == sizeof(FragAVecT) * FragsM * FragsK, + "Compressed A vector size must equal sizeof(FragAVecT[FragsM][FragsK])"); + + static_assert(alignof(CompressedVecType) >= alignof(FragAVecT), + "Compressed vector alignment must be >= FragAVecT alignment " + "for safe reinterpret_cast to per-fragment array"); + + using ActualIdxType = std::tuple_element_t<1, ATransformResult>; + static_assert(std::is_same_v, + "Sparsity index type must match SparseIdxPack"); } }; diff --git a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp index 4b0effc2bf..f89a062240 100644 --- a/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp +++ b/include/ck_tile/core/arch/mma/sparse/sparse_transforms.hpp @@ -13,6 +13,30 @@ namespace ck_tile::core::arch::mma { namespace sparse::detail { + +/// Number of int32_t words needed to store CompressedSize 2-bit idx fields. +template +static constexpr index_t idx_words_needed = (CompressedSize * 2 + 31) / 32; + +/** + * @class SparseIdxPack + * @brief Variable-length container for 2:4 structured sparsity index metadata. + * + * Each compressed element produces a 2-bit index field encoding the original + * position (0–3) within its group of 4. When composing multiple MMA fragments + * in M and K dimensions within a WaveTile, the total number of index bits can + * exceed 32. This struct packs the index fields into an array of int32_t words, + * sized at compile time. + * + * @tparam NumWords Number of int32_t words needed to store all index fields. + */ +template +struct SparseIdxPack +{ + static_assert(NumWords > 0, "SparseIdxPack requires at least 1 word"); + int32_t words[NumWords] = {}; +}; + /** * @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero * elements into lower part of a_vec to half its effective size. @@ -20,21 +44,29 @@ namespace sparse::detail { * @tparam ADataType The data type of a_vec * @tparam CompressedSize The target compression size * @tparam AVec The vector type of a_vec (deduced) - * @return Packed 32‑bit word containing **CompressedSize** 2‑bit fields. - * Each field encodes the original position (0–3) of the corresponding - * non‑zero element in the input. If fewer than CompressedSize - * non‑zeros are found, remaining fields default to 2 (see below). + * @return SparseIdxPack containing **CompressedSize** 2‑bit fields packed + * across one or more int32_t words. Each field encodes the original + * position (0–3) of the corresponding non‑zero element in the input. + * If fewer than CompressedSize non‑zeros are found, remaining fields + * default to 2 (see below). */ template -static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) +static CK_TILE_DEVICE auto compress_a_impl(AVec& a_vec) { - // idx holds one 2‑bit index per output element (total CompressedSize entries). + static constexpr index_t NumIdxWords = idx_words_needed; + // idx holds one 2‑bit index per output element (total CompressedSize entries), + // packed across NumIdxWords int32_t words. // It is initialized to the pattern 0b10 for every field. This matches // what the hardware expects when there are fewer than two non‑zero values // in a 4‑element group – the unused output is treated as coming from slot 2. // The loop below will clear and set each field as real non‑zeros are seen. - int32_t idx = 0; - static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); }); + SparseIdxPack idx{}; + static_for<0, CompressedSize, 1>{}([&](auto k) { + constexpr uint32_t bit_pos = static_cast(k) * 2u; + constexpr uint32_t word = bit_pos / 32u; + constexpr uint32_t shift = bit_pos % 32u; + idx.words[word] |= static_cast(2u << shift); + }); static_for<0, CompressedSize / 2, 1>{}([&](auto i) { ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; @@ -45,8 +77,13 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) { nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; // clear the two‑bit field for this output and insert j - idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos))); - idx |= static_cast(j) << (2u * (i * 2 + non_zero_pos)); + const uint32_t field_idx = + static_cast(i) * 2u + static_cast(non_zero_pos); + const uint32_t bit_pos = field_idx * 2u; + const uint32_t word = bit_pos / 32u; + const uint32_t shift = bit_pos % 32u; + idx.words[word] &= ~static_cast(0b11u << shift); + idx.words[word] |= static_cast(static_cast(j) << shift); ++non_zero_pos; } }); @@ -56,6 +93,40 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec) return idx; } +/** + * @brief Extract the per-fragment sparsity index from a packed idx pack. + * After whole-wave-tile compression, the returned idx packs 2-bit fields for + * every compressed output element across one or more int32_t words. + * @return A single int32_t with this fragment's 2-bit fields at the + * least-significant positions, suitable for passing to the MMA builtin. + */ +template +static CK_TILE_DEVICE int32_t extract_fragment_idx(const SparseIdxPack& idx, + uint32_t m, + uint32_t k) +{ + static constexpr uint32_t IdxBitsPerFrag = FragCompressedSize * 2; + const auto fragLinearIdx = m * FragsK + k; + const auto totalBitOffset = fragLinearIdx * IdxBitsPerFrag; + const auto wordIdx = totalBitOffset / 32u; + const auto bitInWord = totalBitOffset % 32u; + + uint32_t result = static_cast(idx.words[wordIdx]) >> bitInWord; + + // If fragment bits span a word boundary, stitch in bits from the next word. + // (This is a safety measure; it should not occur when IdxBitsPerFrag is a + // power-of-2 divisor of 32, which is always the case for current MMA ops.) + if constexpr(NumIdxWords > 1) + { + if(bitInWord != 0 && bitInWord + IdxBitsPerFrag > 32u) + { + result |= static_cast(idx.words[wordIdx + 1]) << (32u - bitInWord); + } + } + + return static_cast(result); +} + } // namespace sparse::detail /** @@ -75,15 +146,15 @@ struct SparseCompressTransform static constexpr auto VecN = VecTraits::vector_size; static constexpr index_t CompressedSize = VecN / CompressionRatio; using VecCompressed = ext_vector_t; + using IdxType = + sparse::detail::SparseIdxPack>; static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio"); static_assert(CompressedSize > 0, "CompressedSize must be > 0"); - const auto idx = sparse::detail::compress_a_impl(v); + auto idx = sparse::detail::compress_a_impl(v); - // TODO c++20: Use bit_cast - return std::tuple( - *std::launder(reinterpret_cast(&v)), idx); + return std::tuple(*ck_tile::bit_cast(&v), idx); } }; diff --git a/test/ck_tile/core/arch/mma/CMakeLists.txt b/test/ck_tile/core/arch/mma/CMakeLists.txt index e65a76c134..7f7817b5bf 100644 --- a/test/ck_tile/core/arch/mma/CMakeLists.txt +++ b/test/ck_tile/core/arch/mma/CMakeLists.txt @@ -7,19 +7,109 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx9|gfx120") - add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) +# --------------------------------------------------------------------------- +# Map GPU target strings to hex amdgcn_target_id values (arch.hpp). +# Builds a -DCK_CMAKE_GPU_TARGET_IDS=0xHHHH,... definition that host-side +# test code can consume without launching a device kernel. +# --------------------------------------------------------------------------- +function(_ck_gpu_target_string_to_id TARGET_STR OUT_VAR) + string(TOLOWER "${TARGET_STR}" _tgt) + string(REGEX REPLACE ":.*" "" _tgt "${_tgt}") + # GFX9 + if(_tgt STREQUAL "gfx908") + set(${OUT_VAR} "0x0908" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx90a") + set(${OUT_VAR} "0x090A" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx942") + set(${OUT_VAR} "0x0942" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx950") + set(${OUT_VAR} "0x0950" PARENT_SCOPE) + # GFX10.3 + elseif(_tgt STREQUAL "gfx1030") + set(${OUT_VAR} "0x1030" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1031") + set(${OUT_VAR} "0x1031" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1032") + set(${OUT_VAR} "0x1032" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1033") + set(${OUT_VAR} "0x1033" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1034") + set(${OUT_VAR} "0x1034" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1035") + set(${OUT_VAR} "0x1035" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1036") + set(${OUT_VAR} "0x1036" PARENT_SCOPE) + elseif(_tgt MATCHES "^gfx10-3-generic$") + set(${OUT_VAR} "0x103F" PARENT_SCOPE) + # GFX11 + elseif(_tgt STREQUAL "gfx1100") + set(${OUT_VAR} "0x1100" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1101") + set(${OUT_VAR} "0x1101" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1102") + set(${OUT_VAR} "0x1102" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1103") + set(${OUT_VAR} "0x1103" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1150") + set(${OUT_VAR} "0x1150" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1151") + set(${OUT_VAR} "0x1151" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1152") + set(${OUT_VAR} "0x1152" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1153") + set(${OUT_VAR} "0x1153" PARENT_SCOPE) + elseif(_tgt MATCHES "^gfx11-generic$") + set(${OUT_VAR} "0x11FF" PARENT_SCOPE) + # GFX12 + elseif(_tgt STREQUAL "gfx1200") + set(${OUT_VAR} "0x1200" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1201") + set(${OUT_VAR} "0x1201" PARENT_SCOPE) + elseif(_tgt MATCHES "^gfx12-generic$") + set(${OUT_VAR} "0x12FF" PARENT_SCOPE) + elseif(_tgt STREQUAL "gfx1250") + set(${OUT_VAR} "0x1250" PARENT_SCOPE) + else() + message(WARNING "_ck_gpu_target_string_to_id: unknown GPU target '${TARGET_STR}', skipping") + set(${OUT_VAR} "" PARENT_SCOPE) + endif() +endfunction() +function(_ck_add_gpu_target_ids_define TARGET_NAME) + get_property(_archs TARGET ${TARGET_NAME} PROPERTY HIP_ARCHITECTURES) + string(REPLACE "," ";" _archs "${_archs}") + set(_hex_ids) + foreach(_tgt IN LISTS _archs) + _ck_gpu_target_string_to_id("${_tgt}" _hex) + if(_hex AND NOT _hex STREQUAL "0x0000") + list(APPEND _hex_ids "${_hex}") + endif() + endforeach() + list(JOIN _hex_ids "," _hex_str) + if(_hex_str) + target_compile_definitions(${TARGET_NAME} PRIVATE "CK_CMAKE_GPU_TARGET_IDS=${_hex_str}") + endif() +endfunction() + +# Convenience: add_gtest_executable + inject CK_CMAKE_GPU_TARGET_IDS +macro(_add_mma_gtest TEST_NAME) + add_gtest_executable(${TEST_NAME} ${ARGN}) + _ck_add_gpu_target_ids_define(${TEST_NAME}) +endmacro() +# --------------------------------------------------------------------------- + +if(GPU_TARGETS MATCHES "gfx9|gfx120") + _add_mma_gtest(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp) target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx950") - add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp) + _add_mma_gtest(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp) target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp) + _add_mma_gtest(test_amdgcn_mma test_amdgcn_mma.cpp) target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp) + _add_mma_gtest(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp) target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") @@ -37,45 +127,45 @@ macro(set_mma_test_arch_define target_name) endmacro() if(GPU_TARGETS MATCHES "gfx9") - add_gtest_executable(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp) target_compile_options(test_amdgcn_mma_layout_gfx9 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx9) endif() if(GPU_TARGETS MATCHES "gfx908|gfx90a") - add_gtest_executable(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp) target_compile_options(test_amdgcn_mma_layout_gfx908_and_gfx90a PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx908_and_gfx90a) endif() if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") - add_gtest_executable(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp) target_compile_options(test_amdgcn_mma_layout_gfx90a_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx90a_and_higher) endif() if(GPU_TARGETS MATCHES "gfx942|gfx950") - add_gtest_executable(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp) target_compile_options(test_amdgcn_mma_layout_gfx942_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS} -Wno-header-hygiene) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx942_and_higher) endif() if(GPU_TARGETS MATCHES "gfx950") - add_gtest_executable(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp) target_compile_options(test_amdgcn_mma_layout_gfx950 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) set_mma_test_arch_define(test_amdgcn_mma_layout_gfx950) endif() if(GPU_TARGETS MATCHES "gfx11") - add_gtest_executable(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp) target_compile_options(test_amdgcn_mma_layout_gfx11 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx120") - add_gtest_executable(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp) + _add_mma_gtest(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp) target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() -add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) +_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp) target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp b/test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp new file mode 100644 index 0000000000..eeac607a1b --- /dev/null +++ b/test/ck_tile/core/arch/mma/get_cmake_targets_helper.hpp @@ -0,0 +1,87 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include + +namespace ck_tile::core::arch::testing { + +static CK_TILE_HOST auto getCMakeGpuTargetIds() +{ + using ck_tile::core::arch::amdgcn_target_id; +#ifdef CK_CMAKE_GPU_TARGET_IDS + constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS}; + std::unordered_set result; + for(auto id : ids) + result.insert(static_cast(id)); + return result; +#else + return std::unordered_set{}; +#endif +} + +template +static CK_TILE_HOST bool dispatchCompilerTarget(ck_tile::core::arch::amdgcn_target_id id, + Func&& func) +{ + using namespace ck_tile::core::arch; + + // clang-format off + switch(id) + { + case amdgcn_target_id::GFX908: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX90A: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX942: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX950: func(make_amdgcn_gfx9_target()); return true; + case amdgcn_target_id::GFX1030: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1031: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1032: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1033: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1034: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1035: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1036: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX103_GENERIC: func(make_amdgcn_gfx10_3_target()); return true; + case amdgcn_target_id::GFX1100: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1101: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1102: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1103: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1150: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1151: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1152: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1153: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX11_GENERIC: func(make_amdgcn_gfx11_target()); return true; + case amdgcn_target_id::GFX1200: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::GFX1201: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::GFX12_GENERIC: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::GFX1250: func(make_amdgcn_gfx12_target()); return true; + case amdgcn_target_id::HOST: return false; + } + // clang-format on + __builtin_unreachable(); +} + +static CK_TILE_HOST constexpr int32_t getCMakeWaveSize() +{ + using ck_tile::core::arch::amdgcn_target_id; +#ifdef CK_CMAKE_GPU_TARGET_IDS + constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS}; + constexpr index_t targets_size = sizeof(ids) / sizeof(ids[0]); + static_assert(targets_size > 0); + constexpr auto first_target_id = static_cast(ids[0]); + if constexpr(first_target_id >= amdgcn_target_id::GFX908 && + first_target_id <= amdgcn_target_id::GFX950) + { + return 64; + } + else + { + return 32; + } +#else + static_assert(false, "Configure CK_CMAKE_GPU_TARGET_IDS before calling this function."); + return 0; +#endif +} +} // namespace ck_tile::core::arch::testing diff --git a/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp b/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp deleted file mode 100644 index 84a3f955e5..0000000000 --- a/test/ck_tile/core/arch/mma/get_wave_size_helper.hpp +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include - -#include "ck_tile/core/arch/arch.hpp" -#include -#include "ck_tile/host/hip_check_error.hpp" - -namespace { - -__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize) -{ - using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target()); - - if(waveSize) - *waveSize = static_cast(CompilerTarget::WAVE_SIZE_ID); -} - -static __host__ uint32_t getDeviceWaveSize() -{ - uint32_t* d_wave_size; - HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t))); - getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - uint32_t wave_size; - HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost)); - return wave_size; -} - -} // namespace diff --git a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp index 8460100aa9..edd5828a5b 100644 --- a/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp +++ b/test/ck_tile/core/arch/mma/pipeline/pipeline_tests_helper.hpp @@ -10,223 +10,421 @@ #include #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp" +#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp" #include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/host/hip_check_error.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include -#include "../get_wave_size_helper.hpp" +#include "../get_cmake_targets_helper.hpp" -template -struct MmaPipelineTest +namespace mma_pipeline_test { + +using namespace ck_tile; +using namespace ck_tile::core::arch; +using namespace ck_tile::core::arch::mma; +using namespace ck_tile::core::arch::testing; + +inline bool hipTargetMatchesCmakeTargets(amdgcn_target_id arch) { - using AType = AType_; - using BType = BType_; - using CType = CType_; - using ScaleAType = ScaleAType_; - using ScaleBType = ScaleBType_; - static constexpr auto WaveTileM = WaveTileM_; - static constexpr auto WaveTileN = WaveTileN_; - static constexpr auto WaveTileK = WaveTileK_; - - void test_pipeline(std::function shouldSkip, - std::function kernel, - std::function getExpected, - std::function aInitializer = nullptr) + const auto cmake_targets = getCMakeGpuTargetIds(); + if(cmake_targets.count(arch) == 0) { - using namespace ck_tile; - using namespace ck_tile::core::arch; - - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - if(!hasDevice || shouldSkip(currentArchId)) + // gfx12-generic and gfx11-generic make no difference with the specialized archs. + // Some CI pipelines make use of that and configure the project with the generic + // flags besides compiling for (f.e.) gfx1201. + if(arch >= amdgcn_target_id::GFX1200 && arch <= amdgcn_target_id::GFX12_GENERIC) { - GTEST_SKIP() << "No HIP device found. Skipping test."; + return (cmake_targets.count(amdgcn_target_id::GFX12_GENERIC) > 0); } - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize; - uint32_t BElements = FragN * FragK / deviceWarpSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - - // Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's - std::vector h_a(AElements); - if(aInitializer) + else if(arch >= amdgcn_target_id::GFX1100 && arch <= amdgcn_target_id::GFX11_GENERIC) { - for(size_t i = 0; i < AElements; ++i) - h_a[i] = aInitializer(i); + return (cmake_targets.count(amdgcn_target_id::GFX11_GENERIC) > 0); } - else + } + return true; +} +template +void reference_matmul(std::vector& C, + const std::vector& A, + const std::vector& B, + uint32_t M, + uint32_t N, + uint32_t K) +{ + for(uint32_t m = 0; m < M; ++m) + { + for(uint32_t n = 0; n < N; ++n) { - std::fill(h_a.begin(), h_a.end(), type_convert(1)); + float acc = 0.0f; + for(uint32_t k = 0; k < K; ++k) + { + acc += type_convert(A[m * K + k]) * type_convert(B[k * N + n]); + } + C[m * N + n] = static_cast(acc); } - std::vector h_b(BElements, type_convert(1)); - std::vector h_c(CElements, type_convert(0)); - std::vector h_out(CElements, type_convert(0)); + } +} - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; +template +T deterministic_value(uint32_t row, uint32_t col, uint32_t minor_dim) +{ + float v = static_cast((row * minor_dim + col) % 7 + 1) * 0.25f; + return type_convert(v); +} - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - kernel(wave_size, d_a, d_b, d_c, d_out); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Verify output against expected value for all elements - for(size_t i = 0; i < CElements; ++i) +// Apply 2:4 sparsity pattern to A matrix in-place (for sparse pipeline tests). +// Every group of 4 consecutive K elements keeps slots 0 and 2, zeros slots 1 and 3. +template +void apply_sparse_pattern(std::vector& A, uint32_t M, uint32_t K) +{ + for(uint32_t m = 0; m < M; ++m) + { + for(uint32_t k = 0; k < K; k += 4) { - EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3); + // Keep slots 0, 2. Zero out slots 1, 3. + if(k + 1 < K) + A[m * K + k + 1] = static_cast(0); + if(k + 3 < K) + A[m * K + k + 3] = static_cast(0); } + } +} - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); +// Fill per-lane A fragments from logical A[M][K] matrix. +// For dense pipelines: AVecType = InternalAVecT[FragsM][FragsK] +// For sparse pipelines: AVecType = ExternalAFragVecT[FragsM][FragsK] (uncompressed) +template +void fill_a_fragments(typename Pipeline::AVecType* a_per_lane, + const std::vector& A_matrix, + uint32_t K, + uint32_t waveSize) +{ + using MmaOp = typename Pipeline::MmaOp; + using ARegMap = TileDistrEncRegMap::AWarpDstrEncoding>; + using AFragScalar = typename vector_traits::scalar_type; + + constexpr uint32_t FragM = Pipeline::FragM; + constexpr uint32_t FragK = Pipeline::FragK; + constexpr uint32_t FragsM = Pipeline::FragsM; + constexpr uint32_t FragsK = Pipeline::FragsK; + + constexpr uint32_t kCompressionRatio = MmaOp::kCompressionRatio; + + // The A register map maps (lane, vec_idx) -> (m_within_frag, k_within_frag) + // For sparse: k_within_frag is in the compressed K domain (K / kCompressionRatio) + constexpr index_t a_vec_size = ARegMap::num_vector_items; + constexpr index_t external_a_frag_vec_size = a_vec_size * kCompressionRatio; + + for(uint32_t lane = 0; lane < waveSize; ++lane) + { + auto* lane_a = reinterpret_cast(&a_per_lane[lane]); + + for(uint32_t bm = 0; bm < FragsM; ++bm) + { + for(uint32_t bk = 0; bk < FragsK; ++bk) + { + uint32_t frag_offset = (bm * FragsK + bk) * external_a_frag_vec_size; + + if constexpr(kCompressionRatio > 1) + { + // Sparse: fill external (uncompressed) vector + for(index_t ev = 0; ev < external_a_frag_vec_size; ++ev) + { + index_t compressed_v = ev / kCompressionRatio; + index_t sub_pos = ev % kCompressionRatio; + + auto coords = + ARegMap::calc_matrix_indices_from_lane_vector(lane, compressed_v); + uint32_t m_local = coords[0]; + uint32_t k_compressed = coords[1]; + uint32_t k_local = k_compressed * kCompressionRatio + sub_pos; + + uint32_t m_global = bm * FragM + m_local; + uint32_t k_global = bk * FragK + k_local; + + lane_a[frag_offset + ev] = + static_cast(A_matrix[m_global * K + k_global]); + } + } + else + { + // Dense/Scale: direct mapping + for(index_t v = 0; v < a_vec_size; ++v) + { + auto coords = ARegMap::calc_matrix_indices_from_lane_vector(lane, v); + uint32_t m_local = coords[0]; + uint32_t k_local = coords[1]; + + uint32_t m_global = bm * FragM + m_local; + uint32_t k_global = bk * FragK + k_local; + + lane_a[frag_offset + v] = + static_cast(A_matrix[m_global * K + k_global]); + } + } + } + } + } +} + +// Fill per-lane B fragments from logical B[K][N] matrix. +// BVecType = InternalBVecT[FragsN][FragsK] +template +void fill_b_fragments(typename Pipeline::BVecType* b_per_lane, + const std::vector& B_matrix, + uint32_t N, + uint32_t waveSize) +{ + using MmaOp = typename Pipeline::MmaOp; + using BRegMap = TileDistrEncRegMap::BWarpDstrEncoding>; + using BFragScalar = typename vector_traits::scalar_type; + + constexpr uint32_t FragN = Pipeline::FragN; + constexpr uint32_t FragK = Pipeline::FragK; + constexpr uint32_t FragsN = Pipeline::FragsN; + constexpr uint32_t FragsK = Pipeline::FragsK; + + constexpr index_t b_vec_size = BRegMap::num_vector_items; + + for(uint32_t lane = 0; lane < waveSize; ++lane) + { + auto* lane_b = reinterpret_cast(&b_per_lane[lane]); + + for(uint32_t bn = 0; bn < FragsN; ++bn) + { + for(uint32_t bk = 0; bk < FragsK; ++bk) + { + uint32_t frag_offset = (bn * FragsK + bk) * b_vec_size; + + for(index_t v = 0; v < b_vec_size; ++v) + { + auto coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v); + uint32_t n_local = coords[0]; + uint32_t k_local = coords[1]; + + uint32_t n_global = bn * FragN + n_local; + uint32_t k_global = bk * FragK + k_local; + + // B matrix is stored as B[K][N] + lane_b[frag_offset + v] = + static_cast(B_matrix[k_global * N + n_global]); + } + } + } + } +} + +// Extract C matrix from per-lane C fragments. +// CVecType = InternalCVecT[FragsM][FragsN] +template +void extract_c_matrix(const typename Pipeline::CVecType* c_per_lane, + std::vector& C_matrix, + uint32_t N, + uint32_t waveSize) +{ + using MmaOp = typename Pipeline::MmaOp; + using CRegMap = TileDistrEncRegMap::CWarpDstrEncoding>; + using CFragScalar = typename vector_traits::scalar_type; + + constexpr uint32_t FragM = Pipeline::FragM; + constexpr uint32_t FragN = Pipeline::FragN; + constexpr uint32_t FragsM = Pipeline::FragsM; + constexpr uint32_t FragsN = Pipeline::FragsN; + + constexpr index_t c_vec_size = CRegMap::num_vector_items; + + for(uint32_t lane = 0; lane < waveSize; ++lane) + { + auto* lane_c = reinterpret_cast(&c_per_lane[lane]); + + for(uint32_t bm = 0; bm < FragsM; ++bm) + { + for(uint32_t bn = 0; bn < FragsN; ++bn) + { + uint32_t frag_offset = (bm * FragsN + bn) * c_vec_size; + + for(index_t v = 0; v < c_vec_size; ++v) + { + auto coords = CRegMap::calc_matrix_indices_from_lane_vector(lane, v); + uint32_t m_local = coords[0]; + uint32_t n_local = coords[1]; + + uint32_t m_global = bm * FragM + m_local; + uint32_t n_global = bn * FragN + n_local; + + C_matrix[m_global * N + n_global] = + static_cast(lane_c[frag_offset + v]); + } + } + } + } +} + +/// Internal: runs the test with a fully resolved Pipeline type. +/// Called from run_pipeline_matrix_test after dispatching on compiler target. +template +void run_pipeline_matrix_test_impl(uint32_t M, + uint32_t N, + uint32_t K, + uint32_t waveSize, + KernelType kernel, + bool isSparse, + bool transposeExpected = false, + float referenceScale = 1.0f) +{ + std::vector A_matrix(M * K); + std::vector B_matrix(K * N); + std::vector C_expected(M * N, static_cast(0)); + std::vector C_actual(M * N, static_cast(0)); + + for(uint32_t m = 0; m < M; ++m) + for(uint32_t k = 0; k < K; ++k) + A_matrix[m * K + k] = deterministic_value(m, k, K); + + for(uint32_t k = 0; k < K; ++k) + for(uint32_t n = 0; n < N; ++n) + B_matrix[k * N + n] = deterministic_value(k, n, N); + + if(isSparse) + { + apply_sparse_pattern(A_matrix, M, K); } - void - test_pipeline(std::function shouldSkip, - std::function kernel, - std::function getExpected, - std::function aInitializer = nullptr) + reference_matmul(C_expected, A_matrix, B_matrix, M, N, K); + + using AVecType = typename Pipeline::AVecType; + using BVecType = typename Pipeline::BVecType; + using CVecType = typename Pipeline::CVecType; + + const size_t a_buf_size = waveSize * sizeof(AVecType); + const size_t b_buf_size = waveSize * sizeof(BVecType); + const size_t c_buf_size = waveSize * sizeof(CVecType); + + std::vector h_a(a_buf_size, 0); + std::vector h_b(b_buf_size, 0); + std::vector h_c(c_buf_size, 0); + + fill_a_fragments(reinterpret_cast(h_a.data()), A_matrix, K, waveSize); + fill_b_fragments(reinterpret_cast(h_b.data()), B_matrix, N, waveSize); + + void *d_a, *d_b, *d_c; + HIP_CHECK_ERROR(hipMalloc(&d_a, a_buf_size)); + HIP_CHECK_ERROR(hipMalloc(&d_b, b_buf_size)); + HIP_CHECK_ERROR(hipMalloc(&d_c, c_buf_size)); + + HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), a_buf_size, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), b_buf_size, hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemset(d_c, 0, c_buf_size)); + + ck_tile::launch_kernel(ck_tile::stream_config{}, + ck_tile::make_kernel(kernel, dim3(1), dim3(waveSize), 0, d_a, d_b, d_c)); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + HIP_CHECK_ERROR(hipMemcpy(h_c.data(), d_c, c_buf_size, hipMemcpyDeviceToHost)); + extract_c_matrix( + reinterpret_cast(h_c.data()), C_actual, N, waveSize); + + for(uint32_t m = 0; m < M; ++m) { - using namespace ck_tile; - using namespace ck_tile::core::arch; - - int devCount; - hipDevice_t dev; - HIP_CHECK_ERROR(hipGetDevice(&dev)); - HIP_CHECK_ERROR(hipGetDeviceCount(&devCount)); - - hipDeviceProp_t devProp; - HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev)); - - auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName); - bool hasDevice = static_cast(devCount > 0); - int deviceWarpSize = devProp.warpSize; - - if(!hasDevice || shouldSkip(currentArchId)) + for(uint32_t n = 0; n < N; ++n) { - GTEST_SKIP() << "No HIP device found. Skipping test."; + // When transposeExpected is true, the kernel computes C^T via SwapAB, + // so compare actual C[m][n] against reference C[n][m]. + constexpr float relative_tolerance = 1e-2f; + constexpr float absolute_tolerance = 1e-3f; + + float expected = transposeExpected ? static_cast(C_expected[n * M + m]) + : static_cast(C_expected[m * N + n]); + expected *= referenceScale; + float actual = static_cast(C_actual[m * N + n]); + EXPECT_NEAR( + actual, expected, std::abs(expected) * relative_tolerance + absolute_tolerance) + << "Mismatch at C[" << m << "][" << n << "]"; } - - // WaveTile size, also the expected fragment size (MmaTile) from the selector. - // Note: Actual FragK might be slightly different due to hardware implementation, but the - // test_accum_over_k kernel will loop over the K dimension to ensure that the total K is - // correct. - static constexpr uint32_t FragM = WaveTileM; - static constexpr uint32_t FragN = WaveTileN; - static constexpr uint32_t FragK = WaveTileK; - - // The number of elements per thread - uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits::PackedSize; - uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits::PackedSize; - uint32_t CElements = FragM * FragN / deviceWarpSize; - - uint32_t ASize = AElements * sizeof(AType); - uint32_t BSize = BElements * sizeof(BType); - uint32_t CSize = CElements * sizeof(CType); - uint32_t ScaleASize = 1 * sizeof(ScaleAType); - uint32_t ScaleBSize = 1 * sizeof(ScaleBType); - - // Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's - std::vector h_a(AElements); - if(aInitializer) - { - for(size_t i = 0; i < AElements; ++i) - h_a[i] = aInitializer(i); - } - else - { - std::fill(h_a.begin(), h_a.end(), type_convert(1.0f)); - } - std::vector h_b(BElements, type_convert(1.0f)); - std::vector h_c(CElements, type_convert(0.0f)); - std::vector h_out(CElements, type_convert(0.0f)); - // The actual scale is computed as pow(2, scale - 127), so: - // 126 -> 2^-1 and 129 -> 2^2. - ScaleAType h_scale_a = 126; - ScaleBType h_scale_b = 129; - - AType* d_a; - BType* d_b; - CType* d_c; - CType* d_out; - ScaleAType* d_scale_a; - ScaleBType* d_scale_b; - - HIP_CHECK_ERROR(hipMalloc(&d_a, ASize)); - HIP_CHECK_ERROR(hipMalloc(&d_b, BSize)); - HIP_CHECK_ERROR(hipMalloc(&d_c, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_out, CSize)); - HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize)); - HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize)); - - // Copy inputs to device - HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice)); - HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice)); - - const auto wave_size = getDeviceWaveSize(); - kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b); - HIP_CHECK_ERROR(hipDeviceSynchronize()); - - HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost)); - - // Verify output against expected value for all elements - for(size_t i = 0; i < CElements; ++i) - { - EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3); - } - - HIP_CHECK_ERROR(hipFree(d_a)); - HIP_CHECK_ERROR(hipFree(d_b)); - HIP_CHECK_ERROR(hipFree(d_c)); - HIP_CHECK_ERROR(hipFree(d_out)); - HIP_CHECK_ERROR(hipFree(d_scale_a)); - HIP_CHECK_ERROR(hipFree(d_scale_b)); } -}; + + HIP_CHECK_ERROR(hipFree(d_a)); + HIP_CHECK_ERROR(hipFree(d_b)); + HIP_CHECK_ERROR(hipFree(d_c)); +} + +/// @tparam PipelineFactory A template template that, given a CompilerTarget type, produces +/// the Pipeline type: PipelineFactory::type +/// @tparam KernelType Kernel functor struct with kBlockSize and __device__ operator() +/// @tparam AScalar Scalar type for A matrix (e.g., fp16_t) +/// @tparam BScalar Scalar type for B matrix (e.g., fp16_t) +/// @tparam CScalar Scalar type for C matrix (e.g., fp32_t) +/// @param M WaveTile M dimension +/// @param N WaveTile N dimension +/// @param K WaveTile K dimension +/// @param shouldSkip Predicate returning true if current device should skip +/// @param kernel Kernel functor instance to launch via make_kernel +/// @param isSparse Whether to apply 2:4 sparsity pattern to A +/// @param transposeExpected When true, compare against transposed reference (for +/// SwapAB/TransposeC) +/// @param referenceScale Scalar multiplier applied to the reference matmul result before +/// comparison (e.g., to account for scale-MMA scaling factors) +template