[rocm-libraries] ROCm/rocm-libraries#6212 (commit ccee58d)

=?UTF-8?q?[CK=20TILE]=20Unification=20Work=20=E2=80=93=20?=
 =?UTF-8?q?More=20accurate=20tests=20for=20MmaPipelines=20(#6212)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

This PR solves several issues:

#### More accurate tests for MmaPipelines

The current tests for the MmaPipelines (test_amdgcn_sparse_mma,
test_amdgcn_wavewise_mma) use explicit input fragment vectors filled
with 1s, and only check the output of a single lane. We should have
tests that actually use the MmaPipelines with non-trivial input matrices
and verify the complete output.
Some other aspects of the current MmaPipelines tests that I noticed and
deserve some attention:

1. There is sometimes iteration over K outside of the pipeline, which is
then included in WaveTileK or FragK, which is not correct. We should
remove it, move K iteration inside of the pipeline, or be more clear
about this outer-K loop size and how it propagates downwards.
2. There is very tight coupling between the kernel, gtest code, and
test_pipeline helper, requiring a lot of information and functions to be
passed back and forth.
3. The test_pipeline helper is doing a bunch of register-related logic
on the host (related to point 1)
4. Without this register logic the only thing it does is check the
device, call the kernel, and check the output, but with a lot of
boilerplate.

#### Test helper for detecting target arch at HOST runtime

There is a really apparent issue we faced while writing tests:

Scenario:
1. Compile a test that supports both gfx950 and gfx1201 for gfx950
2. Run the test on a server that only has gfx1201 GPU

Actual:
Segmentation fault

Expected:
The test can correctly detect from HOST runtime that the DEVICE
target_id was different and skips the test.

Notes:

The only way of detecting the COMPILER_TARGET_ID in the existing "arch"
framework is launching a kernel and calling `get_compiler_target()` (so,
from a DEVICE code). This will create a segmentation fault if the
current arch differs from the target arch. To cope with this issue, we
propose to export the compiler target(s) (note they can be many) through
`projects/composablekernel/test/ck_tile/core/arch/CMakeLists.txt` and
define a test helper to deal with such cases.

#### Add composition support to Transforms

We have a small number of Transforms which act on MmaOp input and output
data, before and after the MmaOp call respectively. These are currently
implemented to work on an MmaTile level, but in theory they are also
supposed to work at a WaveTile level, i.e. after composition of multiple
MmaTiles to create larger effective MNK dimensions. Currently the
composed MmaTiles look like 2D C-style arrays of the individual MmaTile
level register vectors (see WaveWiseMmaPipeline). The transforms should
be able to take these and perform the proper transforms to the whole
WaveTile at once. This might allow for better performing
transformations.

Note: This PR handles the SparseTransform case and if we don't end up
doing scale as a transformation, there isn't really much left to do. If
we end up having only the sparse transform as a non-trivial transform,
then we could also consider removing the Transform framework.
This commit is contained in:
chris-tsiaousis-hpc
2026-06-03 14:35:18 +00:00
committed by assistant-librarian[bot]
parent 88f8d24c34
commit db05d61136
20 changed files with 1646 additions and 589 deletions

View File

@@ -5,7 +5,7 @@
#include <type_traits>
#include <tuple>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/container/tuple.hpp"

View File

@@ -24,6 +24,7 @@
#include "ck_tile/core/arch/mma/mfma/mfma_selector.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_traits.hpp"
#include "ck_tile/core/arch/mma/mfma/mfma_transforms.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"

View File

@@ -393,10 +393,3 @@ CK_TILE_HOST_DEVICE void print(amdgcn_mma<ADataType,
#if __clang_major__ >= 23
#pragma clang diagnostic pop
#endif
// Include the implementations
#include "wmma/wmma.hpp" // should be included before the below headers
#include "mfma/mfma.hpp"
#include "scale/scale.hpp"
#include "sparse/sparse.hpp"

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_transforms.hpp"
namespace ck_tile::core::arch::mma {

View File

@@ -0,0 +1,8 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "wmma/wmma.hpp"
#include "mfma/mfma.hpp"
#include "sparse/sparse.hpp"

View File

@@ -270,12 +270,20 @@ struct MmaPipelineBase
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
std::forward<VecTC>(accum));
constexpr bool swap_a_and_b = hasFlag<MmaPipelineOptionFlag::ABSwap>();
auto transformed_inputs = [&]() {
if constexpr(swap_a_and_b)
{
return applyTransformsToInputs(
std::forward<VecTB>(b), std::forward<VecTA>(a), std::forward<VecTC>(accum));
}
else
{
return applyTransformsToInputs(
std::forward<VecTA>(a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
}
}();
Derived::execImpl(transformed_inputs);
@@ -302,17 +310,18 @@ struct MmaPipelineBase
{
if constexpr(MmaOpTraits<typename Derived::MmaOp>::IsSupported)
{
// TODO: c++20: Call template functions with MmaPipelineOptionFlags directly
static_assert(MmaOpTraits<typename Derived::MmaOp>::IsScale,
"This exec variant is intended for scale policy structs");
constexpr bool swap_a_and_b = hasFlag<MmaPipelineOptionFlag::ABSwap>();
auto transformed_inputs = applyTransformsToInputs(
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTB>(b)
: std::forward<VecTA>(a),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<VecTA>(a)
: std::forward<VecTB>(b),
swap_a_and_b ? std::forward<VecTB>(b) : std::forward<VecTA>(a),
swap_a_and_b ? std::forward<VecTA>(a) : std::forward<VecTB>(b),
std::forward<VecTC>(accum),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleBDataType>(scale_B)
: std::forward<ScaleADataType>(scale_A),
hasFlag<MmaPipelineOptionFlag::ABSwap>() ? std::forward<ScaleADataType>(scale_A)
: std::forward<ScaleBDataType>(scale_B));
swap_a_and_b ? std::forward<ScaleBDataType>(scale_B)
: std::forward<ScaleADataType>(scale_A),
swap_a_and_b ? std::forward<ScaleADataType>(scale_A)
: std::forward<ScaleBDataType>(scale_B));
Derived::execImpl(transformed_inputs);

View File

@@ -169,7 +169,7 @@ struct WaveWiseMmaPipeline : public MmaPipelineBase<dense::wavewise::detail::get
}
else
{
static_assert(false);
static_assert(false, "Invalid accumulation policy");
}
}
};

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/scale/scale_selector.hpp"
#include "ck_tile/core/arch/mma/scale/scale_transforms.hpp"
#include "ck_tile/core/config.hpp"
@@ -16,12 +17,32 @@
namespace ck_tile::core::arch::mma {
/**
* @class ScaleMmaPipeline
* @brief Driver for the wave-tile scale Mma operation. Given a backend MmaOp implementation
* (e.g., scale MFMA), this class performs fragment-wise (MmaTile) decomposition to
* matrix-multiply input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and
* accumulates results into output WaveTile (C: WaveTileM x WaveTileN).
* Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates
* internally over FragsM x FragsN x FragsK, passing per-wave scale factors to each fragment call.
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
std::uint32_t FragM,
std::uint32_t FragN,
std::uint32_t FragK,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
@@ -30,37 +51,55 @@ template <typename ADataType,
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
MmaOpFamily::SCALE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<static_cast<int>(MmaPipelineOptionFlag::NONE), ScaleMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Expose caller-side vector types
using AVecType = typename MmaOp::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Fragment dimensions (from the hardware MmaOp)
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Expose internal vector types
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
// Vector types for packed registers in each fragment
using InternalAVecT = typename MmaOp::AVecType;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles
using AVecType = InternalAVecT[FragsM][FragsK];
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename VecTA,
typename VecTB,
typename VecTC,
@@ -69,8 +108,40 @@ struct ScaleMmaPipeline : public MmaPipelineBase<static_cast<int>(MmaPipelineOpt
CK_TILE_DEVICE static void
execImpl(std::tuple<VecTA, VecTB, VecTC, ScaleADataType, ScaleBDataType>& vecs)
{
auto& [a_vec, b_vec, c_vec, scale_A, scale_B] = vecs;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, scale_A, scale_B);
auto& [a_frag, b_frag, c_frag, scale_A, scale_B] = vecs;
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B);
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn], scale_A, scale_B);
}
}
}
}
else
{
static_assert(false, "Invalid accumulation policy");
}
}
};

View File

@@ -5,10 +5,10 @@
#include "ck_tile/core/arch/mma/mma_pipeline.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_transforms.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include <cstdint>
#include <type_traits>
namespace ck_tile::core::arch::mma {
@@ -20,12 +20,33 @@ constexpr inline int getPipelineFlags()
}
} // namespace sparse::detail
/**
* @class SparseMmaPipeline
* @brief Driver for the wave-tile sparse Mma operation. Given a backend MmaOp implementation
* (e.g., smfmac), this class performs fragment-wise (MmaTile) decomposition to matrix-multiply
* input WaveTiles of (A: WaveTileM x WaveTileK) x (B: WaveTileK x WaveTileN) and accumulates
* results into output WaveTile (C: WaveTileM x WaveTileN).
* Like WaveWiseMmaPipeline, this decomposes WaveTile dimensions into fragments and iterates
* internally over FragsM × FragsN × FragsK. The A operand is provided in uncompressed form;
* 2:4 structured sparsity compression (SparseCompressTransform) is applied.
* @tparam ADataType Data type of input WaveTile A
* @tparam BDataType Data type of input WaveTile B
* @tparam CDataType Data type of input/output WaveTile C (accumulator)
* @tparam WaveTileM Mma WaveTile M dimension
* @tparam WaveTileN Mma WaveTile N dimension
* @tparam WaveTileK Mma WaveTile K dimension
* @tparam AccumPolicy The fragment order of the accum. registers (row or col major frag order)
* @tparam CompilerTarget The compiler target
* @tparam MmaOp_ Backend wrapper class that will perform the mma op
* @tparam MmaTransforms The set of transforms to be applied to input/output WaveTiles
*/
template <typename ADataType,
typename BDataType,
typename CDataType,
uint32_t FragM,
uint32_t FragN,
uint32_t FragK,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
typename CompilerTarget =
decltype(get_compiler_target()), // TODO: c++20 amdgcn_target_arch_id GfxTargetId =
// get_compiler_target(),
@@ -34,17 +55,17 @@ template <typename ADataType,
// MmaDefaultSelector<ADataType,
BDataType,
CDataType,
FragM,
FragN,
FragK,
WaveTileM,
WaveTileN,
WaveTileK,
CompilerTarget,
MmaOpFamily::SPARSE>::SelectedOp,
typename MmaTransforms = // TODO: c++20 MmaTransformsI MmaTransforms =
typename MmaTransformsDefaultSelector<MmaOp_, CompilerTarget>::SelectedTransforms>
// clang-format off
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>
struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>
{
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, FragM, FragN, FragK, CompilerTarget, MmaOp_, MmaTransforms>>;
using Base = MmaPipelineBase<sparse::detail::getPipelineFlags(), SparseMmaPipeline<ADataType, BDataType, CDataType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy, CompilerTarget, MmaOp_, MmaTransforms>>;
// clang-format on
static_assert(!Base::template hasFlag<MmaPipelineOptionFlag::ABSwap>(),
@@ -52,48 +73,153 @@ struct SparseMmaPipeline : public MmaPipelineBase<sparse::detail::getPipelineFla
using MmaOp = MmaOp_; // Expose the selected MmaOp
// Calculate the uncompressed A vector type
// Fragment dimensions (from the hardware MmaOp)
constexpr static uint32_t FragM = MmaOp::kM;
constexpr static uint32_t FragN = MmaOp::kN;
constexpr static uint32_t FragK = MmaOp::kK;
// Fragment counts for decomposition
constexpr static uint32_t FragsM = WaveTileM / FragM;
constexpr static uint32_t FragsN = WaveTileN / FragN;
constexpr static uint32_t FragsK = WaveTileK / FragK;
// Calculate the uncompressed external A per-fragment vector type
struct ExternalAVecCalculator
{
using AVecTraits = vector_traits<typename MmaOp::AVecType>;
static constexpr index_t ASize = AVecTraits::vector_size * MmaOp::kCompressionRatio;
using AVecType = ext_vector_t<typename AVecTraits::scalar_type, ASize>;
};
using ExternalAFragVecT = typename ExternalAVecCalculator::AVecType;
// Expose caller-side vector types
using AVecType = typename ExternalAVecCalculator::AVecType;
using BVecType = typename MmaOp::BVecType;
using CVecType = typename MmaOp::CVecType;
// Scalar type of A
using AScalarT = typename ExternalAVecCalculator::AVecTraits::scalar_type;
// Expose internal vector types
using InternalAVecT = typename MmaOp::AVecType;
// Per-fragment sizes
static constexpr uint32_t ExternalAFragSize = ExternalAVecCalculator::ASize;
static constexpr uint32_t InternalAFragSize =
vector_traits<typename MmaOp::AVecType>::vector_size;
// Full wave-tile sizes (all fragments combined)
static constexpr uint32_t TotalUncompressedElems = FragsM * FragsK * ExternalAFragSize;
static constexpr uint32_t TotalCompressedElems =
TotalUncompressedElems / MmaOp::kCompressionRatio;
// Variable-length idx type for the whole wave-tile (spans multiple int32_t words if needed)
static constexpr index_t IdxNumWords = sparse::detail::idx_words_needed<TotalCompressedElems>;
using IdxType = sparse::detail::SparseIdxPack<IdxNumWords>;
// Per-fragment compressed vector type (for individual MmaOp::exec calls)
using FragAVecT = typename MmaOp::AVecType;
// Internal vector types used by the base class formatBuffer.
// InternalAVecT matches the full compressed wave-tile so the base class can
// format the SparseCompressTransform result via formatBuffer<InternalAVecT>.
using InternalAVecT = ext_vector_t<AScalarT, TotalCompressedElems>;
using InternalBVecT = typename MmaOp::BVecType;
using InternalCVecT = typename MmaOp::CVecType;
// Buffer types for WaveTiles (caller-facing).
// A is a single flat uncompressed vector covering the whole wave-tile.
// The base class compresses it in one pass via ATransform.
using AVecType = ext_vector_t<AScalarT, TotalUncompressedElems>;
using BVecType = InternalBVecT[FragsN][FragsK];
using CVecType = InternalCVecT[FragsM][FragsN];
// Transforms
using ATransform = typename MmaTransforms::ATransform;
using BTransform = typename MmaTransforms::BTransform;
using CTransform = typename MmaTransforms::CTransform;
using DTransform = typename MmaTransforms::DTransform;
// Sanity checks
static_assert(WaveTileM >= FragM, "WaveTileM must be >= FragM");
static_assert(WaveTileN >= FragN, "WaveTileN must be >= FragN");
static_assert(WaveTileK >= FragK, "WaveTileK must be >= FragK");
static_assert(WaveTileM % FragM == 0u, "WaveTileM must be a multiple of FragM");
static_assert(WaveTileN % FragN == 0u, "WaveTileN must be a multiple of FragN");
static_assert(WaveTileK % FragK == 0u, "WaveTileK must be a multiple of FragK");
template <typename ATransformResult, typename BTransformResult, typename CTransformResult>
CK_TILE_DEVICE static void
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& vecs)
execImpl(std::tuple<ATransformResult, BTransformResult, CTransformResult>& transformedInputs)
{
auto& [a, b_frag, c_frag] = transformedInputs;
auto& [a_compressed_whole, idx] = a;
// Validate that the ATransform result and per-fragment reinterpretation are correct
checkATransformResult<ATransformResult>();
auto& [a_result, b_vec, c_vec] = vecs;
auto& [a_vec, idx] = a_result;
c_vec = MmaOp::exec(a_vec, b_vec, c_vec, idx);
// Reinterpret the full compressed vector as per-fragment arrays
auto* a_frags = ck_tile::bit_cast<FragAVecT(*)[FragsK]>(&a_compressed_whole);
// Accumulation loop with per-fragment idx extraction
if constexpr(AccumPolicy == MmaAccumPolicy::ROW_MAJOR)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frags[bm][bk],
b_frag[bn][bk],
c_frag[bm][bn],
sparse::detail::extract_fragment_idx<InternalAFragSize, FragsK>(
idx, bm, bk));
}
}
}
}
else if constexpr(AccumPolicy == MmaAccumPolicy::COL_MAJOR)
{
for(uint32_t bn = 0u; bn < FragsN; ++bn)
{
for(uint32_t bm = 0u; bm < FragsM; ++bm)
{
for(uint32_t bk = 0u; bk < FragsK; ++bk)
{
c_frag[bm][bn] = MmaOp::exec(
a_frags[bm][bk],
b_frag[bn][bk],
c_frag[bm][bn],
sparse::detail::extract_fragment_idx<InternalAFragSize, FragsK>(
idx, bm, bk));
}
}
}
}
else
{
static_assert(false, "Invalid accumulation policy");
}
}
private:
// Type check helper - not a device function, so std::declval is available
// Compile-time validation of ATransform result and per-fragment reinterpretation.
// Ensures the compressed vector returned by ATransform::exec can be safely
// reinterpreted as FragAVecT[FragsM][FragsK] for per-fragment MmaOp dispatch.
template <typename ATransformResult>
static constexpr void checkATransformResult()
{
using ExternalAvecRef = std::add_lvalue_reference_t<AVecType>;
static_assert(std::is_same_v<ATransformResult,
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>);
decltype(ATransform::exec(std::declval<ExternalAvecRef>()))>,
"ATransformResult must match the return type of ATransform::exec");
using CompressedVecType =
std::remove_reference_t<std::tuple_element_t<0, ATransformResult>>;
static_assert(sizeof(CompressedVecType) == sizeof(FragAVecT) * FragsM * FragsK,
"Compressed A vector size must equal sizeof(FragAVecT[FragsM][FragsK])");
static_assert(alignof(CompressedVecType) >= alignof(FragAVecT),
"Compressed vector alignment must be >= FragAVecT alignment "
"for safe reinterpret_cast to per-fragment array");
using ActualIdxType = std::tuple_element_t<1, ATransformResult>;
static_assert(std::is_same_v<ActualIdxType, IdxType>,
"Sparsity index type must match SparseIdxPack<IdxNumWords>");
}
};

View File

@@ -13,6 +13,30 @@
namespace ck_tile::core::arch::mma {
namespace sparse::detail {
/// Number of int32_t words needed to store CompressedSize 2-bit idx fields.
template <index_t CompressedSize>
static constexpr index_t idx_words_needed = (CompressedSize * 2 + 31) / 32;
/**
* @class SparseIdxPack
* @brief Variable-length container for 2:4 structured sparsity index metadata.
*
* Each compressed element produces a 2-bit index field encoding the original
* position (03) within its group of 4. When composing multiple MMA fragments
* in M and K dimensions within a WaveTile, the total number of index bits can
* exceed 32. This struct packs the index fields into an array of int32_t words,
* sized at compile time.
*
* @tparam NumWords Number of int32_t words needed to store all index fields.
*/
template <index_t NumWords>
struct SparseIdxPack
{
static_assert(NumWords > 0, "SparseIdxPack requires at least 1 word");
int32_t words[NumWords] = {};
};
/**
* @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero
* elements into lower part of a_vec to half its effective size.
@@ -20,21 +44,29 @@ namespace sparse::detail {
* @tparam ADataType The data type of a_vec
* @tparam CompressedSize The target compression size
* @tparam AVec The vector type of a_vec (deduced)
* @return Packed 32bit word containing **CompressedSize** 2bit fields.
* Each field encodes the original position (03) of the corresponding
* nonzero element in the input. If fewer than CompressedSize
* nonzeros are found, remaining fields default to 2 (see below).
* @return SparseIdxPack containing **CompressedSize** 2bit fields packed
* across one or more int32_t words. Each field encodes the original
* position (03) of the corresponding nonzero element in the input.
* If fewer than CompressedSize nonzeros are found, remaining fields
* default to 2 (see below).
*/
template <typename ADataType, index_t CompressedSize, typename AVec>
static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
static CK_TILE_DEVICE auto compress_a_impl(AVec& a_vec)
{
// idx holds one 2bit index per output element (total CompressedSize entries).
static constexpr index_t NumIdxWords = idx_words_needed<CompressedSize>;
// idx holds one 2bit index per output element (total CompressedSize entries),
// packed across NumIdxWords int32_t words.
// It is initialized to the pattern 0b10 for every field. This matches
// what the hardware expects when there are fewer than two nonzero values
// in a 4element group the unused output is treated as coming from slot 2.
// The loop below will clear and set each field as real nonzeros are seen.
int32_t idx = 0;
static_for<0, CompressedSize, 1>{}([&](auto k) { idx |= (2u << (2u * k)); });
SparseIdxPack<NumIdxWords> idx{};
static_for<0, CompressedSize, 1>{}([&](auto k) {
constexpr uint32_t bit_pos = static_cast<uint32_t>(k) * 2u;
constexpr uint32_t word = bit_pos / 32u;
constexpr uint32_t shift = bit_pos % 32u;
idx.words[word] |= static_cast<int32_t>(2u << shift);
});
static_for<0, CompressedSize / 2, 1>{}([&](auto i) {
ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]};
@@ -45,8 +77,13 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
{
nonzero_elems[non_zero_pos] = a_vec[i * 4 + j];
// clear the twobit field for this output and insert j
idx &= ~(0b11u << (2u * (i * 2 + non_zero_pos)));
idx |= static_cast<uint32_t>(j) << (2u * (i * 2 + non_zero_pos));
const uint32_t field_idx =
static_cast<uint32_t>(i) * 2u + static_cast<uint32_t>(non_zero_pos);
const uint32_t bit_pos = field_idx * 2u;
const uint32_t word = bit_pos / 32u;
const uint32_t shift = bit_pos % 32u;
idx.words[word] &= ~static_cast<int32_t>(0b11u << shift);
idx.words[word] |= static_cast<int32_t>(static_cast<uint32_t>(j) << shift);
++non_zero_pos;
}
});
@@ -56,6 +93,40 @@ static CK_TILE_DEVICE int32_t compress_a_impl(AVec& a_vec)
return idx;
}
/**
* @brief Extract the per-fragment sparsity index from a packed idx pack.
* After whole-wave-tile compression, the returned idx packs 2-bit fields for
* every compressed output element across one or more int32_t words.
* @return A single int32_t with this fragment's 2-bit fields at the
* least-significant positions, suitable for passing to the MMA builtin.
*/
template <uint32_t FragCompressedSize, uint32_t FragsK, index_t NumIdxWords>
static CK_TILE_DEVICE int32_t extract_fragment_idx(const SparseIdxPack<NumIdxWords>& idx,
uint32_t m,
uint32_t k)
{
static constexpr uint32_t IdxBitsPerFrag = FragCompressedSize * 2;
const auto fragLinearIdx = m * FragsK + k;
const auto totalBitOffset = fragLinearIdx * IdxBitsPerFrag;
const auto wordIdx = totalBitOffset / 32u;
const auto bitInWord = totalBitOffset % 32u;
uint32_t result = static_cast<uint32_t>(idx.words[wordIdx]) >> bitInWord;
// If fragment bits span a word boundary, stitch in bits from the next word.
// (This is a safety measure; it should not occur when IdxBitsPerFrag is a
// power-of-2 divisor of 32, which is always the case for current MMA ops.)
if constexpr(NumIdxWords > 1)
{
if(bitInWord != 0 && bitInWord + IdxBitsPerFrag > 32u)
{
result |= static_cast<uint32_t>(idx.words[wordIdx + 1]) << (32u - bitInWord);
}
}
return static_cast<int32_t>(result);
}
} // namespace sparse::detail
/**
@@ -75,15 +146,15 @@ struct SparseCompressTransform
static constexpr auto VecN = VecTraits::vector_size;
static constexpr index_t CompressedSize = VecN / CompressionRatio;
using VecCompressed = ext_vector_t<ScalarT, CompressedSize>;
using IdxType =
sparse::detail::SparseIdxPack<sparse::detail::idx_words_needed<CompressedSize>>;
static_assert(VecN % CompressionRatio == 0, "VecN must be divisible by CompressionRatio");
static_assert(CompressedSize > 0, "CompressedSize must be > 0");
const auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
auto idx = sparse::detail::compress_a_impl<ScalarT, CompressedSize>(v);
// TODO c++20: Use bit_cast
return std::tuple<VecCompressed&, int32_t>(
*std::launder(reinterpret_cast<VecCompressed*>(&v)), idx);
return std::tuple<VecCompressed&, IdxType>(*ck_tile::bit_cast<VecCompressed*>(&v), idx);
}
};

View File

@@ -7,19 +7,109 @@ if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx120")
add_gtest_executable(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
# ---------------------------------------------------------------------------
# Map GPU target strings to hex amdgcn_target_id values (arch.hpp).
# Builds a -DCK_CMAKE_GPU_TARGET_IDS=0xHHHH,... definition that host-side
# test code can consume without launching a device kernel.
# ---------------------------------------------------------------------------
function(_ck_gpu_target_string_to_id TARGET_STR OUT_VAR)
string(TOLOWER "${TARGET_STR}" _tgt)
string(REGEX REPLACE ":.*" "" _tgt "${_tgt}")
# GFX9
if(_tgt STREQUAL "gfx908")
set(${OUT_VAR} "0x0908" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx90a")
set(${OUT_VAR} "0x090A" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx942")
set(${OUT_VAR} "0x0942" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx950")
set(${OUT_VAR} "0x0950" PARENT_SCOPE)
# GFX10.3
elseif(_tgt STREQUAL "gfx1030")
set(${OUT_VAR} "0x1030" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1031")
set(${OUT_VAR} "0x1031" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1032")
set(${OUT_VAR} "0x1032" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1033")
set(${OUT_VAR} "0x1033" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1034")
set(${OUT_VAR} "0x1034" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1035")
set(${OUT_VAR} "0x1035" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1036")
set(${OUT_VAR} "0x1036" PARENT_SCOPE)
elseif(_tgt MATCHES "^gfx10-3-generic$")
set(${OUT_VAR} "0x103F" PARENT_SCOPE)
# GFX11
elseif(_tgt STREQUAL "gfx1100")
set(${OUT_VAR} "0x1100" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1101")
set(${OUT_VAR} "0x1101" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1102")
set(${OUT_VAR} "0x1102" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1103")
set(${OUT_VAR} "0x1103" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1150")
set(${OUT_VAR} "0x1150" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1151")
set(${OUT_VAR} "0x1151" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1152")
set(${OUT_VAR} "0x1152" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1153")
set(${OUT_VAR} "0x1153" PARENT_SCOPE)
elseif(_tgt MATCHES "^gfx11-generic$")
set(${OUT_VAR} "0x11FF" PARENT_SCOPE)
# GFX12
elseif(_tgt STREQUAL "gfx1200")
set(${OUT_VAR} "0x1200" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1201")
set(${OUT_VAR} "0x1201" PARENT_SCOPE)
elseif(_tgt MATCHES "^gfx12-generic$")
set(${OUT_VAR} "0x12FF" PARENT_SCOPE)
elseif(_tgt STREQUAL "gfx1250")
set(${OUT_VAR} "0x1250" PARENT_SCOPE)
else()
message(WARNING "_ck_gpu_target_string_to_id: unknown GPU target '${TARGET_STR}', skipping")
set(${OUT_VAR} "" PARENT_SCOPE)
endif()
endfunction()
function(_ck_add_gpu_target_ids_define TARGET_NAME)
get_property(_archs TARGET ${TARGET_NAME} PROPERTY HIP_ARCHITECTURES)
string(REPLACE "," ";" _archs "${_archs}")
set(_hex_ids)
foreach(_tgt IN LISTS _archs)
_ck_gpu_target_string_to_id("${_tgt}" _hex)
if(_hex AND NOT _hex STREQUAL "0x0000")
list(APPEND _hex_ids "${_hex}")
endif()
endforeach()
list(JOIN _hex_ids "," _hex_str)
if(_hex_str)
target_compile_definitions(${TARGET_NAME} PRIVATE "CK_CMAKE_GPU_TARGET_IDS=${_hex_str}")
endif()
endfunction()
# Convenience: add_gtest_executable + inject CK_CMAKE_GPU_TARGET_IDS
macro(_add_mma_gtest TEST_NAME)
add_gtest_executable(${TEST_NAME} ${ARGN})
_ck_add_gpu_target_ids_define(${TEST_NAME})
endmacro()
# ---------------------------------------------------------------------------
if(GPU_TARGETS MATCHES "gfx9|gfx120")
_add_mma_gtest(test_amdgcn_sparse_mma pipeline/test_amdgcn_sparse_mma.cpp)
target_compile_options(test_amdgcn_sparse_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp)
_add_mma_gtest(test_amdgcn_scale_mma pipeline/test_amdgcn_scale_mma.cpp)
target_compile_options(test_amdgcn_scale_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma test_amdgcn_mma.cpp)
_add_mma_gtest(test_amdgcn_mma test_amdgcn_mma.cpp)
target_compile_options(test_amdgcn_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_gtest_executable(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp)
_add_mma_gtest(test_amdgcn_wavewise_mma pipeline/test_amdgcn_wavewise_mma.cpp)
target_compile_options(test_amdgcn_wavewise_mma PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile_gemm tests for current target")
@@ -37,45 +127,45 @@ macro(set_mma_test_arch_define target_name)
endmacro()
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx9 test_amdgcn_mma_layout_gfx9.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx9 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx9)
endif()
if(GPU_TARGETS MATCHES "gfx908|gfx90a")
add_gtest_executable(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx908_and_gfx90a test_amdgcn_mma_layout_gfx908_and_gfx90a.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx908_and_gfx90a PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx908_and_gfx90a)
endif()
if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
add_gtest_executable(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx90a_and_higher test_amdgcn_mma_layout_gfx90a_and_higher.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx90a_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx90a_and_higher)
endif()
if(GPU_TARGETS MATCHES "gfx942|gfx950")
add_gtest_executable(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx942_and_higher test_amdgcn_mma_layout_gfx942_and_higher.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx942_and_higher PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS} -Wno-header-hygiene)
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx942_and_higher)
endif()
if(GPU_TARGETS MATCHES "gfx950")
add_gtest_executable(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx950 test_amdgcn_mma_layout_gfx950.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx950 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
set_mma_test_arch_define(test_amdgcn_mma_layout_gfx950)
endif()
if(GPU_TARGETS MATCHES "gfx11")
add_gtest_executable(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx11 test_amdgcn_mma_layout_gfx11.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx11 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
if(GPU_TARGETS MATCHES "gfx120")
add_gtest_executable(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp)
_add_mma_gtest(test_amdgcn_mma_layout_gfx12 test_amdgcn_mma_layout_gfx12.cpp)
target_compile_options(test_amdgcn_mma_layout_gfx12 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
endif()
add_gtest_executable(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
_add_mma_gtest(test_amdgcn_mma_pipeline pipeline/test_amdgcn_mma_pipeline.cpp)
target_compile_options(test_amdgcn_mma_pipeline PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include <unordered_set>
namespace ck_tile::core::arch::testing {
static CK_TILE_HOST auto getCMakeGpuTargetIds()
{
using ck_tile::core::arch::amdgcn_target_id;
#ifdef CK_CMAKE_GPU_TARGET_IDS
constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS};
std::unordered_set<amdgcn_target_id> result;
for(auto id : ids)
result.insert(static_cast<amdgcn_target_id>(id));
return result;
#else
return std::unordered_set<amdgcn_target_id>{};
#endif
}
template <typename Func>
static CK_TILE_HOST bool dispatchCompilerTarget(ck_tile::core::arch::amdgcn_target_id id,
Func&& func)
{
using namespace ck_tile::core::arch;
// clang-format off
switch(id)
{
case amdgcn_target_id::GFX908: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX908>()); return true;
case amdgcn_target_id::GFX90A: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX90A>()); return true;
case amdgcn_target_id::GFX942: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX942>()); return true;
case amdgcn_target_id::GFX950: func(make_amdgcn_gfx9_target<amdgcn_target_id::GFX950>()); return true;
case amdgcn_target_id::GFX1030: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1030>()); return true;
case amdgcn_target_id::GFX1031: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1031>()); return true;
case amdgcn_target_id::GFX1032: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1032>()); return true;
case amdgcn_target_id::GFX1033: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1033>()); return true;
case amdgcn_target_id::GFX1034: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1034>()); return true;
case amdgcn_target_id::GFX1035: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1035>()); return true;
case amdgcn_target_id::GFX1036: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX1036>()); return true;
case amdgcn_target_id::GFX103_GENERIC: func(make_amdgcn_gfx10_3_target<amdgcn_target_id::GFX103_GENERIC>()); return true;
case amdgcn_target_id::GFX1100: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1100>()); return true;
case amdgcn_target_id::GFX1101: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1101>()); return true;
case amdgcn_target_id::GFX1102: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1102>()); return true;
case amdgcn_target_id::GFX1103: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1103>()); return true;
case amdgcn_target_id::GFX1150: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1150>()); return true;
case amdgcn_target_id::GFX1151: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1151>()); return true;
case amdgcn_target_id::GFX1152: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1152>()); return true;
case amdgcn_target_id::GFX1153: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX1153>()); return true;
case amdgcn_target_id::GFX11_GENERIC: func(make_amdgcn_gfx11_target<amdgcn_target_id::GFX11_GENERIC>()); return true;
case amdgcn_target_id::GFX1200: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1200>()); return true;
case amdgcn_target_id::GFX1201: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1201>()); return true;
case amdgcn_target_id::GFX12_GENERIC: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX12_GENERIC>()); return true;
case amdgcn_target_id::GFX1250: func(make_amdgcn_gfx12_target<amdgcn_target_id::GFX1250>()); return true;
case amdgcn_target_id::HOST: return false;
}
// clang-format on
__builtin_unreachable();
}
static CK_TILE_HOST constexpr int32_t getCMakeWaveSize()
{
using ck_tile::core::arch::amdgcn_target_id;
#ifdef CK_CMAKE_GPU_TARGET_IDS
constexpr uint32_t ids[] = {CK_CMAKE_GPU_TARGET_IDS};
constexpr index_t targets_size = sizeof(ids) / sizeof(ids[0]);
static_assert(targets_size > 0);
constexpr auto first_target_id = static_cast<amdgcn_target_id>(ids[0]);
if constexpr(first_target_id >= amdgcn_target_id::GFX908 &&
first_target_id <= amdgcn_target_id::GFX950)
{
return 64;
}
else
{
return 32;
}
#else
static_assert(false, "Configure CK_CMAKE_GPU_TARGET_IDS before calling this function.");
return 0;
#endif
}
} // namespace ck_tile::core::arch::testing

View File

@@ -1,34 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include <cstdio>
#include "ck_tile/core/arch/arch.hpp"
#include <hip/hip_runtime.h>
#include "ck_tile/host/hip_check_error.hpp"
namespace {
__global__ void getWaveSizeForSelectedOp(uint32_t* waveSize)
{
using CompilerTarget = decltype(ck_tile::core::arch::get_compiler_target());
if(waveSize)
*waveSize = static_cast<uint32_t>(CompilerTarget::WAVE_SIZE_ID);
}
static __host__ uint32_t getDeviceWaveSize()
{
uint32_t* d_wave_size;
HIP_CHECK_ERROR(hipMalloc(&d_wave_size, sizeof(uint32_t)));
getWaveSizeForSelectedOp<<<1, 64>>>(d_wave_size);
HIP_CHECK_ERROR(hipDeviceSynchronize());
uint32_t wave_size;
HIP_CHECK_ERROR(hipMemcpy(&wave_size, d_wave_size, sizeof(uint32_t), hipMemcpyDeviceToHost));
return wave_size;
}
} // namespace

View File

@@ -10,223 +10,421 @@
#include <gtest/gtest.h>
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_calculator.hpp"
#include "ck_tile/core/arch/mma/utility/tile_distribution_encoding_register_mapper.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
#include "../get_wave_size_helper.hpp"
#include "../get_cmake_targets_helper.hpp"
template <typename AType_ = ck_tile::fp16_t,
typename BType_ = ck_tile::fp16_t,
typename CType_ = ck_tile::fp32_t,
uint32_t WaveTileM_ = 16,
uint32_t WaveTileN_ = 16,
uint32_t WaveTileK_ = 32,
typename ScaleAType_ = int,
typename ScaleBType_ = int>
struct MmaPipelineTest
namespace mma_pipeline_test {
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using namespace ck_tile::core::arch::testing;
inline bool hipTargetMatchesCmakeTargets(amdgcn_target_id arch)
{
using AType = AType_;
using BType = BType_;
using CType = CType_;
using ScaleAType = ScaleAType_;
using ScaleBType = ScaleBType_;
static constexpr auto WaveTileM = WaveTileM_;
static constexpr auto WaveTileN = WaveTileN_;
static constexpr auto WaveTileK = WaveTileK_;
void test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
std::function<void(uint32_t, void*, void*, void*, void*)> kernel,
std::function<CType(uint32_t)> getExpected,
std::function<AType(size_t)> aInitializer = nullptr)
const auto cmake_targets = getCMakeGpuTargetIds();
if(cmake_targets.count(arch) == 0)
{
using namespace ck_tile;
using namespace ck_tile::core::arch;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
if(!hasDevice || shouldSkip(currentArchId))
// gfx12-generic and gfx11-generic make no difference with the specialized archs.
// Some CI pipelines make use of that and configure the project with the generic
// flags besides compiling for (f.e.) gfx1201.
if(arch >= amdgcn_target_id::GFX1200 && arch <= amdgcn_target_id::GFX12_GENERIC)
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
return (cmake_targets.count(amdgcn_target_id::GFX12_GENERIC) > 0);
}
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize;
uint32_t BElements = FragN * FragK / deviceWarpSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
std::vector<AType> h_a(AElements);
if(aInitializer)
else if(arch >= amdgcn_target_id::GFX1100 && arch <= amdgcn_target_id::GFX11_GENERIC)
{
for(size_t i = 0; i < AElements; ++i)
h_a[i] = aInitializer(i);
return (cmake_targets.count(amdgcn_target_id::GFX11_GENERIC) > 0);
}
else
}
return true;
}
template <typename CType, typename AType, typename BType>
void reference_matmul(std::vector<CType>& C,
const std::vector<AType>& A,
const std::vector<BType>& B,
uint32_t M,
uint32_t N,
uint32_t K)
{
for(uint32_t m = 0; m < M; ++m)
{
for(uint32_t n = 0; n < N; ++n)
{
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1));
float acc = 0.0f;
for(uint32_t k = 0; k < K; ++k)
{
acc += type_convert<float>(A[m * K + k]) * type_convert<float>(B[k * N + n]);
}
C[m * N + n] = static_cast<CType>(acc);
}
std::vector<BType> h_b(BElements, type_convert<BType>(1));
std::vector<CType> h_c(CElements, type_convert<CType>(0));
std::vector<CType> h_out(CElements, type_convert<CType>(0));
}
}
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
template <typename T>
T deterministic_value(uint32_t row, uint32_t col, uint32_t minor_dim)
{
float v = static_cast<float>((row * minor_dim + col) % 7 + 1) * 0.25f;
return type_convert<T>(v);
}
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
kernel(wave_size, d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Verify output against expected value for all elements
for(size_t i = 0; i < CElements; ++i)
// Apply 2:4 sparsity pattern to A matrix in-place (for sparse pipeline tests).
// Every group of 4 consecutive K elements keeps slots 0 and 2, zeros slots 1 and 3.
template <typename T>
void apply_sparse_pattern(std::vector<T>& A, uint32_t M, uint32_t K)
{
for(uint32_t m = 0; m < M; ++m)
{
for(uint32_t k = 0; k < K; k += 4)
{
EXPECT_NEAR(h_out[i], getExpected(FragK), 1e-3);
// Keep slots 0, 2. Zero out slots 1, 3.
if(k + 1 < K)
A[m * K + k + 1] = static_cast<T>(0);
if(k + 3 < K)
A[m * K + k + 3] = static_cast<T>(0);
}
}
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
// Fill per-lane A fragments from logical A[M][K] matrix.
// For dense pipelines: AVecType = InternalAVecT[FragsM][FragsK]
// For sparse pipelines: AVecType = ExternalAFragVecT[FragsM][FragsK] (uncompressed)
template <typename Pipeline, typename AScalar>
void fill_a_fragments(typename Pipeline::AVecType* a_per_lane,
const std::vector<AScalar>& A_matrix,
uint32_t K,
uint32_t waveSize)
{
using MmaOp = typename Pipeline::MmaOp;
using ARegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::AWarpDstrEncoding>;
using AFragScalar = typename vector_traits<typename MmaOp::AVecType>::scalar_type;
constexpr uint32_t FragM = Pipeline::FragM;
constexpr uint32_t FragK = Pipeline::FragK;
constexpr uint32_t FragsM = Pipeline::FragsM;
constexpr uint32_t FragsK = Pipeline::FragsK;
constexpr uint32_t kCompressionRatio = MmaOp::kCompressionRatio;
// The A register map maps (lane, vec_idx) -> (m_within_frag, k_within_frag)
// For sparse: k_within_frag is in the compressed K domain (K / kCompressionRatio)
constexpr index_t a_vec_size = ARegMap::num_vector_items;
constexpr index_t external_a_frag_vec_size = a_vec_size * kCompressionRatio;
for(uint32_t lane = 0; lane < waveSize; ++lane)
{
auto* lane_a = reinterpret_cast<AFragScalar*>(&a_per_lane[lane]);
for(uint32_t bm = 0; bm < FragsM; ++bm)
{
for(uint32_t bk = 0; bk < FragsK; ++bk)
{
uint32_t frag_offset = (bm * FragsK + bk) * external_a_frag_vec_size;
if constexpr(kCompressionRatio > 1)
{
// Sparse: fill external (uncompressed) vector
for(index_t ev = 0; ev < external_a_frag_vec_size; ++ev)
{
index_t compressed_v = ev / kCompressionRatio;
index_t sub_pos = ev % kCompressionRatio;
auto coords =
ARegMap::calc_matrix_indices_from_lane_vector(lane, compressed_v);
uint32_t m_local = coords[0];
uint32_t k_compressed = coords[1];
uint32_t k_local = k_compressed * kCompressionRatio + sub_pos;
uint32_t m_global = bm * FragM + m_local;
uint32_t k_global = bk * FragK + k_local;
lane_a[frag_offset + ev] =
static_cast<AFragScalar>(A_matrix[m_global * K + k_global]);
}
}
else
{
// Dense/Scale: direct mapping
for(index_t v = 0; v < a_vec_size; ++v)
{
auto coords = ARegMap::calc_matrix_indices_from_lane_vector(lane, v);
uint32_t m_local = coords[0];
uint32_t k_local = coords[1];
uint32_t m_global = bm * FragM + m_local;
uint32_t k_global = bk * FragK + k_local;
lane_a[frag_offset + v] =
static_cast<AFragScalar>(A_matrix[m_global * K + k_global]);
}
}
}
}
}
}
// Fill per-lane B fragments from logical B[K][N] matrix.
// BVecType = InternalBVecT[FragsN][FragsK]
template <typename Pipeline, typename BScalar>
void fill_b_fragments(typename Pipeline::BVecType* b_per_lane,
const std::vector<BScalar>& B_matrix,
uint32_t N,
uint32_t waveSize)
{
using MmaOp = typename Pipeline::MmaOp;
using BRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::BWarpDstrEncoding>;
using BFragScalar = typename vector_traits<typename MmaOp::BVecType>::scalar_type;
constexpr uint32_t FragN = Pipeline::FragN;
constexpr uint32_t FragK = Pipeline::FragK;
constexpr uint32_t FragsN = Pipeline::FragsN;
constexpr uint32_t FragsK = Pipeline::FragsK;
constexpr index_t b_vec_size = BRegMap::num_vector_items;
for(uint32_t lane = 0; lane < waveSize; ++lane)
{
auto* lane_b = reinterpret_cast<BFragScalar*>(&b_per_lane[lane]);
for(uint32_t bn = 0; bn < FragsN; ++bn)
{
for(uint32_t bk = 0; bk < FragsK; ++bk)
{
uint32_t frag_offset = (bn * FragsK + bk) * b_vec_size;
for(index_t v = 0; v < b_vec_size; ++v)
{
auto coords = BRegMap::calc_matrix_indices_from_lane_vector(lane, v);
uint32_t n_local = coords[0];
uint32_t k_local = coords[1];
uint32_t n_global = bn * FragN + n_local;
uint32_t k_global = bk * FragK + k_local;
// B matrix is stored as B[K][N]
lane_b[frag_offset + v] =
static_cast<BFragScalar>(B_matrix[k_global * N + n_global]);
}
}
}
}
}
// Extract C matrix from per-lane C fragments.
// CVecType = InternalCVecT[FragsM][FragsN]
template <typename Pipeline, typename CScalar>
void extract_c_matrix(const typename Pipeline::CVecType* c_per_lane,
std::vector<CScalar>& C_matrix,
uint32_t N,
uint32_t waveSize)
{
using MmaOp = typename Pipeline::MmaOp;
using CRegMap = TileDistrEncRegMap<typename TileDistrEncCalc<MmaOp>::CWarpDstrEncoding>;
using CFragScalar = typename vector_traits<typename MmaOp::CVecType>::scalar_type;
constexpr uint32_t FragM = Pipeline::FragM;
constexpr uint32_t FragN = Pipeline::FragN;
constexpr uint32_t FragsM = Pipeline::FragsM;
constexpr uint32_t FragsN = Pipeline::FragsN;
constexpr index_t c_vec_size = CRegMap::num_vector_items;
for(uint32_t lane = 0; lane < waveSize; ++lane)
{
auto* lane_c = reinterpret_cast<const CFragScalar*>(&c_per_lane[lane]);
for(uint32_t bm = 0; bm < FragsM; ++bm)
{
for(uint32_t bn = 0; bn < FragsN; ++bn)
{
uint32_t frag_offset = (bm * FragsN + bn) * c_vec_size;
for(index_t v = 0; v < c_vec_size; ++v)
{
auto coords = CRegMap::calc_matrix_indices_from_lane_vector(lane, v);
uint32_t m_local = coords[0];
uint32_t n_local = coords[1];
uint32_t m_global = bm * FragM + m_local;
uint32_t n_global = bn * FragN + n_local;
C_matrix[m_global * N + n_global] =
static_cast<CScalar>(lane_c[frag_offset + v]);
}
}
}
}
}
/// Internal: runs the test with a fully resolved Pipeline type.
/// Called from run_pipeline_matrix_test after dispatching on compiler target.
template <typename Pipeline,
typename KernelType,
typename AScalar = fp16_t,
typename BScalar = fp16_t,
typename CScalar = fp32_t>
void run_pipeline_matrix_test_impl(uint32_t M,
uint32_t N,
uint32_t K,
uint32_t waveSize,
KernelType kernel,
bool isSparse,
bool transposeExpected = false,
float referenceScale = 1.0f)
{
std::vector<AScalar> A_matrix(M * K);
std::vector<BScalar> B_matrix(K * N);
std::vector<CScalar> C_expected(M * N, static_cast<CScalar>(0));
std::vector<CScalar> C_actual(M * N, static_cast<CScalar>(0));
for(uint32_t m = 0; m < M; ++m)
for(uint32_t k = 0; k < K; ++k)
A_matrix[m * K + k] = deterministic_value<AScalar>(m, k, K);
for(uint32_t k = 0; k < K; ++k)
for(uint32_t n = 0; n < N; ++n)
B_matrix[k * N + n] = deterministic_value<BScalar>(k, n, N);
if(isSparse)
{
apply_sparse_pattern(A_matrix, M, K);
}
void
test_pipeline(std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
std::function<void(uint32_t, void*, void*, void*, void*, void*, void*)> kernel,
std::function<CType(uint32_t, ScaleAType, ScaleBType)> getExpected,
std::function<AType(size_t)> aInitializer = nullptr)
reference_matmul(C_expected, A_matrix, B_matrix, M, N, K);
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const size_t a_buf_size = waveSize * sizeof(AVecType);
const size_t b_buf_size = waveSize * sizeof(BVecType);
const size_t c_buf_size = waveSize * sizeof(CVecType);
std::vector<uint8_t> h_a(a_buf_size, 0);
std::vector<uint8_t> h_b(b_buf_size, 0);
std::vector<uint8_t> h_c(c_buf_size, 0);
fill_a_fragments<Pipeline>(reinterpret_cast<AVecType*>(h_a.data()), A_matrix, K, waveSize);
fill_b_fragments<Pipeline>(reinterpret_cast<BVecType*>(h_b.data()), B_matrix, N, waveSize);
void *d_a, *d_b, *d_c;
HIP_CHECK_ERROR(hipMalloc(&d_a, a_buf_size));
HIP_CHECK_ERROR(hipMalloc(&d_b, b_buf_size));
HIP_CHECK_ERROR(hipMalloc(&d_c, c_buf_size));
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), a_buf_size, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), b_buf_size, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemset(d_c, 0, c_buf_size));
ck_tile::launch_kernel(ck_tile::stream_config{},
ck_tile::make_kernel(kernel, dim3(1), dim3(waveSize), 0, d_a, d_b, d_c));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_c.data(), d_c, c_buf_size, hipMemcpyDeviceToHost));
extract_c_matrix<Pipeline>(
reinterpret_cast<const CVecType*>(h_c.data()), C_actual, N, waveSize);
for(uint32_t m = 0; m < M; ++m)
{
using namespace ck_tile;
using namespace ck_tile::core::arch;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
bool hasDevice = static_cast<bool>(devCount > 0);
int deviceWarpSize = devProp.warpSize;
if(!hasDevice || shouldSkip(currentArchId))
for(uint32_t n = 0; n < N; ++n)
{
GTEST_SKIP() << "No HIP device found. Skipping test.";
// When transposeExpected is true, the kernel computes C^T via SwapAB,
// so compare actual C[m][n] against reference C[n][m].
constexpr float relative_tolerance = 1e-2f;
constexpr float absolute_tolerance = 1e-3f;
float expected = transposeExpected ? static_cast<float>(C_expected[n * M + m])
: static_cast<float>(C_expected[m * N + n]);
expected *= referenceScale;
float actual = static_cast<float>(C_actual[m * N + n]);
EXPECT_NEAR(
actual, expected, std::abs(expected) * relative_tolerance + absolute_tolerance)
<< "Mismatch at C[" << m << "][" << n << "]";
}
// WaveTile size, also the expected fragment size (MmaTile) from the selector.
// Note: Actual FragK might be slightly different due to hardware implementation, but the
// test_accum_over_k kernel will loop over the K dimension to ensure that the total K is
// correct.
static constexpr uint32_t FragM = WaveTileM;
static constexpr uint32_t FragN = WaveTileN;
static constexpr uint32_t FragK = WaveTileK;
// The number of elements per thread
uint32_t AElements = FragM * FragK / deviceWarpSize / numeric_traits<AType>::PackedSize;
uint32_t BElements = FragN * FragK / deviceWarpSize / numeric_traits<BType>::PackedSize;
uint32_t CElements = FragM * FragN / deviceWarpSize;
uint32_t ASize = AElements * sizeof(AType);
uint32_t BSize = BElements * sizeof(BType);
uint32_t CSize = CElements * sizeof(CType);
uint32_t ScaleASize = 1 * sizeof(ScaleAType);
uint32_t ScaleBSize = 1 * sizeof(ScaleBType);
// Initialize A (use custom initializer or default all 1's), B to all 1's, C to all 0's
std::vector<AType> h_a(AElements);
if(aInitializer)
{
for(size_t i = 0; i < AElements; ++i)
h_a[i] = aInitializer(i);
}
else
{
std::fill(h_a.begin(), h_a.end(), type_convert<AType>(1.0f));
}
std::vector<BType> h_b(BElements, type_convert<BType>(1.0f));
std::vector<CType> h_c(CElements, type_convert<CType>(0.0f));
std::vector<CType> h_out(CElements, type_convert<CType>(0.0f));
// The actual scale is computed as pow(2, scale - 127), so:
// 126 -> 2^-1 and 129 -> 2^2.
ScaleAType h_scale_a = 126;
ScaleBType h_scale_b = 129;
AType* d_a;
BType* d_b;
CType* d_c;
CType* d_out;
ScaleAType* d_scale_a;
ScaleBType* d_scale_b;
HIP_CHECK_ERROR(hipMalloc(&d_a, ASize));
HIP_CHECK_ERROR(hipMalloc(&d_b, BSize));
HIP_CHECK_ERROR(hipMalloc(&d_c, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_out, CSize));
HIP_CHECK_ERROR(hipMalloc(&d_scale_a, ScaleASize));
HIP_CHECK_ERROR(hipMalloc(&d_scale_b, ScaleBSize));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_a, h_a.data(), ASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_scale_a, &h_scale_a, ScaleASize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_scale_b, &h_scale_b, ScaleBSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
kernel(wave_size, d_a, d_b, d_c, d_out, d_scale_a, d_scale_b);
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_out, CSize, hipMemcpyDeviceToHost));
// Verify output against expected value for all elements
for(size_t i = 0; i < CElements; ++i)
{
EXPECT_NEAR(h_out[i], getExpected(FragK, h_scale_a, h_scale_b), 1e-3);
}
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
HIP_CHECK_ERROR(hipFree(d_out));
HIP_CHECK_ERROR(hipFree(d_scale_a));
HIP_CHECK_ERROR(hipFree(d_scale_b));
}
};
HIP_CHECK_ERROR(hipFree(d_a));
HIP_CHECK_ERROR(hipFree(d_b));
HIP_CHECK_ERROR(hipFree(d_c));
}
/// @tparam PipelineFactory A template template that, given a CompilerTarget type, produces
/// the Pipeline type: PipelineFactory<Target>::type
/// @tparam KernelType Kernel functor struct with kBlockSize and __device__ operator()
/// @tparam AScalar Scalar type for A matrix (e.g., fp16_t)
/// @tparam BScalar Scalar type for B matrix (e.g., fp16_t)
/// @tparam CScalar Scalar type for C matrix (e.g., fp32_t)
/// @param M WaveTile M dimension
/// @param N WaveTile N dimension
/// @param K WaveTile K dimension
/// @param shouldSkip Predicate returning true if current device should skip
/// @param kernel Kernel functor instance to launch via make_kernel
/// @param isSparse Whether to apply 2:4 sparsity pattern to A
/// @param transposeExpected When true, compare against transposed reference (for
/// SwapAB/TransposeC)
/// @param referenceScale Scalar multiplier applied to the reference matmul result before
/// comparison (e.g., to account for scale-MMA scaling factors)
template <template <typename> class PipelineFactory,
typename KernelType,
typename AScalar = fp16_t,
typename BScalar = fp16_t,
typename CScalar = fp32_t>
void run_pipeline_matrix_test(uint32_t M,
uint32_t N,
uint32_t K,
std::function<bool(ck_tile::core::arch::amdgcn_target_id)> shouldSkip,
KernelType kernel,
bool isSparse = false,
bool transposeExpected = false,
float referenceScale = 1.0f)
{
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceCount(&devCount));
hipDeviceProp_t devProp;
HIP_CHECK_ERROR(hipGetDeviceProperties(&devProp, dev));
auto currentArchId = hip_device_prop_gcn_arch_name_to_amdgcn_target_id(devProp.gcnArchName);
if(devCount <= 0 || shouldSkip(currentArchId))
{
GTEST_SKIP() << "No HIP device found or arch (0x" << std::hex
<< static_cast<int>(currentArchId) << ") not supported. Skipping test.";
}
if(!hipTargetMatchesCmakeTargets(currentArchId))
{
std::cout << "The GPU targets exposed by CMake are: ";
for(const auto& target : getCMakeGpuTargetIds())
{
std::cout << "(0x" << std::hex << static_cast<int>(target) << ")\n";
}
FAIL() << "The HIP device (0x" << std::hex << static_cast<int>(currentArchId)
<< ") does not match the compiler target(s).";
}
const uint32_t waveSize = static_cast<uint32_t>(devProp.warpSize);
bool dispatched = dispatchCompilerTarget(currentArchId, [&](auto target) {
using CompilerTarget = decltype(target);
using Pipeline = typename PipelineFactory<CompilerTarget>::type;
run_pipeline_matrix_test_impl<Pipeline, KernelType, AScalar, BScalar, CScalar>(
M, N, K, waveSize, kernel, isSparse, transposeExpected, referenceScale);
});
if(!dispatched)
{
GTEST_SKIP() << "Cannot dispatch on HOST target.";
}
}
} // namespace mma_pipeline_test

View File

@@ -18,7 +18,7 @@ TEST(MmaPipelineOptionFlagsTests, ConversionTests)
MmaPipelineOptionFlags flags_0{};
MmaPipelineOptionFlags flags_1{MmaPipelineOptionFlag::ABSwap};
MmaPipelineOptionFlags flags_2{MmaPipelineOptionFlag::COMPRESS_A};
MmaPipelineOptionFlags flags_3{0b11};
MmaPipelineOptionFlags flags_3{0b11}; // TODO c++20 - remove this
EXPECT_TRUE(flags_0.testFlag(MmaPipelineOptionFlag::NONE));
EXPECT_FALSE(flags_0.testFlag(MmaPipelineOptionFlag::ABSwap));

View File

@@ -5,17 +5,16 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/scale/scale_mma_pipeline.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/pk_fp4.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include <gtest/gtest.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <type_traits>
@@ -44,21 +43,21 @@ void ScaleMfmaGfx950Specialization_impl()
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
static_assert(std::is_same_v<typename TestScaleMma::OpType, MfmaOp> &&
TestScaleMma::OpFamily == MmaOpFamily::SCALE,
"GFX950 scale intrinsic should have ScaleMFMAOp type");
EXPECT_TRUE((std::is_same_v<typename TestScaleMma::OpType, MfmaOp> &&
TestScaleMma::OpFamily == MmaOpFamily::SCALE))
<< "GFX950 scale intrinsic should have ScaleMFMAOp type";
static_assert(is_mma_op_of_family_v<MmaOpFamily::SCALE, TestScaleMma>,
"GFX950 scale intrinsic should be detected as Scale");
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::SCALE, TestScaleMma>))
<< "GFX950 scale intrinsic should be detected as Scale";
// Get its traits
using TestTraits = MmaOpTraits<TestScaleMma>;
// Verify trait detection
static_assert(TestTraits::IsScale, "Scale MMA should be detected as scale");
static_assert(TestTraits::IsSupported, "Scale MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Scale MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Scale MFMA should not be detected as WMMA");
EXPECT_TRUE(TestTraits::IsScale) << "Scale MMA should be detected as scale";
EXPECT_TRUE(TestTraits::IsSupported) << "Scale MMA specialization should be supported";
EXPECT_TRUE(TestTraits::IsMfma) << "Scale MFMA should be detected as MFMA";
EXPECT_FALSE(TestTraits::IsWmma) << "Scale MFMA should not be detected as WMMA";
}
TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
@@ -67,14 +66,10 @@ TEST(ScaleMMATrait, ScaleMfmaGfx950Specialization)
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
// Test bf8 -> fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
// Test fp4 -> fp32 scale MFMA for GFX950 (16x16x128)
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
// Test fp8 -> fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
// Test bf8 -> fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
// Test fp4 -> fp32 scale MFMA for GFX950 (32x32x64)
ScaleMfmaGfx950Specialization_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
std::cout << "GFX950 scale MFMA specialization is correct" << std::endl;
}
@@ -97,7 +92,7 @@ void TestConceptRequirements_impl()
DefaultScaleMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SCALE>;
static_assert(MmaOpI<TestScaleMma>);
EXPECT_TRUE(MmaOpI<TestScaleMma>);
}
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -106,10 +101,8 @@ TEST(ScaleMMATrait, TestConceptRequirements)
#if CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
TestConceptRequirements_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 64u>();
TestConceptRequirements_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
TestConceptRequirements_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
#else
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -132,15 +125,15 @@ void ScaleSelector_impl()
if constexpr(isValid)
{
// Selector should pick a scale MFMA implementation
static_assert(MmaOpTraits<Selected>::IsScale);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
EXPECT_TRUE(MmaOpTraits<Selected>::IsScale);
EXPECT_TRUE(MmaOpTraits<Selected>::IsMfma);
EXPECT_TRUE(MmaOpTraits<Selected>::IsSupported);
EXPECT_TRUE((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
EXPECT_FALSE(MmaOpTraits<Selected>::IsSupported);
}
});
});
@@ -150,7 +143,6 @@ TEST(ScaleMMATrait, ScaleSelector)
{
ScaleSelector_impl<fp8_t, fp8_t, fp32_t>();
ScaleSelector_impl<bf8_t, bf8_t, fp32_t>();
ScaleSelector_impl<pk_fp4_t, pk_fp4_t, fp32_t>();
}
template <typename AType,
@@ -161,34 +153,72 @@ template <typename AType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
__global__ void
test_scale_accum_over_k(void* a, void* b, void* c, void* out, void* scale_A, void* scale_B)
struct ScalePipelineKernel
{
using Pipeline = ScaleMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
// NOTE: WaveTileK is used as a Pipeline template parameter, but the K iteration is
// happening outside the Pipeline. This is a bit incorrect currently.
static constexpr std::uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(std::uint32_t i = 0; i < kIters; ++i)
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
*reinterpret_cast<BVecType*>(b),
result,
*reinterpret_cast<ScaleAType*>(scale_A),
*reinterpret_cast<ScaleBType*>(scale_B));
}
using Pipeline = ScaleMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
*reinterpret_cast<CVecType*>(out) = result;
}
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
// Each lane has a single 8-bit E8M0 scale that applies to all
// 32 A/B elements in that lane. The byte's position within the
// VGPR is selected by opsel. Replicating the byte to all 4
// positions makes the value opsel-independent.
// scale_a byte = 126 -> 2^(126-127) = 2^-1 = 0.5
// scale_b byte = 129 -> 2^(129-127) = 2^2 = 4.0
// Combined scale factor = 0.5 * 4.0 = 2.0
constexpr int32_t replicate_byte = 0x01010101;
ScaleAType scale_a = 126u * replicate_byte;
ScaleBType scale_b = 129u * replicate_byte;
Pipeline::exec(a, b, c, scale_a, scale_b);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
};
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK>
struct ScalePipelineFactory
{
template <typename Target>
struct Create
{
using type = ScaleMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaAccumPolicy::ROW_MAJOR,
Target>;
};
};
template <typename AType,
typename BType,
@@ -198,39 +228,37 @@ template <typename AType,
std::uint32_t WaveTileK>
void MmaSelector_Scale_Real_impl()
{
using TestType = MmaPipelineTest<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
TestType test;
using ScaleAType = std::int32_t;
using ScaleBType = std::int32_t;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = false;
bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
return ((currentArchId == amdgcn_target_id::HOST) || !isSupportedMfma);
};
const std::function<fp32_t(
std::uint32_t, typename TestType::ScaleAType, typename TestType::ScaleBType)>
validator =
[](std::uint32_t fragK, TestType::ScaleAType scale_A, TestType::ScaleBType scale_B) {
fp32_t actual_scale_A = std::powf(2.0f, scale_A - 127.0f);
fp32_t actual_scale_B = std::powf(2.0f, scale_B - 127.0f);
return static_cast<fp32_t>(fragK) * actual_scale_A * actual_scale_B;
};
const auto kernel = [](std::uint32_t waveSize,
void* a,
void* b,
void* c,
void* out,
void* scale_A,
void* scale_B) {
test_scale_accum_over_k<typename TestType::AType,
typename TestType::BType,
typename TestType::CType,
typename TestType::ScaleAType,
typename TestType::ScaleBType,
TestType::WaveTileM,
TestType::WaveTileN,
TestType::WaveTileK>
<<<1, waveSize>>>(a, b, c, out, scale_A, scale_B);
};
test.test_pipeline(should_skip, kernel, validator);
using Factory = ScalePipelineFactory<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
using Kernel = ScalePipelineKernel<AType,
BType,
CType,
ScaleAType,
ScaleBType,
WaveTileM,
WaveTileN,
WaveTileK>;
// scale_a=126 -> 2^-1=0.5, scale_b=129 -> 2^2=4.0 -> combined = 2.0
constexpr float reference_scale = 2.0f;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM,
WaveTileN,
WaveTileK,
should_skip,
Kernel{},
/*isSparse=*/false,
/*transposeExpected=*/false,
reference_scale);
}
// Live test on real hardware for scale selection and execution.
@@ -245,12 +273,6 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_16x16x128_Real)
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_16x16x128_Real)
{
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 16u, 16u, 128u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x64_Real)
{
@@ -263,8 +285,215 @@ TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x64_Real)
MmaSelector_Scale_Real_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 64u>();
}
// Live test on real hardware for scale selection and execution.
TEST(ScaleMMATrait, MmaSelector_Scale_F4_F4_F32_32x32x64_Real)
// ---------------------------------------------------------------------------
// Multi-fragment (WaveWise) scale pipeline tests
// ---------------------------------------------------------------------------
// Kernel functor with AccumPolicy support for multi-fragment scale pipeline tests.
template <typename AType,
typename BType,
typename CType,
typename ScaleAType,
typename ScaleBType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct ScaleWaveWisePipelineKernel
{
MmaSelector_Scale_Real_impl<pk_fp4_t, pk_fp4_t, fp32_t, 32u, 32u, 64u>();
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = ScaleMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
// Each lane has a single 8-bit E8M0 scale that applies to all
// 32 A/B elements in that lane. The byte's position within the
// VGPR is selected by opsel. Replicating the byte to all 4
// positions makes the value opsel-independent.
// scale_a byte = 126 -> 2^(126-127) = 2^-1 = 0.5
// scale_b byte = 129 -> 2^(129-127) = 2^2 = 4.0
// Combined scale factor = 0.5 * 4.0 = 2.0
constexpr int32_t replicate_byte = 0x01010101;
ScaleAType scale_a = 126u * replicate_byte;
ScaleBType scale_b = 129u * replicate_byte;
Pipeline::exec(a, b, c, scale_a, scale_b);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
};
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct ScaleWaveWisePipelineFactory
{
template <typename Target>
struct Create
{
using type = ScaleMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
Target>;
};
};
template <typename AType,
typename BType,
typename CType,
std::uint32_t WaveTileM,
std::uint32_t WaveTileN,
std::uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR>
void MmaSelector_Scale_WaveWise_Real_impl()
{
using ScaleAType = std::int32_t;
using ScaleBType = std::int32_t;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedMfma = (currentArchId == amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !isSupportedMfma);
};
using Factory = ScaleWaveWisePipelineFactory<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy>;
using Kernel = ScaleWaveWisePipelineKernel<AType,
BType,
CType,
ScaleAType,
ScaleBType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy>;
// scale_a=126 -> 2^-1=0.5, scale_b=129 -> 2^2=4.0 -> combined = 2.0
constexpr float reference_scale = 2.0f;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM,
WaveTileN,
WaveTileK,
should_skip,
Kernel{},
/*isSparse=*/false,
/*transposeExpected=*/false,
reference_scale);
}
// Multi-fragment tests: 64x64x64 uses 32x32x64 op -> FragsM=2, FragsN=2, FragsK=1
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x64_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t,
fp8_t,
fp32_t,
64u,
64u,
64u,
MmaAccumPolicy::ROW_MAJOR>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x64_WaveWise_ColMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t,
fp8_t,
fp32_t,
64u,
64u,
64u,
MmaAccumPolicy::COL_MAJOR>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_64x64x64_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<bf8_t,
bf8_t,
fp32_t,
64u,
64u,
64u,
MmaAccumPolicy::ROW_MAJOR>();
}
// Multi-fragment tests: 32x32x128 uses 32x32x64 op -> FragsM=1, FragsN=1, FragsK=2
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x32x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 32u, 32u, 128u>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_BF8_BF8_F32_32x32x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<bf8_t, bf8_t, fp32_t, 32u, 32u, 128u>();
}
// Multi-fragment tests: 64x64x128 uses 32x32x64 op -> FragsM=2, FragsN=2, FragsK=2
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 64u, 64u, 128u>();
}
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_64x64x128_WaveWise_ColMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t,
fp8_t,
fp32_t,
64u,
64u,
128u,
MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment tests with 16x16x128 op: 32x16x128 -> FragsM=2, FragsN=1, FragsK=1
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_32x16x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 32u, 16u, 128u>();
}
// Multi-fragment tests with 16x16x128 op: 16x32x128 -> FragsM=1, FragsN=2, FragsK=1
TEST(ScaleMMATrait, MmaSelector_Scale_F8_F8_F32_16x32x128_WaveWise_RowMajor_Real)
{
MmaSelector_Scale_WaveWise_Real_impl<fp8_t, fp8_t, fp32_t, 16u, 32u, 128u>();
}

View File

@@ -8,6 +8,7 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_selector.hpp"
#include "ck_tile/core/arch/mma/sparse/sparse_mma_pipeline.hpp"
@@ -41,14 +42,12 @@ TEST(SparseMMATrait, SparseMfmaGfx950Specialization)
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE,
"GFX950 sparse 16x16x32 should have SparseMFMAOp type");
EXPECT_TRUE((std::is_same_v<typename TestSparseMfma16x16::OpType, MfmaOp> &&
TestSparseMfma16x16::OpFamily == MmaOpFamily::SPARSE))
<< "GFX950 sparse 16x16x32 should have SparseMFMAOp type";
static_assert(is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>,
"GFX950 sparse 16x16x32 should be detected as Sparse");
std::cout << "GFX950 sparse MFMA specialization is correct" << std::endl;
EXPECT_TRUE((is_mma_op_of_family_v<MmaOpFamily::SPARSE, TestSparseMfma16x16>))
<< "GFX950 sparse 16x16x32 should be detected as Sparse";
}
TEST(SparseMMATrait, MmaOpTraitsIntegration)
@@ -68,12 +67,10 @@ TEST(SparseMMATrait, MmaOpTraitsIntegration)
using TestTraits = MmaOpTraits<TestSparseMmma>;
// Verify trait detection
static_assert(TestTraits::IsSparse, "Sparse MMA should be detected as sparse");
static_assert(TestTraits::IsSupported, "Sparse MMA specialization should be supported");
static_assert(TestTraits::IsMfma, "Sparse MFMA should be detected as MFMA");
static_assert(!TestTraits::IsWmma, "Sparse MFMA should not be detected as WMMA");
std::cout << "MmaOpTraits correctly integrates sparse operations" << std::endl;
EXPECT_TRUE(TestTraits::IsSparse) << "Sparse MMA should be detected as sparse";
EXPECT_TRUE(TestTraits::IsSupported) << "Sparse MMA specialization should be supported";
EXPECT_TRUE(TestTraits::IsMfma) << "Sparse MFMA should be detected as MFMA";
EXPECT_FALSE(TestTraits::IsWmma) << "Sparse MFMA should not be detected as WMMA";
}
TEST(SparseMMATrait, TestConceptRequirements)
@@ -88,7 +85,7 @@ TEST(SparseMMATrait, TestConceptRequirements)
DefaultSparseMfmaCtrlFlags,
CompilerTargetGfx950,
MmaOpFamily::SPARSE>;
static_assert(MmaOpI<TestSparseMmma>);
EXPECT_TRUE(MmaOpI<TestSparseMmma>);
#else
GTEST_SKIP() << "Not compiled with concepts. Skipping test.";
#endif // CK_TILE_CONCEPTS && CK_TILE_CONCEPTS_HEADER
@@ -119,22 +116,20 @@ TEST(SparseMMATrait, DenseVsSparseDistinction)
MmaOpFamily::SPARSE>;
// Verify they have different operation types
static_assert(std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily,
"Dense and Sparse MFMA should have the same OpType tags and different OpFamily");
EXPECT_TRUE((std::is_same_v<typename DenseMfma::OpType, typename SparseMfma::OpType> &&
DenseMfma::OpFamily != SparseMfma::OpFamily))
<< "Dense and Sparse MFMA should have the same OpType tags and different OpFamily";
// Verify traits correctly identify them
static_assert(MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported,
"Dense MFMA should be identified correctly");
EXPECT_TRUE((MmaOpTraits<DenseMfma>::IsMfma && MmaOpTraits<DenseMfma>::IsDense &&
!MmaOpTraits<DenseMfma>::IsSparse && !MmaOpTraits<DenseMfma>::IsScale &&
MmaOpTraits<DenseMfma>::IsSupported))
<< "Dense MFMA should be identified correctly";
static_assert(MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported,
"Sparse MFMA should be identified correctly");
std::cout << "Dense and sparse MMA operations are correctly distinguished" << std::endl;
EXPECT_TRUE((MmaOpTraits<SparseMfma>::IsSparse && MmaOpTraits<SparseMfma>::IsMfma &&
!MmaOpTraits<SparseMfma>::IsDense && !MmaOpTraits<SparseMfma>::IsScale &&
MmaOpTraits<SparseMfma>::IsSupported))
<< "Sparse MFMA should be identified correctly";
}
TEST(SparseMMATrait, SparseSelector)
@@ -153,100 +148,53 @@ TEST(SparseMMATrait, SparseSelector)
if constexpr(isValid)
{
// Selector should pick a sparse MFMA implementation
static_assert(MmaOpTraits<Selected>::IsSparse);
static_assert(MmaOpTraits<Selected>::IsMfma);
static_assert(MmaOpTraits<Selected>::IsSupported);
static_assert((std::is_same<typename Selected::OpType, MfmaOp>::value));
EXPECT_TRUE(MmaOpTraits<Selected>::IsSparse);
EXPECT_TRUE(MmaOpTraits<Selected>::IsMfma);
EXPECT_TRUE(MmaOpTraits<Selected>::IsSupported);
EXPECT_TRUE((std::is_same<typename Selected::OpType, MfmaOp>::value));
}
else
{
// Selector should pick the unsupported pass through
static_assert(!MmaOpTraits<Selected>::IsSupported);
EXPECT_FALSE(MmaOpTraits<Selected>::IsSupported);
}
});
}
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK>
__global__ void test_sparse_accum_over_k(void* a, void* b, void* c, void* out)
{
using Pipeline = SparseMmaPipeline<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
static constexpr uint32_t kIters = WaveTileK / Pipeline::MmaOp::kK;
// Initialize the accumulator
CVecType result = *reinterpret_cast<CVecType*>(c);
// Accumulate input AxB over WaveTileK/FragK iterations
for(uint32_t i = 0; i < kIters; ++i)
{
result = Pipeline::exec(
*reinterpret_cast<AVecType*>(a), *reinterpret_cast<BVecType*>(b), result);
}
*reinterpret_cast<CVecType*>(out) = result;
}
// Live test on real hardware for sparse selection and execution.
TEST(SparseMMATrait, MmaSelector_Sparse_F16_F16_F32_16x16x32_Real)
{
MmaPipelineTest<> test;
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma = (currentArchId >= amdgcn_target_id::GFX942) &&
(currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(uint32_t)> validator = [](uint32_t waveTileK) {
return static_cast<fp32_t>(waveTileK) / 2;
};
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_sparse_accum_over_k<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK><<<1, waveSize>>>(a, b, c, out);
};
// Initialize A with 2:4 structured sparsity pattern: {1, 0, 1, 0, ...}
// This ensures the sparse compression transform is actually exercised -
// a no-op or broken compression would pass zeros through, causing incorrect results.
const std::function<fp16_t(size_t)> sparseAInit = [](size_t i) -> fp16_t {
return (i % 2 == 0) ? type_convert<fp16_t>(1) : type_convert<fp16_t>(0);
};
test.test_pipeline(should_skip, kernel, validator, sparseAInit);
}
template <uint32_t CompressionRatio, typename Vec>
__global__ void test_sparse_transform(void* a, void* idx)
struct SparseTransformKernel
{
using ResultT =
decltype(SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a)));
using FirstT = std::tuple_element_t<0, ResultT>;
const auto& [vec, i] = SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a));
*reinterpret_cast<remove_cvref_t<FirstT>*>(a) = vec;
*reinterpret_cast<int32_t*>(idx) = i;
}
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
__device__ void operator()(void* a, void* idx) const
{
using ResultT =
decltype(SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a)));
using FirstT = std::tuple_element_t<0, ResultT>;
using IdxT = std::tuple_element_t<1, ResultT>;
const auto& [vec, i] =
SparseCompressTransform<CompressionRatio>::exec(*static_cast<Vec*>(a));
*reinterpret_cast<remove_cvref_t<FirstT>*>(a) = vec;
__builtin_memcpy(idx, &i, sizeof(IdxT));
}
};
// Generalized helper: runs the sparse transform kernel and verifies compressed output and index.
template <int NUM, int RATIO, typename Type>
void sparse_transform_verify(const std::vector<Type>& input,
const std::vector<Type>& expected_output,
int32_t expected_idx)
void sparse_transform_verify(
const std::vector<Type>& input,
const std::vector<Type>& expected_output,
const sparse::detail::SparseIdxPack<sparse::detail::idx_words_needed<NUM / RATIO>>&
expected_idx)
{
static_assert(RATIO == 2, "Extend functionality if other ratio is used.");
ASSERT_EQ(static_cast<int>(input.size()), NUM);
ASSERT_EQ(static_cast<int>(expected_output.size()), NUM / RATIO);
constexpr int CompressedSize = NUM / RATIO;
constexpr int IdxNumWords = sparse::detail::idx_words_needed<CompressedSize>;
using IdxType = sparse::detail::SparseIdxPack<IdxNumWords>;
int devCount;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
@@ -265,24 +213,31 @@ void sparse_transform_verify(const std::vector<Type>& input,
}
float* d_v;
int32_t* d_idx;
void* d_idx;
static constexpr auto Size = sizeof(Type) * NUM;
HIP_CHECK_ERROR(hipMalloc(&d_v, Size));
HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(int32_t)));
HIP_CHECK_ERROR(hipMalloc(&d_idx, sizeof(IdxType)));
// Copy inputs to device
HIP_CHECK_ERROR(hipMemcpy(d_v, input.data(), Size, hipMemcpyHostToDevice));
test_sparse_transform<RATIO, ext_vector_t<Type, NUM>><<<1, 32>>>(d_v, d_idx);
using Kernel = SparseTransformKernel<RATIO, ext_vector_t<Type, NUM>>;
ck_tile::launch_kernel(ck_tile::stream_config{},
ck_tile::make_kernel(Kernel{}, dim3(1), dim3(32), 0, d_v, d_idx));
HIP_CHECK_ERROR(hipDeviceSynchronize());
std::vector<Type> h_out(NUM / RATIO, static_cast<Type>(0));
HIP_CHECK_ERROR(hipMemcpy(h_out.data(), d_v, Size / RATIO, hipMemcpyDeviceToHost));
int32_t h_idx;
HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(int32_t), hipMemcpyDeviceToHost));
IdxType h_idx{};
HIP_CHECK_ERROR(hipMemcpy(&h_idx, d_idx, sizeof(IdxType), hipMemcpyDeviceToHost));
EXPECT_EQ(h_idx, expected_idx) << "Index mask mismatch";
EXPECT_EQ(h_idx.words[0], expected_idx.words[0]) << "Index mask mismatch (word 0)";
for(int w = 1; w < IdxNumWords; ++w)
{
EXPECT_EQ(h_idx.words[w], expected_idx.words[w])
<< "Index mask mismatch (word " << w << ")";
}
for(int i = 0; i < NUM / RATIO; ++i)
{
EXPECT_EQ(h_out[i], expected_output[i]) << "Output mismatch at position " << i;
@@ -296,10 +251,11 @@ void sparse_transform_verify(const std::vector<Type>& input,
// initialization values (from nonzero_elems init) that don't correspond to the
// default index (slot 2). We only validate entries where the index was explicitly
// set, i.e. where input[slot] is non-zero.
constexpr int CompressedSize = NUM / RATIO;
for(int i = 0; i < CompressedSize; ++i)
{
int slot = (h_idx >> (2 * i)) & 0b11;
const int word = (2 * i) / 32;
const int shift = (2 * i) % 32;
int slot = (h_idx.words[word] >> shift) & 0b11;
int group = i / 2;
Type input_at_slot = input[group * 4 + slot];
// Only check when input at the indexed slot is non-zero (explicitly assigned)
@@ -319,20 +275,36 @@ void sparse_transform_verify(const std::vector<Type>& input,
// Helper: build expected index from a per-group 4-bit pattern, repeated for all groups.
// Each group of 4 input elements contributes 2 compressed elements -> 2 x 2-bit index fields = 4
// bits.
static int32_t build_repeated_group_idx(int num_groups, int32_t group_bits_4)
template <int NumGroups>
static auto build_repeated_group_idx(int32_t group_bits_4)
{
int32_t idx = 0;
for(int g = 0; g < num_groups; ++g)
idx |= (group_bits_4 << (4 * g));
constexpr int CompressedSize = NumGroups * 2;
constexpr int NumWords = sparse::detail::idx_words_needed<CompressedSize>;
sparse::detail::SparseIdxPack<NumWords> idx{};
for(int g = 0; g < NumGroups; ++g)
{
const int bit_pos = g * 4;
const int word = bit_pos / 32;
const int shift = bit_pos % 32;
idx.words[word] |= (group_bits_4 << shift);
}
return idx;
}
// Helper: build expected index from alternating even/odd 4-bit group patterns.
static int32_t build_alternating_group_idx(int num_groups, int32_t even_bits_4, int32_t odd_bits_4)
template <int NumGroups>
static auto build_alternating_group_idx(int32_t even_bits_4, int32_t odd_bits_4)
{
int32_t idx = 0;
for(int g = 0; g < num_groups; ++g)
idx |= ((g % 2 == 0 ? even_bits_4 : odd_bits_4) << (4 * g));
constexpr int CompressedSize = NumGroups * 2;
constexpr int NumWords = sparse::detail::idx_words_needed<CompressedSize>;
sparse::detail::SparseIdxPack<NumWords> idx{};
for(int g = 0; g < NumGroups; ++g)
{
const int bit_pos = g * 4;
const int word = bit_pos / 32;
const int shift = bit_pos % 32;
idx.words[word] |= ((g % 2 == 0 ? even_bits_4 : odd_bits_4) << shift);
}
return idx;
}
@@ -354,7 +326,7 @@ void sparse_transform_test_case()
expected_out[i] = v[i * 2];
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1000);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1000);
sparse_transform_verify<NUM, RATIO, Type>(v, expected_out, expected_idx);
}
@@ -365,6 +337,7 @@ TEST(SparseTransformsTest, ValidCompressionRatio)
sparse_transform_test_case<8, 2, fp16_t>();
sparse_transform_test_case<16, 2, fp16_t>();
sparse_transform_test_case<32, 2, fp16_t>();
sparse_transform_test_case<64, 2, fp16_t>(); // multi-word SparseIdxPack
}
// All-zero input: no non-zeros in any group of 4.
@@ -377,7 +350,7 @@ void sparse_transform_all_zero()
using T = fp16_t;
std::vector<T> input(NUM, static_cast<T>(0));
std::vector<T> expected_output(NUM / 2, static_cast<T>(0));
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1010);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1010);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -386,6 +359,7 @@ TEST(SparseTransformsTest, AllZeroInput)
sparse_transform_all_zero<8>();
sparse_transform_all_zero<16>();
sparse_transform_all_zero<32>();
sparse_transform_all_zero<64>(); // multi-word SparseIdxPack
}
// Single non-zero per group of 4 (at slot 3).
@@ -408,7 +382,7 @@ void sparse_transform_single_nonzero()
expected_output[g * 2 + 1] = val;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1011);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1011);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -417,6 +391,7 @@ TEST(SparseTransformsTest, SingleNonZeroPerGroup)
sparse_transform_single_nonzero<8>();
sparse_transform_single_nonzero<16>();
sparse_transform_single_nonzero<32>();
sparse_transform_single_nonzero<64>(); // multi-word SparseIdxPack
}
// Non-zeros at slots 1 and 3 in each group.
@@ -439,7 +414,7 @@ void sparse_transform_slots_1_and_3()
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1101);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1101);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -448,6 +423,7 @@ TEST(SparseTransformsTest, NonZerosAtSlots1And3)
sparse_transform_slots_1_and_3<8>();
sparse_transform_slots_1_and_3<16>();
sparse_transform_slots_1_and_3<32>();
sparse_transform_slots_1_and_3<64>(); // multi-word SparseIdxPack
}
// Non-zeros at slots 0 and 3 in each group (non-adjacent).
@@ -470,7 +446,7 @@ void sparse_transform_slots_0_and_3()
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_repeated_group_idx(NUM / 4, 0b1100);
auto expected_idx = build_repeated_group_idx<NUM / 4>(0b1100);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -479,6 +455,7 @@ TEST(SparseTransformsTest, NonZerosAtSlots0And3)
sparse_transform_slots_0_and_3<8>();
sparse_transform_slots_0_and_3<16>();
sparse_transform_slots_0_and_3<32>();
sparse_transform_slots_0_and_3<64>(); // multi-word SparseIdxPack
}
// Mixed sparsity pattern: even groups have non-zeros at slots 0,2; odd groups at slots 1,3.
@@ -511,7 +488,7 @@ void sparse_transform_mixed()
expected_output[g * 2 + 1] = b;
}
int32_t expected_idx = build_alternating_group_idx(NUM / 4, 0b1000, 0b1101);
auto expected_idx = build_alternating_group_idx<NUM / 4>(0b1000, 0b1101);
sparse_transform_verify<NUM, 2, T>(input, expected_output, expected_idx);
}
@@ -520,4 +497,156 @@ TEST(SparseTransformsTest, MixedSparsityPattern)
sparse_transform_mixed<8>();
sparse_transform_mixed<16>();
sparse_transform_mixed<32>();
sparse_transform_mixed<64>(); // multi-word SparseIdxPack
}
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct SparsePipelineKernel
{
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = SparseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
Pipeline::exec(a, b, c);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
};
namespace {
const auto should_skip = [](amdgcn_target_id currentArchId) {
bool isSupportedWmma = (currentArchId >= amdgcn_target_id::GFX1200) &&
(currentArchId <= amdgcn_target_id::GFX12_GENERIC);
bool isSupportedMfma =
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
} // namespace
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct SparsePipelineFactory
{
template <typename Target>
struct Create
{
using type = SparseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
Target>;
};
};
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR>
void SparsePipeline_Real_impl()
{
using Factory =
SparsePipelineFactory<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy>;
using Kernel =
SparsePipelineKernel<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy>;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM, WaveTileN, WaveTileK, should_skip, Kernel{}, /*isSparse=*/true);
}
// Full matrix verification: 16x16x32 single-fragment sparse pipeline (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x32)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u>();
}
// Multi-fragment K: 16x16x64 -> 2 K fragments, tests internal K iteration (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x64)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u>();
}
// Full matrix verification: 16x16x32 single-fragment sparse pipeline (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x32_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment K: 16x16x64 -> 2 K fragments, tests internal K iteration (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x64_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 64u, MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment K: 16x16x128 -> 4 K fragments, exercises multi-word SparseIdxPack (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x128)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 128u>();
}
// Multi-fragment K: 16x16x256 -> 8 K fragments, exercises larger multi-word SparseIdxPack
// (ROW_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x256)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 256u>();
}
// Multi-fragment K: 16x16x128 -> 4 K fragments (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x128_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 128u, MmaAccumPolicy::COL_MAJOR>();
}
// Multi-fragment K: 16x16x256 -> 8 K fragments (COL_MAJOR)
TEST(SparseMmaPipeline, FullMatrixVerify_16x16x256_ColMajor)
{
SparsePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 256u, MmaAccumPolicy::COL_MAJOR>();
}

View File

@@ -3,52 +3,70 @@
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
#include "ck_tile/core/arch/mma/mma_traits.hpp"
#include "ck_tile/core/arch/mma/mma_wavewise.hpp"
#include "ck_tile/core/arch/mma/mma.hpp"
#include "pipeline_tests_helper.hpp"
#include <memory>
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
// Kernel functor: constructs Pipeline internally using device-side get_compiler_target().
// Uses void* for data to avoid host/device symbol mismatches.
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
bool CTranspose>
__global__ void test_wavewise_pipeline(void* a, void* b, void* c, void* out)
MmaAccumPolicy AccumPolicy,
bool TransposeC>
struct WaveWisePipelineKernel
{
using CompilerTarget = decltype(get_compiler_target());
static constexpr int kBlockSize = mma_pipeline_test::getCMakeWaveSize();
using Pipeline = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
MmaAccumPolicy::ROW_MAJOR,
CTranspose,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
auto result = Pipeline::exec(*reinterpret_cast<AVecType*>(a),
*reinterpret_cast<BVecType*>(b),
*reinterpret_cast<CVecType*>(c));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
__device__ void
operator()(const void* a_per_lane, const void* b_per_lane, void* c_per_lane) const
{
// When the MmaOp is Unsupported (default) it returns the CVecType by value
// so this cast is impossible...
__builtin_memcpy(out, static_cast<const void*>(result), sizeof(CVecType));
using CompilerTarget = decltype(get_compiler_target());
using Pipeline = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
AccumPolicy,
TransposeC,
CompilerTarget>;
using AVecType = typename Pipeline::AVecType;
using BVecType = typename Pipeline::BVecType;
using CVecType = typename Pipeline::CVecType;
const uint32_t lane = threadIdx.x;
AVecType a;
BVecType b;
CVecType c;
__builtin_memcpy(&a,
static_cast<const uint8_t*>(a_per_lane) + lane * sizeof(AVecType),
sizeof(AVecType));
__builtin_memcpy(&b,
static_cast<const uint8_t*>(b_per_lane) + lane * sizeof(BVecType),
sizeof(BVecType));
__builtin_memset(&c, 0, sizeof(CVecType));
if constexpr(MmaOpTraits<typename Pipeline::MmaOp>::IsSupported)
{
Pipeline::exec(a, b, c);
__builtin_memcpy(
static_cast<uint8_t*>(c_per_lane) + lane * sizeof(CVecType), &c, sizeof(CVecType));
}
}
}
};
namespace {
const auto should_skip = [](amdgcn_target_id currentArchId) {
@@ -57,37 +75,95 @@ const auto should_skip = [](amdgcn_target_id currentArchId) {
(currentArchId >= amdgcn_target_id::GFX942) && (currentArchId <= amdgcn_target_id::GFX950);
return ((currentArchId == amdgcn_target_id::HOST) || !(isSupportedWmma || isSupportedMfma));
};
const std::function<fp32_t(uint32_t)> validator = [](uint32_t waveTileK) {
return static_cast<fp32_t>(waveTileK);
};
} // namespace
TEST(WaveWiseMmaPipeline, testKIter)
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy>
struct WaveWisePipelineFactory
{
MmaPipelineTest<> test;
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_wavewise_pipeline<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK,
false><<<1, waveSize>>>(a, b, c, out);
template <typename Target>
struct Create
{
using type = WaveWiseMmaPipeline<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
MmaOpFamily::DENSE,
AccumPolicy,
false,
Target>;
};
test.test_pipeline(should_skip, kernel, validator);
};
template <typename AType,
typename BType,
typename CType,
uint32_t WaveTileM,
uint32_t WaveTileN,
uint32_t WaveTileK,
MmaAccumPolicy AccumPolicy = MmaAccumPolicy::ROW_MAJOR,
bool TransposeC = false>
void WaveWisePipeline_Real_impl()
{
using Factory =
WaveWisePipelineFactory<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK, AccumPolicy>;
using Kernel = WaveWisePipelineKernel<AType,
BType,
CType,
WaveTileM,
WaveTileN,
WaveTileK,
AccumPolicy,
TransposeC>;
mma_pipeline_test::
run_pipeline_matrix_test<Factory::template Create, Kernel, AType, BType, CType>(
WaveTileM,
WaveTileN,
WaveTileK,
should_skip,
Kernel{},
/*isSparse=*/false,
/*transposeExpected=*/TransposeC);
}
TEST(WaveWiseMmaPipeline, testKIterSwapAB)
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32)
{
MmaPipelineTest<> test;
const auto kernel = [](uint32_t waveSize, void* a, void* b, void* c, void* out) {
test_wavewise_pipeline<MmaPipelineTest<>::AType,
MmaPipelineTest<>::BType,
MmaPipelineTest<>::CType,
MmaPipelineTest<>::WaveTileM,
MmaPipelineTest<>::WaveTileN,
MmaPipelineTest<>::WaveTileK,
true><<<1, waveSize>>>(a, b, c, out);
};
test.test_pipeline(should_skip, kernel, validator);
WaveWisePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u>();
}
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32_SwapAB)
{
WaveWisePipeline_Real_impl<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
MmaAccumPolicy::ROW_MAJOR,
true>();
}
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32_ColMajor)
{
WaveWisePipeline_Real_impl<fp16_t, fp16_t, fp32_t, 16u, 16u, 32u, MmaAccumPolicy::COL_MAJOR>();
}
TEST(WaveWiseMmaPipeline, FullMatrixVerify_16x16x32_ColMajor_TransposeC)
{
WaveWisePipeline_Real_impl<fp16_t,
fp16_t,
fp32_t,
16u,
16u,
32u,
MmaAccumPolicy::COL_MAJOR,
true>();
}

View File

@@ -1,8 +1,6 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "get_wave_size_helper.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/mma/amdgcn_mma.hpp"
#include "ck_tile/core/arch/mma/mma_op_family.hpp"
@@ -12,6 +10,8 @@
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "get_cmake_targets_helper.hpp"
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
@@ -21,6 +21,7 @@
using namespace ck_tile;
using namespace ck_tile::core::arch;
using namespace ck_tile::core::arch::mma;
using namespace ck_tile::core::arch::testing;
// Dummy values for testing
constexpr uint32_t DummyTargetIdVal = 55555u;
@@ -484,7 +485,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_16x16x32_Real)
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
const auto wave_size = getCMakeWaveSize();
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());
@@ -585,7 +586,7 @@ TEST(TestAmdgcnMma, MmaSelector_F16_F16_F32_112x112x128_Real)
HIP_CHECK_ERROR(hipMemcpy(d_b, h_b.data(), BSize, hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_c, h_c.data(), CSize, hipMemcpyHostToDevice));
const auto wave_size = getDeviceWaveSize();
const auto wave_size = getCMakeWaveSize();
test_accum_over_k<AType, BType, CType, WaveTileM, WaveTileN, WaveTileK>
<<<1, wave_size>>>(d_a, d_b, d_c, d_out);
HIP_CHECK_ERROR(hipDeviceSynchronize());